learning_ai_common_plat/packages/llm-router/src/router.ts

363 lines
10 KiB
TypeScript

import type {
ChatCompletionRequest,
PromptCategory,
RouterConfig,
ProviderConfig,
RouteResult,
RoutePlan,
HealthSnapshot,
} from './types.js';
import { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
import { classifyPrompt } from './classifier.js';
import { HealthTracker } from './health.js';
import { selectCandidates, pickNext, excludeCandidate, createRoundRobinState } from './selector.js';
import { sendChatCompletion } from './client.js';
export class LlmRouter {
private readonly providers: ProviderConfig[];
private readonly health: HealthTracker;
private readonly timeoutMs: number;
private readonly maxRetries: number;
private readonly log: (entry: TelemetryEntry) => void;
private readonly roundRobinState: Map<string, number>;
constructor(config?: RouterConfig & { onTelemetry?: (entry: TelemetryEntry) => void }) {
const allProviders = config?.providers ?? DEFAULT_PROVIDERS;
this.providers = getAvailableProviders(allProviders);
if (this.providers.length === 0) {
throw new Error(
'No providers available. Set at least one API key env var: ' +
allProviders.map(p => p.apiKeyEnv).join(', ')
);
}
this.health = new HealthTracker({
windowMs: config?.healthWindowMs,
errorThreshold: config?.errorThreshold,
rateLimitThreshold: config?.rateLimitThreshold,
});
this.timeoutMs = config?.timeoutMs ?? 30_000;
this.maxRetries = config?.maxRetries ?? 3;
this.log = config?.onTelemetry ?? (() => {});
this.roundRobinState = createRoundRobinState();
}
/**
* Route a chat completion request to the best available provider.
* Automatically retries on 429/5xx with fallback to other providers.
*/
async chat(request: ChatCompletionRequest): Promise<RouteResult> {
const startTime = Date.now();
const plan = this.planInternal(request, false);
if (plan.explicit) {
return this.chatWithExplicitModel(request, startTime, plan);
}
const category = plan.category as PromptCategory;
let candidates = selectCandidates(this.providers, category, this.health);
let lastError: Error | null = null;
for (let attempt = 1; attempt <= this.maxRetries; attempt++) {
const pick = pickNext(candidates, this.roundRobinState);
if (!pick) break;
const { provider, model } = pick;
const attemptStart = Date.now();
try {
const result = await sendChatCompletion(provider, model.id, request, this.timeoutMs);
if (result.status === 429) {
// Rate limited — record and try next provider
this.health.record(provider.name, model.id, {
timestamp: Date.now(),
latencyMs: result.latencyMs,
status: 'rate_limit',
});
this.log({
event: 'rate_limit',
provider: provider.name,
model: model.id,
attempt,
latencyMs: result.latencyMs,
category,
});
candidates = excludeCandidate(candidates, provider.name, model.id);
continue;
}
// Success
this.health.record(provider.name, model.id, {
timestamp: Date.now(),
latencyMs: result.latencyMs,
status: 'success',
});
this.log({
event: 'success',
provider: provider.name,
model: model.id,
attempt,
latencyMs: result.latencyMs,
category,
tokens: result.response.usage?.total_tokens,
});
return {
response: result.response,
provider: provider.name,
model: model.id,
totalLatencyMs: Date.now() - startTime,
attempts: attempt,
};
} catch (err) {
lastError = err instanceof Error ? err : new Error(String(err));
const attemptLatency = Date.now() - attemptStart;
this.health.record(provider.name, model.id, {
timestamp: Date.now(),
latencyMs: attemptLatency,
status: 'error',
});
this.log({
event: 'error',
provider: provider.name,
model: model.id,
attempt,
latencyMs: attemptLatency,
category,
error: lastError.message,
});
candidates = excludeCandidate(candidates, provider.name, model.id);
}
}
throw new Error(
`All providers exhausted after ${this.maxRetries} attempts. Last error: ${lastError?.message ?? 'unknown'}`
);
}
/**
* Handle explicit provider:model routing (bypass classifier).
*/
private async chatWithExplicitModel(
request: ChatCompletionRequest,
startTime: number,
plan?: RoutePlan
): Promise<RouteResult> {
const resolved = plan ?? this.plan(request);
const provider = resolved.provider;
const modelId = resolved.model.id;
try {
const result = await sendChatCompletion(provider, modelId, request, this.timeoutMs);
if (result.status === 429) {
this.health.record(provider.name, modelId, {
timestamp: Date.now(),
latencyMs: result.latencyMs,
status: 'rate_limit',
});
this.log({
event: 'rate_limit',
provider: provider.name,
model: modelId,
attempt: 1,
latencyMs: result.latencyMs,
category: 'explicit',
});
throw new Error(`Rate limited by ${provider.name} for model ${modelId}`);
}
this.health.record(provider.name, modelId, {
timestamp: Date.now(),
latencyMs: result.latencyMs,
status: 'success',
});
this.log({
event: 'success',
provider: provider.name,
model: modelId,
attempt: 1,
latencyMs: result.latencyMs,
category: 'explicit',
tokens: result.response.usage?.total_tokens,
});
return {
response: result.response,
provider: provider.name,
model: modelId,
totalLatencyMs: Date.now() - startTime,
attempts: 1,
};
} catch (err) {
// Re-throw rate-limit errors (already logged above)
if (err instanceof Error && err.message.startsWith('Rate limited by')) {
throw err;
}
const latency = Date.now() - startTime;
this.health.record(provider.name, modelId, {
timestamp: Date.now(),
latencyMs: latency,
status: 'error',
});
this.log({
event: 'error',
provider: provider.name,
model: modelId,
attempt: 1,
latencyMs: latency,
category: 'explicit',
error: err instanceof Error ? err.message : String(err),
});
throw err;
}
}
plan(request: ChatCompletionRequest): RoutePlan {
return this.planInternal(request, true);
}
private planInternal(request: ChatCompletionRequest, advanceRoundRobin: boolean): RoutePlan {
const explicit = this.resolveExplicitModel(request.model);
if (explicit) {
return {
provider: explicit.provider,
model: explicit.model,
category: 'explicit',
explicit: true,
};
}
const classification = classifyPrompt(request.messages);
const candidates = selectCandidates(this.providers, classification.category, this.health);
if (candidates.length === 0) {
throw new Error('No healthy providers available for routing');
}
const pick = advanceRoundRobin
? pickNext(candidates, this.roundRobinState)
: (candidates[0] ?? null);
if (!pick) {
throw new Error('No provider available for routing');
}
return {
provider: pick.provider,
model: pick.model,
category: classification.category,
explicit: false,
};
}
private resolveExplicitModel(
model?: string
): { provider: ProviderConfig; model: RoutePlan['model'] } | null {
if (!model) return null;
if (model.includes(':') || model.includes('/')) {
const { providerName, modelId } = parseExplicitModel(model);
const provider = this.providers.find(p => p.name === providerName);
if (!provider) {
throw new Error(
`Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}`
);
}
const matchedModel = provider.models.find(candidate => candidate.id === modelId);
if (!matchedModel) {
throw new Error(
`Model "${modelId}" not found for provider "${providerName}". Available: ${provider.models
.map(candidate => candidate.id)
.join(', ')}`
);
}
return { provider, model: matchedModel };
}
const matches = this.providers.flatMap(provider =>
provider.models
.filter(candidate => candidate.id === model)
.map(candidate => ({ provider, model: candidate }))
);
if (matches.length === 1) {
return matches[0]!;
}
if (matches.length > 1) {
throw new Error(
`Model "${model}" is available on multiple providers. Use provider:model format instead.`
);
}
return null;
}
/** Get health snapshots for all tracked provider+model pairs. */
getHealth(): HealthSnapshot[] {
return this.health.allSnapshots();
}
/** Get list of available (configured) providers. */
getProviders(): string[] {
return this.providers.map(p => p.name);
}
/** Reset health tracking data. */
resetHealth(): void {
this.health.reset();
}
}
function parseExplicitModel(raw: string): { providerName: string; modelId: string } {
const colonIdx = raw.indexOf(':');
const slashIdx = raw.indexOf('/');
let sepIdx: number;
if (colonIdx === -1 && slashIdx === -1) {
sepIdx = -1;
} else if (colonIdx === -1) {
sepIdx = slashIdx;
} else if (slashIdx === -1) {
sepIdx = colonIdx;
} else {
sepIdx = Math.min(colonIdx, slashIdx);
}
return {
providerName: sepIdx === -1 ? raw : raw.slice(0, sepIdx),
modelId: sepIdx === -1 ? '' : raw.slice(sepIdx + 1),
};
}
// ── Telemetry types ────────────────────────────────────────────
export interface TelemetryEntry {
event: 'success' | 'rate_limit' | 'error';
provider: string;
model: string;
attempt: number;
latencyMs: number;
category: string;
tokens?: number;
error?: string;
}