feat(llm-router): unify local ollama routing
This commit is contained in:
parent
91885f0d4f
commit
bf2c285e41
@ -3,6 +3,8 @@
|
|||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"predev": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build",
|
||||||
|
"prebuild": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build",
|
||||||
"dev": "next dev",
|
"dev": "next dev",
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import { InputBar } from './InputBar';
|
|||||||
import { MessageThread } from './MessageThread';
|
import { MessageThread } from './MessageThread';
|
||||||
import { ContextBar } from './ContextBar';
|
import { ContextBar } from './ContextBar';
|
||||||
import { estimateTokens, getModelContextWindow } from '../../lib/format';
|
import { estimateTokens, getModelContextWindow } from '../../lib/format';
|
||||||
import { autoDetectDefaults, classifyTask, resolveModel } from '../../lib/router';
|
import { autoDetectDefaults } from '../../lib/router';
|
||||||
import type { ModelDefaults } from '../../lib/types';
|
import type { ModelDefaults } from '../../lib/types';
|
||||||
|
|
||||||
interface ConversationViewProps {
|
interface ConversationViewProps {
|
||||||
@ -31,7 +31,6 @@ export function ConversationView({
|
|||||||
const [streaming, setStreaming] = useState(false);
|
const [streaming, setStreaming] = useState(false);
|
||||||
const [selectedModel, setSelectedModel] = useState(conversation.model || '__auto__');
|
const [selectedModel, setSelectedModel] = useState(conversation.model || '__auto__');
|
||||||
const [models, setModels] = useState<string[]>([]);
|
const [models, setModels] = useState<string[]>([]);
|
||||||
const [runningModels, setRunningModels] = useState<string[]>([]);
|
|
||||||
const [showModels, setShowModels] = useState(false);
|
const [showModels, setShowModels] = useState(false);
|
||||||
const [modelDefaults, setModelDefaults] = useState<ModelDefaults | null>(null);
|
const [modelDefaults, setModelDefaults] = useState<ModelDefaults | null>(null);
|
||||||
const abortRef = useRef<AbortController | null>(null);
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
@ -53,9 +52,7 @@ export function ConversationView({
|
|||||||
if (!res.ok) return;
|
if (!res.ok) return;
|
||||||
const data = (await res.json()) as OllamaData;
|
const data = (await res.json()) as OllamaData;
|
||||||
const installed = data.models.map(m => m.name);
|
const installed = data.models.map(m => m.name);
|
||||||
const running = data.running.map(m => m.name);
|
|
||||||
setModels(installed);
|
setModels(installed);
|
||||||
setRunningModels(running);
|
|
||||||
|
|
||||||
const stored = localStorage.getItem('llm-model-defaults');
|
const stored = localStorage.getItem('llm-model-defaults');
|
||||||
if (stored) {
|
if (stored) {
|
||||||
@ -114,15 +111,7 @@ export function ConversationView({
|
|||||||
const onSend = async (text: string) => {
|
const onSend = async (text: string) => {
|
||||||
await ensureModels();
|
await ensureModels();
|
||||||
|
|
||||||
const routedModel = (() => {
|
const requestedModel = selectedModel !== '__auto__' ? selectedModel : undefined;
|
||||||
if (selectedModel !== '__auto__') return selectedModel;
|
|
||||||
const defaults = modelDefaults || autoDetectDefaults(models);
|
|
||||||
const taskType = classifyTask(text);
|
|
||||||
const resolved = resolveModel(taskType, defaults, runningModels, models);
|
|
||||||
return resolved.model;
|
|
||||||
})();
|
|
||||||
|
|
||||||
if (!routedModel) return;
|
|
||||||
|
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const userMessage: Message = {
|
const userMessage: Message = {
|
||||||
@ -131,7 +120,7 @@ export function ConversationView({
|
|||||||
role: 'user',
|
role: 'user',
|
||||||
content: text,
|
content: text,
|
||||||
timestamp: now,
|
timestamp: now,
|
||||||
model: routedModel,
|
model: requestedModel,
|
||||||
};
|
};
|
||||||
|
|
||||||
await addMessage(userMessage);
|
await addMessage(userMessage);
|
||||||
@ -159,9 +148,7 @@ export function ConversationView({
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: routedModel,
|
model: requestedModel ?? '__auto__',
|
||||||
modelDefaults: modelDefaults,
|
|
||||||
taskType: classifyTask(text),
|
|
||||||
messages: chatPayload,
|
messages: chatPayload,
|
||||||
}),
|
}),
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
@ -171,6 +158,19 @@ export function ConversationView({
|
|||||||
throw new Error('Failed to stream response');
|
throw new Error('Failed to stream response');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const routedModel = res.headers.get('x-by-model') || requestedModel || '';
|
||||||
|
if (!routedModel) {
|
||||||
|
throw new Error('Router did not return a model');
|
||||||
|
}
|
||||||
|
const previousUserModel = userMessage.model;
|
||||||
|
userMessage.model = routedModel;
|
||||||
|
if (routedModel !== previousUserModel) {
|
||||||
|
await updateMessage(userMessage.id, { model: routedModel });
|
||||||
|
}
|
||||||
|
setMessages(prev =>
|
||||||
|
prev.map(message => (message.id === userMessage.id ? userMessage : message))
|
||||||
|
);
|
||||||
|
|
||||||
const reader = res.body.getReader();
|
const reader = res.body.getReader();
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let buffer = '';
|
let buffer = '';
|
||||||
@ -254,7 +254,7 @@ export function ConversationView({
|
|||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: `Error: ${String(err)}`,
|
content: `Error: ${String(err)}`,
|
||||||
timestamp: Date.now(),
|
timestamp: Date.now(),
|
||||||
model: routedModel,
|
model: requestedModel,
|
||||||
};
|
};
|
||||||
setMessages(prev => {
|
setMessages(prev => {
|
||||||
const withoutTemp = prev.filter(m => m.id !== assistantId);
|
const withoutTemp = prev.filter(m => m.id !== assistantId);
|
||||||
|
|||||||
@ -1,5 +1,63 @@
|
|||||||
import { NextRequest } from 'next/server';
|
import { NextRequest } from 'next/server';
|
||||||
import { OLLAMA_URL } from '../../../lib/ollama-config';
|
import { OLLAMA_URL } from '../../../lib/ollama-config';
|
||||||
|
import {
|
||||||
|
LlmRouter,
|
||||||
|
createLocalOllamaProvider,
|
||||||
|
type ChatCompletionRequest,
|
||||||
|
type RoutePlan,
|
||||||
|
} from '../../../lib/llm-router';
|
||||||
|
|
||||||
|
interface OllamaTagsResponse {
|
||||||
|
models?: Array<{ name?: string }>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface OllamaPsResponse {
|
||||||
|
models?: Array<{ name?: string }>;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchOllamaJson<T>(path: string): Promise<T> {
|
||||||
|
const response = await fetch(`${OLLAMA_URL}${path}`, {
|
||||||
|
cache: 'no-store',
|
||||||
|
signal: AbortSignal.timeout(5000),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Ollama ${path} failed: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (await response.json()) as T;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function resolvePlan(request: ChatCompletionRequest): Promise<RoutePlan> {
|
||||||
|
const [tags, ps] = await Promise.all([
|
||||||
|
fetchOllamaJson<OllamaTagsResponse>('/api/tags'),
|
||||||
|
fetchOllamaJson<OllamaPsResponse>('/api/ps').catch(() => ({ models: [] })),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const installedModels = (tags.models ?? [])
|
||||||
|
.map(model => model.name?.trim() || '')
|
||||||
|
.filter(Boolean);
|
||||||
|
const runningModels = new Set(
|
||||||
|
(ps.models ?? []).map(model => model.name?.trim() || '').filter(Boolean)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (installedModels.length === 0) {
|
||||||
|
throw new Error('No local Ollama models installed');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put currently loaded models first so the router prefers them when scores are equal.
|
||||||
|
const orderedModels = [
|
||||||
|
...installedModels.filter(model => runningModels.has(model)),
|
||||||
|
...installedModels.filter(model => !runningModels.has(model)),
|
||||||
|
];
|
||||||
|
|
||||||
|
const router = new LlmRouter({
|
||||||
|
providers: [createLocalOllamaProvider(orderedModels, `${OLLAMA_URL}/v1`)],
|
||||||
|
maxRetries: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
return router.plan(request);
|
||||||
|
}
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
export async function POST(request: NextRequest) {
|
||||||
try {
|
try {
|
||||||
@ -13,10 +71,15 @@ export async function POST(request: NextRequest) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const plan = await resolvePlan({
|
||||||
|
messages,
|
||||||
|
model: model && model !== '__auto__' ? `local-ollama:${String(model)}` : undefined,
|
||||||
|
});
|
||||||
|
|
||||||
const response = await fetch(`${OLLAMA_URL}/api/chat`, {
|
const response = await fetch(`${OLLAMA_URL}/api/chat`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ model, messages, stream: true }),
|
body: JSON.stringify({ model: plan.model.id, messages, stream: true }),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok || !response.body) {
|
if (!response.ok || !response.body) {
|
||||||
@ -31,6 +94,9 @@ export async function POST(request: NextRequest) {
|
|||||||
'Content-Type': 'application/x-ndjson',
|
'Content-Type': 'application/x-ndjson',
|
||||||
'Transfer-Encoding': 'chunked',
|
'Transfer-Encoding': 'chunked',
|
||||||
'Cache-Control': 'no-cache',
|
'Cache-Control': 'no-cache',
|
||||||
|
'x-by-provider': plan.provider.name,
|
||||||
|
'x-by-model': plan.model.id,
|
||||||
|
'x-by-category': plan.category,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
8
__LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts
Normal file
8
__LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
export {
|
||||||
|
LlmRouter,
|
||||||
|
createLocalOllamaProvider,
|
||||||
|
classifyPrompt,
|
||||||
|
type ChatCompletionRequest,
|
||||||
|
type PromptCategory,
|
||||||
|
type RoutePlan,
|
||||||
|
} from '../../../../../packages/llm-router/dist/index.js';
|
||||||
@ -1,5 +1,9 @@
|
|||||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||||
import { getAvailableProviders, DEFAULT_PROVIDERS } from '../registry.js';
|
import {
|
||||||
|
createLocalOllamaProvider,
|
||||||
|
getAvailableProviders,
|
||||||
|
DEFAULT_PROVIDERS,
|
||||||
|
} from '../registry.js';
|
||||||
import type { ProviderConfig } from '../types.js';
|
import type { ProviderConfig } from '../types.js';
|
||||||
|
|
||||||
describe('getAvailableProviders', () => {
|
describe('getAvailableProviders', () => {
|
||||||
@ -8,8 +12,9 @@ describe('getAvailableProviders', () => {
|
|||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
// Save and clear all default provider env vars
|
// Save and clear all default provider env vars
|
||||||
for (const p of DEFAULT_PROVIDERS) {
|
for (const p of DEFAULT_PROVIDERS) {
|
||||||
saved[p.apiKeyEnv] = process.env[p.apiKeyEnv];
|
const key = p.apiKeyEnv!;
|
||||||
delete process.env[p.apiKeyEnv];
|
saved[key] = process.env[key];
|
||||||
|
delete process.env[key];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -68,6 +73,13 @@ describe('getAvailableProviders', () => {
|
|||||||
delete process.env.CUSTOM_TEST_KEY;
|
delete process.env.CUSTOM_TEST_KEY;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('includes local providers that do not require API keys', () => {
|
||||||
|
const local = createLocalOllamaProvider(['qwen2.5-coder:7b']);
|
||||||
|
const result = getAvailableProviders([local]);
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result[0]!.name).toBe('local-ollama');
|
||||||
|
});
|
||||||
|
|
||||||
it('DEFAULT_PROVIDERS includes all 4 providers', () => {
|
it('DEFAULT_PROVIDERS includes all 4 providers', () => {
|
||||||
expect(DEFAULT_PROVIDERS).toHaveLength(4);
|
expect(DEFAULT_PROVIDERS).toHaveLength(4);
|
||||||
const names = DEFAULT_PROVIDERS.map(p => p.name);
|
const names = DEFAULT_PROVIDERS.map(p => p.name);
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
import { LlmRouter } from '../router.js';
|
import { LlmRouter } from '../router.js';
|
||||||
|
import { createLocalOllamaProvider } from '../registry.js';
|
||||||
import type { ProviderConfig, ChatCompletionResponse } from '../types.js';
|
import type { ProviderConfig, ChatCompletionResponse } from '../types.js';
|
||||||
import * as client from '../client.js';
|
import * as client from '../client.js';
|
||||||
|
|
||||||
@ -241,6 +242,30 @@ describe('LlmRouter', () => {
|
|||||||
expect(router.getProviders()).toEqual(['test-fast', 'test-quality']);
|
expect(router.getProviders()).toEqual(['test-fast', 'test-quality']);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('plans the best provider/model without executing a request', () => {
|
||||||
|
const router = new LlmRouter({ providers: TEST_PROVIDERS });
|
||||||
|
const plan = router.plan({
|
||||||
|
messages: [{ role: 'user', content: 'Write a TypeScript endpoint with validation' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(plan.provider.name).toBe('test-quality');
|
||||||
|
expect(plan.model.id).toBe('quality-model');
|
||||||
|
expect(plan.category).toBe('code');
|
||||||
|
expect(plan.explicit).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('plans local ollama models without requiring an API key', () => {
|
||||||
|
const router = new LlmRouter({
|
||||||
|
providers: [createLocalOllamaProvider(['qwen2.5-coder:7b', 'llama3.1:8b'])],
|
||||||
|
});
|
||||||
|
const plan = router.plan({
|
||||||
|
messages: [{ role: 'user', content: 'Refactor this TypeScript function' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(plan.provider.name).toBe('local-ollama');
|
||||||
|
expect(plan.model.id).toBe('qwen2.5-coder:7b');
|
||||||
|
});
|
||||||
|
|
||||||
it('fires telemetry for explicit model routing', async () => {
|
it('fires telemetry for explicit model routing', async () => {
|
||||||
vi.mocked(client.sendChatCompletion).mockResolvedValueOnce({
|
vi.mocked(client.sendChatCompletion).mockResolvedValueOnce({
|
||||||
response: MOCK_RESPONSE,
|
response: MOCK_RESPONSE,
|
||||||
|
|||||||
@ -10,17 +10,19 @@ export async function sendChatCompletion(
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
timeoutMs: number = 30_000
|
timeoutMs: number = 30_000
|
||||||
): Promise<{ response: ChatCompletionResponse; latencyMs: number; status: number }> {
|
): Promise<{ response: ChatCompletionResponse; latencyMs: number; status: number }> {
|
||||||
const apiKey = process.env[provider.apiKeyEnv];
|
const apiKey = provider.apiKeyEnv ? process.env[provider.apiKeyEnv] : null;
|
||||||
if (!apiKey) {
|
if (provider.apiKeyEnv && !apiKey) {
|
||||||
throw new Error(`Missing API key: env var ${provider.apiKeyEnv} is not set`);
|
throw new Error(`Missing API key: env var ${provider.apiKeyEnv} is not set`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = `${provider.baseUrl}/chat/completions`;
|
const url = `${provider.baseUrl}/chat/completions`;
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
Authorization: `Bearer ${apiKey}`,
|
|
||||||
...provider.extraHeaders,
|
...provider.extraHeaders,
|
||||||
};
|
};
|
||||||
|
if (apiKey) {
|
||||||
|
headers.Authorization = `Bearer ${apiKey}`;
|
||||||
|
}
|
||||||
|
|
||||||
const body = JSON.stringify({
|
const body = JSON.stringify({
|
||||||
model: modelId,
|
model: modelId,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
export { LlmRouter } from './router.js';
|
export { LlmRouter } from './router.js';
|
||||||
export type { TelemetryEntry } from './router.js';
|
export type { TelemetryEntry } from './router.js';
|
||||||
|
|
||||||
export { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
|
export { DEFAULT_PROVIDERS, createLocalOllamaProvider, getAvailableProviders } from './registry.js';
|
||||||
export { classifyPrompt } from './classifier.js';
|
export { classifyPrompt } from './classifier.js';
|
||||||
export { HealthTracker } from './health.js';
|
export { HealthTracker } from './health.js';
|
||||||
export { selectCandidates, pickNext, excludeCandidate, createRoundRobinState } from './selector.js';
|
export { selectCandidates, pickNext, excludeCandidate, createRoundRobinState } from './selector.js';
|
||||||
@ -22,4 +22,5 @@ export type {
|
|||||||
ChatCompletionUsage,
|
ChatCompletionUsage,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
RouteResult,
|
RouteResult,
|
||||||
|
RoutePlan,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { ProviderConfig } from './types.js';
|
import type { ModelConfig, PromptCategory, ProviderConfig } from './types.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default free-tier provider configurations.
|
* Default free-tier provider configurations.
|
||||||
@ -121,6 +121,57 @@ export const DEFAULT_PROVIDERS: ProviderConfig[] = [
|
|||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
function inferStrengths(modelId: string): PromptCategory[] {
|
||||||
|
const lower = modelId.toLowerCase();
|
||||||
|
const strengths = new Set<PromptCategory>(['general']);
|
||||||
|
|
||||||
|
if (/coder|code|codestral|starcoder|deepseek/.test(lower)) strengths.add('code');
|
||||||
|
if (/r1|reason|think|math/.test(lower)) {
|
||||||
|
strengths.add('reasoning');
|
||||||
|
strengths.add('math');
|
||||||
|
}
|
||||||
|
if (/qwen|llama|mistral|chat/.test(lower)) strengths.add('creative');
|
||||||
|
|
||||||
|
return [...strengths];
|
||||||
|
}
|
||||||
|
|
||||||
|
function inferContextWindow(modelId: string): number {
|
||||||
|
const lower = modelId.toLowerCase();
|
||||||
|
if (/128k|131072/.test(lower)) return 128_000;
|
||||||
|
if (/64k|65536/.test(lower)) return 64_000;
|
||||||
|
if (/32k|32768|qwen2\.5/.test(lower)) return 32_768;
|
||||||
|
if (/16k|16384/.test(lower)) return 16_384;
|
||||||
|
return 8_192;
|
||||||
|
}
|
||||||
|
|
||||||
|
function inferSpeedTier(modelId: string): 1 | 2 | 3 {
|
||||||
|
const lower = modelId.toLowerCase();
|
||||||
|
if (/0\.5b|1b|3b|7b|mini|tiny/.test(lower)) return 1;
|
||||||
|
if (/14b|15b|16b|20b|22b|30b|32b/.test(lower)) return 2;
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createLocalOllamaProvider(
|
||||||
|
modelIds: string[],
|
||||||
|
baseUrl: string = 'http://localhost:11434/v1'
|
||||||
|
): ProviderConfig {
|
||||||
|
const models: ModelConfig[] = modelIds.map(modelId => ({
|
||||||
|
id: modelId,
|
||||||
|
label: modelId,
|
||||||
|
contextWindow: inferContextWindow(modelId),
|
||||||
|
strengths: inferStrengths(modelId),
|
||||||
|
speedTier: inferSpeedTier(modelId),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return {
|
||||||
|
name: 'local-ollama',
|
||||||
|
baseUrl,
|
||||||
|
models,
|
||||||
|
rpmLimit: 0,
|
||||||
|
tpmLimit: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Filter providers to only those with API keys present in env.
|
* Filter providers to only those with API keys present in env.
|
||||||
*/
|
*/
|
||||||
@ -128,6 +179,7 @@ export function getAvailableProviders(
|
|||||||
providers: ProviderConfig[] = DEFAULT_PROVIDERS
|
providers: ProviderConfig[] = DEFAULT_PROVIDERS
|
||||||
): ProviderConfig[] {
|
): ProviderConfig[] {
|
||||||
return providers.filter(p => {
|
return providers.filter(p => {
|
||||||
|
if (!p.apiKeyEnv) return true;
|
||||||
const key = process.env[p.apiKeyEnv];
|
const key = process.env[p.apiKeyEnv];
|
||||||
return key !== undefined && key !== '';
|
return key !== undefined && key !== '';
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import type {
|
import type {
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
PromptCategory,
|
||||||
RouterConfig,
|
RouterConfig,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
RouteResult,
|
RouteResult,
|
||||||
|
RoutePlan,
|
||||||
HealthSnapshot,
|
HealthSnapshot,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
|
import { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
|
||||||
@ -49,20 +51,14 @@ export class LlmRouter {
|
|||||||
async chat(request: ChatCompletionRequest): Promise<RouteResult> {
|
async chat(request: ChatCompletionRequest): Promise<RouteResult> {
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
// If user specified a specific provider:model or provider/model, try that first
|
const plan = this.planInternal(request, false);
|
||||||
if (request.model && (request.model.includes(':') || request.model.includes('/'))) {
|
|
||||||
return this.chatWithExplicitModel(request, startTime);
|
if (plan.explicit) {
|
||||||
|
return this.chatWithExplicitModel(request, startTime, plan);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Classify the prompt
|
const category = plan.category as PromptCategory;
|
||||||
const classification = classifyPrompt(request.messages);
|
let candidates = selectCandidates(this.providers, category, this.health);
|
||||||
|
|
||||||
// Get ranked candidates
|
|
||||||
let candidates = selectCandidates(this.providers, classification.category, this.health);
|
|
||||||
|
|
||||||
if (candidates.length === 0) {
|
|
||||||
throw new Error('No healthy providers available for routing');
|
|
||||||
}
|
|
||||||
|
|
||||||
let lastError: Error | null = null;
|
let lastError: Error | null = null;
|
||||||
|
|
||||||
@ -90,7 +86,7 @@ export class LlmRouter {
|
|||||||
model: model.id,
|
model: model.id,
|
||||||
attempt,
|
attempt,
|
||||||
latencyMs: result.latencyMs,
|
latencyMs: result.latencyMs,
|
||||||
category: classification.category,
|
category,
|
||||||
});
|
});
|
||||||
|
|
||||||
candidates = excludeCandidate(candidates, provider.name, model.id);
|
candidates = excludeCandidate(candidates, provider.name, model.id);
|
||||||
@ -110,7 +106,7 @@ export class LlmRouter {
|
|||||||
model: model.id,
|
model: model.id,
|
||||||
attempt,
|
attempt,
|
||||||
latencyMs: result.latencyMs,
|
latencyMs: result.latencyMs,
|
||||||
category: classification.category,
|
category,
|
||||||
tokens: result.response.usage?.total_tokens,
|
tokens: result.response.usage?.total_tokens,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -137,7 +133,7 @@ export class LlmRouter {
|
|||||||
model: model.id,
|
model: model.id,
|
||||||
attempt,
|
attempt,
|
||||||
latencyMs: attemptLatency,
|
latencyMs: attemptLatency,
|
||||||
category: classification.category,
|
category,
|
||||||
error: lastError.message,
|
error: lastError.message,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -155,33 +151,12 @@ export class LlmRouter {
|
|||||||
*/
|
*/
|
||||||
private async chatWithExplicitModel(
|
private async chatWithExplicitModel(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
startTime: number
|
startTime: number,
|
||||||
|
plan?: RoutePlan
|
||||||
): Promise<RouteResult> {
|
): Promise<RouteResult> {
|
||||||
// Support both "provider:model" and "provider/model" separators
|
const resolved = plan ?? this.plan(request);
|
||||||
// Use first colon or first slash (whichever comes first) as separator
|
const provider = resolved.provider;
|
||||||
const raw = request.model!;
|
const modelId = resolved.model.id;
|
||||||
const colonIdx = raw.indexOf(':');
|
|
||||||
const slashIdx = raw.indexOf('/');
|
|
||||||
let sepIdx: number;
|
|
||||||
if (colonIdx === -1 && slashIdx === -1) {
|
|
||||||
sepIdx = -1;
|
|
||||||
} else if (colonIdx === -1) {
|
|
||||||
sepIdx = slashIdx;
|
|
||||||
} else if (slashIdx === -1) {
|
|
||||||
sepIdx = colonIdx;
|
|
||||||
} else {
|
|
||||||
sepIdx = Math.min(colonIdx, slashIdx);
|
|
||||||
}
|
|
||||||
|
|
||||||
const providerName = sepIdx === -1 ? raw : raw.slice(0, sepIdx);
|
|
||||||
const modelId = sepIdx === -1 ? '' : raw.slice(sepIdx + 1);
|
|
||||||
|
|
||||||
const provider = this.providers.find(p => p.name === providerName);
|
|
||||||
if (!provider) {
|
|
||||||
throw new Error(
|
|
||||||
`Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await sendChatCompletion(provider, modelId, request, this.timeoutMs);
|
const result = await sendChatCompletion(provider, modelId, request, this.timeoutMs);
|
||||||
@ -202,7 +177,7 @@ export class LlmRouter {
|
|||||||
category: 'explicit',
|
category: 'explicit',
|
||||||
});
|
});
|
||||||
|
|
||||||
throw new Error(`Rate limited by ${providerName} for model ${modelId}`);
|
throw new Error(`Rate limited by ${provider.name} for model ${modelId}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.health.record(provider.name, modelId, {
|
this.health.record(provider.name, modelId, {
|
||||||
@ -255,6 +230,88 @@ export class LlmRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
plan(request: ChatCompletionRequest): RoutePlan {
|
||||||
|
return this.planInternal(request, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private planInternal(request: ChatCompletionRequest, advanceRoundRobin: boolean): RoutePlan {
|
||||||
|
const explicit = this.resolveExplicitModel(request.model);
|
||||||
|
if (explicit) {
|
||||||
|
return {
|
||||||
|
provider: explicit.provider,
|
||||||
|
model: explicit.model,
|
||||||
|
category: 'explicit',
|
||||||
|
explicit: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const classification = classifyPrompt(request.messages);
|
||||||
|
const candidates = selectCandidates(this.providers, classification.category, this.health);
|
||||||
|
|
||||||
|
if (candidates.length === 0) {
|
||||||
|
throw new Error('No healthy providers available for routing');
|
||||||
|
}
|
||||||
|
|
||||||
|
const pick = advanceRoundRobin
|
||||||
|
? pickNext(candidates, this.roundRobinState)
|
||||||
|
: (candidates[0] ?? null);
|
||||||
|
if (!pick) {
|
||||||
|
throw new Error('No provider available for routing');
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
provider: pick.provider,
|
||||||
|
model: pick.model,
|
||||||
|
category: classification.category,
|
||||||
|
explicit: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private resolveExplicitModel(
|
||||||
|
model?: string
|
||||||
|
): { provider: ProviderConfig; model: RoutePlan['model'] } | null {
|
||||||
|
if (!model) return null;
|
||||||
|
|
||||||
|
if (model.includes(':') || model.includes('/')) {
|
||||||
|
const { providerName, modelId } = parseExplicitModel(model);
|
||||||
|
const provider = this.providers.find(p => p.name === providerName);
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error(
|
||||||
|
`Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const matchedModel = provider.models.find(candidate => candidate.id === modelId);
|
||||||
|
if (!matchedModel) {
|
||||||
|
throw new Error(
|
||||||
|
`Model "${modelId}" not found for provider "${providerName}". Available: ${provider.models
|
||||||
|
.map(candidate => candidate.id)
|
||||||
|
.join(', ')}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { provider, model: matchedModel };
|
||||||
|
}
|
||||||
|
|
||||||
|
const matches = this.providers.flatMap(provider =>
|
||||||
|
provider.models
|
||||||
|
.filter(candidate => candidate.id === model)
|
||||||
|
.map(candidate => ({ provider, model: candidate }))
|
||||||
|
);
|
||||||
|
|
||||||
|
if (matches.length === 1) {
|
||||||
|
return matches[0]!;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matches.length > 1) {
|
||||||
|
throw new Error(
|
||||||
|
`Model "${model}" is available on multiple providers. Use provider:model format instead.`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/** Get health snapshots for all tracked provider+model pairs. */
|
/** Get health snapshots for all tracked provider+model pairs. */
|
||||||
getHealth(): HealthSnapshot[] {
|
getHealth(): HealthSnapshot[] {
|
||||||
return this.health.allSnapshots();
|
return this.health.allSnapshots();
|
||||||
@ -271,6 +328,26 @@ export class LlmRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function parseExplicitModel(raw: string): { providerName: string; modelId: string } {
|
||||||
|
const colonIdx = raw.indexOf(':');
|
||||||
|
const slashIdx = raw.indexOf('/');
|
||||||
|
let sepIdx: number;
|
||||||
|
if (colonIdx === -1 && slashIdx === -1) {
|
||||||
|
sepIdx = -1;
|
||||||
|
} else if (colonIdx === -1) {
|
||||||
|
sepIdx = slashIdx;
|
||||||
|
} else if (slashIdx === -1) {
|
||||||
|
sepIdx = colonIdx;
|
||||||
|
} else {
|
||||||
|
sepIdx = Math.min(colonIdx, slashIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
providerName: sepIdx === -1 ? raw : raw.slice(0, sepIdx),
|
||||||
|
modelId: sepIdx === -1 ? '' : raw.slice(sepIdx + 1),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// ── Telemetry types ────────────────────────────────────────────
|
// ── Telemetry types ────────────────────────────────────────────
|
||||||
|
|
||||||
export interface TelemetryEntry {
|
export interface TelemetryEntry {
|
||||||
|
|||||||
@ -18,8 +18,8 @@ export interface ProviderConfig {
|
|||||||
name: string;
|
name: string;
|
||||||
/** OpenAI-compatible base URL (e.g. https://api.groq.com/openai/v1) */
|
/** OpenAI-compatible base URL (e.g. https://api.groq.com/openai/v1) */
|
||||||
baseUrl: string;
|
baseUrl: string;
|
||||||
/** Environment variable name that holds the API key */
|
/** Environment variable name that holds the API key (omit for local/no-auth providers) */
|
||||||
apiKeyEnv: string;
|
apiKeyEnv?: string;
|
||||||
/** Available models on this provider */
|
/** Available models on this provider */
|
||||||
models: ModelConfig[];
|
models: ModelConfig[];
|
||||||
/** Extra headers to send with every request */
|
/** Extra headers to send with every request */
|
||||||
@ -134,3 +134,10 @@ export interface RouteResult {
|
|||||||
/** How many attempts were made */
|
/** How many attempts were made */
|
||||||
attempts: number;
|
attempts: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface RoutePlan {
|
||||||
|
provider: ProviderConfig;
|
||||||
|
model: ModelConfig;
|
||||||
|
category: PromptCategory | 'explicit';
|
||||||
|
explicit: boolean;
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user