109 lines
3.3 KiB
TypeScript
109 lines
3.3 KiB
TypeScript
import type { ModelConfig, PromptCategory, ProviderConfig } from './types.js';
|
|
import type { HealthTracker } from './health.js';
|
|
|
|
export interface SelectionCandidate {
|
|
provider: ProviderConfig;
|
|
model: ModelConfig;
|
|
}
|
|
|
|
/** Create a fresh round-robin state map (one per router instance). */
|
|
export function createRoundRobinState(): Map<string, number> {
|
|
return new Map<string, number>();
|
|
}
|
|
|
|
/**
|
|
* Score a model for a given prompt category.
|
|
* Higher = better fit.
|
|
*/
|
|
function scoreModel(model: ModelConfig, category: PromptCategory): number {
|
|
let score = 0;
|
|
|
|
// Vision requests require vision-capable models
|
|
if (category === 'vision') {
|
|
if (!model.supportsVision) return -1; // Exclude non-vision models
|
|
score += 15; // Strong boost for vision capability
|
|
}
|
|
|
|
// Direct strength match is the strongest signal
|
|
if (model.strengths.includes(category)) {
|
|
score += 10;
|
|
}
|
|
|
|
// Speed bonus (lower tier = faster = better for simple tasks)
|
|
score += (4 - model.speedTier) * 2;
|
|
|
|
// Context window bonus for reasoning/creative (often longer)
|
|
if ((category === 'reasoning' || category === 'creative') && model.contextWindow >= 64_000) {
|
|
score += 3;
|
|
}
|
|
|
|
// Prefer larger models for code/math/reasoning
|
|
if (['code', 'math', 'reasoning'].includes(category)) {
|
|
if (model.id.includes('70b') || model.id.includes('70B')) score += 5;
|
|
if (model.id.includes('r1') || model.id.includes('R1')) score += 4;
|
|
}
|
|
|
|
return score;
|
|
}
|
|
|
|
/**
|
|
* Select the best provider+model candidates for a prompt category.
|
|
* Returns candidates sorted by score (best first), filtered by health.
|
|
*/
|
|
export function selectCandidates(
|
|
providers: ProviderConfig[],
|
|
category: PromptCategory,
|
|
health: HealthTracker
|
|
): SelectionCandidate[] {
|
|
const candidates: (SelectionCandidate & { score: number })[] = [];
|
|
|
|
for (const provider of providers) {
|
|
for (const model of provider.models) {
|
|
if (!health.isHealthy(provider.name, model.id)) continue;
|
|
|
|
const score = scoreModel(model, category);
|
|
if (score < 0) continue; // Skip incompatible models (e.g. non-vision for vision requests)
|
|
candidates.push({ provider, model, score });
|
|
}
|
|
}
|
|
|
|
// Sort by score descending
|
|
candidates.sort((a, b) => b.score - a.score);
|
|
|
|
return candidates;
|
|
}
|
|
|
|
/**
|
|
* Pick the next candidate using round-robin within the top tier.
|
|
* Groups candidates by provider, rotates between them to spread rate-limit load.
|
|
*/
|
|
export function pickNext(
|
|
candidates: SelectionCandidate[],
|
|
state: Map<string, number>
|
|
): SelectionCandidate | null {
|
|
if (candidates.length === 0) return null;
|
|
if (candidates.length === 1) return candidates[0]!;
|
|
|
|
// Group by provider name for round-robin
|
|
const providerNames = [...new Set(candidates.map(c => c.provider.name))];
|
|
const key = providerNames.join(',');
|
|
|
|
const idx = state.get(key) ?? 0;
|
|
const targetProvider = providerNames[idx % providerNames.length]!;
|
|
state.set(key, idx + 1);
|
|
|
|
// Pick the best model from the selected provider
|
|
return candidates.find(c => c.provider.name === targetProvider) ?? candidates[0]!;
|
|
}
|
|
|
|
/**
|
|
* Remove a candidate from the list (after failure) and return remaining.
|
|
*/
|
|
export function excludeCandidate(
|
|
candidates: SelectionCandidate[],
|
|
provider: string,
|
|
model: string
|
|
): SelectionCandidate[] {
|
|
return candidates.filter(c => !(c.provider.name === provider && c.model.id === model));
|
|
}
|