feat(llm-router): unify local ollama routing

This commit is contained in:
root 2026-03-14 06:17:17 +00:00
parent 91885f0d4f
commit bf2c285e41
11 changed files with 323 additions and 71 deletions

View File

@ -3,6 +3,8 @@
"version": "0.1.0",
"private": true,
"scripts": {
"predev": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build",
"prebuild": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build",
"dev": "next dev",
"build": "next build",
"start": "next start",

View File

@ -8,7 +8,7 @@ import { InputBar } from './InputBar';
import { MessageThread } from './MessageThread';
import { ContextBar } from './ContextBar';
import { estimateTokens, getModelContextWindow } from '../../lib/format';
import { autoDetectDefaults, classifyTask, resolveModel } from '../../lib/router';
import { autoDetectDefaults } from '../../lib/router';
import type { ModelDefaults } from '../../lib/types';
interface ConversationViewProps {
@ -31,7 +31,6 @@ export function ConversationView({
const [streaming, setStreaming] = useState(false);
const [selectedModel, setSelectedModel] = useState(conversation.model || '__auto__');
const [models, setModels] = useState<string[]>([]);
const [runningModels, setRunningModels] = useState<string[]>([]);
const [showModels, setShowModels] = useState(false);
const [modelDefaults, setModelDefaults] = useState<ModelDefaults | null>(null);
const abortRef = useRef<AbortController | null>(null);
@ -53,9 +52,7 @@ export function ConversationView({
if (!res.ok) return;
const data = (await res.json()) as OllamaData;
const installed = data.models.map(m => m.name);
const running = data.running.map(m => m.name);
setModels(installed);
setRunningModels(running);
const stored = localStorage.getItem('llm-model-defaults');
if (stored) {
@ -114,15 +111,7 @@ export function ConversationView({
const onSend = async (text: string) => {
await ensureModels();
const routedModel = (() => {
if (selectedModel !== '__auto__') return selectedModel;
const defaults = modelDefaults || autoDetectDefaults(models);
const taskType = classifyTask(text);
const resolved = resolveModel(taskType, defaults, runningModels, models);
return resolved.model;
})();
if (!routedModel) return;
const requestedModel = selectedModel !== '__auto__' ? selectedModel : undefined;
const now = Date.now();
const userMessage: Message = {
@ -131,7 +120,7 @@ export function ConversationView({
role: 'user',
content: text,
timestamp: now,
model: routedModel,
model: requestedModel,
};
await addMessage(userMessage);
@ -159,9 +148,7 @@ export function ConversationView({
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: routedModel,
modelDefaults: modelDefaults,
taskType: classifyTask(text),
model: requestedModel ?? '__auto__',
messages: chatPayload,
}),
signal: controller.signal,
@ -171,6 +158,19 @@ export function ConversationView({
throw new Error('Failed to stream response');
}
const routedModel = res.headers.get('x-by-model') || requestedModel || '';
if (!routedModel) {
throw new Error('Router did not return a model');
}
const previousUserModel = userMessage.model;
userMessage.model = routedModel;
if (routedModel !== previousUserModel) {
await updateMessage(userMessage.id, { model: routedModel });
}
setMessages(prev =>
prev.map(message => (message.id === userMessage.id ? userMessage : message))
);
const reader = res.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
@ -254,7 +254,7 @@ export function ConversationView({
role: 'assistant',
content: `Error: ${String(err)}`,
timestamp: Date.now(),
model: routedModel,
model: requestedModel,
};
setMessages(prev => {
const withoutTemp = prev.filter(m => m.id !== assistantId);

View File

@ -1,5 +1,63 @@
import { NextRequest } from 'next/server';
import { OLLAMA_URL } from '../../../lib/ollama-config';
import {
LlmRouter,
createLocalOllamaProvider,
type ChatCompletionRequest,
type RoutePlan,
} from '../../../lib/llm-router';
interface OllamaTagsResponse {
models?: Array<{ name?: string }>;
}
interface OllamaPsResponse {
models?: Array<{ name?: string }>;
}
async function fetchOllamaJson<T>(path: string): Promise<T> {
const response = await fetch(`${OLLAMA_URL}${path}`, {
cache: 'no-store',
signal: AbortSignal.timeout(5000),
});
if (!response.ok) {
throw new Error(`Ollama ${path} failed: ${response.status}`);
}
return (await response.json()) as T;
}
async function resolvePlan(request: ChatCompletionRequest): Promise<RoutePlan> {
const [tags, ps] = await Promise.all([
fetchOllamaJson<OllamaTagsResponse>('/api/tags'),
fetchOllamaJson<OllamaPsResponse>('/api/ps').catch(() => ({ models: [] })),
]);
const installedModels = (tags.models ?? [])
.map(model => model.name?.trim() || '')
.filter(Boolean);
const runningModels = new Set(
(ps.models ?? []).map(model => model.name?.trim() || '').filter(Boolean)
);
if (installedModels.length === 0) {
throw new Error('No local Ollama models installed');
}
// Put currently loaded models first so the router prefers them when scores are equal.
const orderedModels = [
...installedModels.filter(model => runningModels.has(model)),
...installedModels.filter(model => !runningModels.has(model)),
];
const router = new LlmRouter({
providers: [createLocalOllamaProvider(orderedModels, `${OLLAMA_URL}/v1`)],
maxRetries: 1,
});
return router.plan(request);
}
export async function POST(request: NextRequest) {
try {
@ -13,10 +71,15 @@ export async function POST(request: NextRequest) {
});
}
const plan = await resolvePlan({
messages,
model: model && model !== '__auto__' ? `local-ollama:${String(model)}` : undefined,
});
const response = await fetch(`${OLLAMA_URL}/api/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model, messages, stream: true }),
body: JSON.stringify({ model: plan.model.id, messages, stream: true }),
});
if (!response.ok || !response.body) {
@ -31,6 +94,9 @@ export async function POST(request: NextRequest) {
'Content-Type': 'application/x-ndjson',
'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-cache',
'x-by-provider': plan.provider.name,
'x-by-model': plan.model.id,
'x-by-category': plan.category,
},
});
} catch (err) {

View File

@ -0,0 +1,8 @@
export {
LlmRouter,
createLocalOllamaProvider,
classifyPrompt,
type ChatCompletionRequest,
type PromptCategory,
type RoutePlan,
} from '../../../../../packages/llm-router/dist/index.js';

View File

@ -1,5 +1,9 @@
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { getAvailableProviders, DEFAULT_PROVIDERS } from '../registry.js';
import {
createLocalOllamaProvider,
getAvailableProviders,
DEFAULT_PROVIDERS,
} from '../registry.js';
import type { ProviderConfig } from '../types.js';
describe('getAvailableProviders', () => {
@ -8,8 +12,9 @@ describe('getAvailableProviders', () => {
beforeEach(() => {
// Save and clear all default provider env vars
for (const p of DEFAULT_PROVIDERS) {
saved[p.apiKeyEnv] = process.env[p.apiKeyEnv];
delete process.env[p.apiKeyEnv];
const key = p.apiKeyEnv!;
saved[key] = process.env[key];
delete process.env[key];
}
});
@ -68,6 +73,13 @@ describe('getAvailableProviders', () => {
delete process.env.CUSTOM_TEST_KEY;
});
it('includes local providers that do not require API keys', () => {
const local = createLocalOllamaProvider(['qwen2.5-coder:7b']);
const result = getAvailableProviders([local]);
expect(result).toHaveLength(1);
expect(result[0]!.name).toBe('local-ollama');
});
it('DEFAULT_PROVIDERS includes all 4 providers', () => {
expect(DEFAULT_PROVIDERS).toHaveLength(4);
const names = DEFAULT_PROVIDERS.map(p => p.name);

View File

@ -1,5 +1,6 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { LlmRouter } from '../router.js';
import { createLocalOllamaProvider } from '../registry.js';
import type { ProviderConfig, ChatCompletionResponse } from '../types.js';
import * as client from '../client.js';
@ -241,6 +242,30 @@ describe('LlmRouter', () => {
expect(router.getProviders()).toEqual(['test-fast', 'test-quality']);
});
it('plans the best provider/model without executing a request', () => {
const router = new LlmRouter({ providers: TEST_PROVIDERS });
const plan = router.plan({
messages: [{ role: 'user', content: 'Write a TypeScript endpoint with validation' }],
});
expect(plan.provider.name).toBe('test-quality');
expect(plan.model.id).toBe('quality-model');
expect(plan.category).toBe('code');
expect(plan.explicit).toBe(false);
});
it('plans local ollama models without requiring an API key', () => {
const router = new LlmRouter({
providers: [createLocalOllamaProvider(['qwen2.5-coder:7b', 'llama3.1:8b'])],
});
const plan = router.plan({
messages: [{ role: 'user', content: 'Refactor this TypeScript function' }],
});
expect(plan.provider.name).toBe('local-ollama');
expect(plan.model.id).toBe('qwen2.5-coder:7b');
});
it('fires telemetry for explicit model routing', async () => {
vi.mocked(client.sendChatCompletion).mockResolvedValueOnce({
response: MOCK_RESPONSE,

View File

@ -10,17 +10,19 @@ export async function sendChatCompletion(
request: ChatCompletionRequest,
timeoutMs: number = 30_000
): Promise<{ response: ChatCompletionResponse; latencyMs: number; status: number }> {
const apiKey = process.env[provider.apiKeyEnv];
if (!apiKey) {
const apiKey = provider.apiKeyEnv ? process.env[provider.apiKeyEnv] : null;
if (provider.apiKeyEnv && !apiKey) {
throw new Error(`Missing API key: env var ${provider.apiKeyEnv} is not set`);
}
const url = `${provider.baseUrl}/chat/completions`;
const headers: Record<string, string> = {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
...provider.extraHeaders,
};
if (apiKey) {
headers.Authorization = `Bearer ${apiKey}`;
}
const body = JSON.stringify({
model: modelId,

View File

@ -1,7 +1,7 @@
export { LlmRouter } from './router.js';
export type { TelemetryEntry } from './router.js';
export { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
export { DEFAULT_PROVIDERS, createLocalOllamaProvider, getAvailableProviders } from './registry.js';
export { classifyPrompt } from './classifier.js';
export { HealthTracker } from './health.js';
export { selectCandidates, pickNext, excludeCandidate, createRoundRobinState } from './selector.js';
@ -22,4 +22,5 @@ export type {
ChatCompletionUsage,
ChatCompletionResponse,
RouteResult,
RoutePlan,
} from './types.js';

View File

@ -1,4 +1,4 @@
import type { ProviderConfig } from './types.js';
import type { ModelConfig, PromptCategory, ProviderConfig } from './types.js';
/**
* Default free-tier provider configurations.
@ -121,6 +121,57 @@ export const DEFAULT_PROVIDERS: ProviderConfig[] = [
},
];
function inferStrengths(modelId: string): PromptCategory[] {
const lower = modelId.toLowerCase();
const strengths = new Set<PromptCategory>(['general']);
if (/coder|code|codestral|starcoder|deepseek/.test(lower)) strengths.add('code');
if (/r1|reason|think|math/.test(lower)) {
strengths.add('reasoning');
strengths.add('math');
}
if (/qwen|llama|mistral|chat/.test(lower)) strengths.add('creative');
return [...strengths];
}
function inferContextWindow(modelId: string): number {
const lower = modelId.toLowerCase();
if (/128k|131072/.test(lower)) return 128_000;
if (/64k|65536/.test(lower)) return 64_000;
if (/32k|32768|qwen2\.5/.test(lower)) return 32_768;
if (/16k|16384/.test(lower)) return 16_384;
return 8_192;
}
function inferSpeedTier(modelId: string): 1 | 2 | 3 {
const lower = modelId.toLowerCase();
if (/0\.5b|1b|3b|7b|mini|tiny/.test(lower)) return 1;
if (/14b|15b|16b|20b|22b|30b|32b/.test(lower)) return 2;
return 3;
}
export function createLocalOllamaProvider(
modelIds: string[],
baseUrl: string = 'http://localhost:11434/v1'
): ProviderConfig {
const models: ModelConfig[] = modelIds.map(modelId => ({
id: modelId,
label: modelId,
contextWindow: inferContextWindow(modelId),
strengths: inferStrengths(modelId),
speedTier: inferSpeedTier(modelId),
}));
return {
name: 'local-ollama',
baseUrl,
models,
rpmLimit: 0,
tpmLimit: 0,
};
}
/**
* Filter providers to only those with API keys present in env.
*/
@ -128,6 +179,7 @@ export function getAvailableProviders(
providers: ProviderConfig[] = DEFAULT_PROVIDERS
): ProviderConfig[] {
return providers.filter(p => {
if (!p.apiKeyEnv) return true;
const key = process.env[p.apiKeyEnv];
return key !== undefined && key !== '';
});

View File

@ -1,8 +1,10 @@
import type {
ChatCompletionRequest,
PromptCategory,
RouterConfig,
ProviderConfig,
RouteResult,
RoutePlan,
HealthSnapshot,
} from './types.js';
import { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js';
@ -49,20 +51,14 @@ export class LlmRouter {
async chat(request: ChatCompletionRequest): Promise<RouteResult> {
const startTime = Date.now();
// If user specified a specific provider:model or provider/model, try that first
if (request.model && (request.model.includes(':') || request.model.includes('/'))) {
return this.chatWithExplicitModel(request, startTime);
const plan = this.planInternal(request, false);
if (plan.explicit) {
return this.chatWithExplicitModel(request, startTime, plan);
}
// Classify the prompt
const classification = classifyPrompt(request.messages);
// Get ranked candidates
let candidates = selectCandidates(this.providers, classification.category, this.health);
if (candidates.length === 0) {
throw new Error('No healthy providers available for routing');
}
const category = plan.category as PromptCategory;
let candidates = selectCandidates(this.providers, category, this.health);
let lastError: Error | null = null;
@ -90,7 +86,7 @@ export class LlmRouter {
model: model.id,
attempt,
latencyMs: result.latencyMs,
category: classification.category,
category,
});
candidates = excludeCandidate(candidates, provider.name, model.id);
@ -110,7 +106,7 @@ export class LlmRouter {
model: model.id,
attempt,
latencyMs: result.latencyMs,
category: classification.category,
category,
tokens: result.response.usage?.total_tokens,
});
@ -137,7 +133,7 @@ export class LlmRouter {
model: model.id,
attempt,
latencyMs: attemptLatency,
category: classification.category,
category,
error: lastError.message,
});
@ -155,33 +151,12 @@ export class LlmRouter {
*/
private async chatWithExplicitModel(
request: ChatCompletionRequest,
startTime: number
startTime: number,
plan?: RoutePlan
): Promise<RouteResult> {
// Support both "provider:model" and "provider/model" separators
// Use first colon or first slash (whichever comes first) as separator
const raw = request.model!;
const colonIdx = raw.indexOf(':');
const slashIdx = raw.indexOf('/');
let sepIdx: number;
if (colonIdx === -1 && slashIdx === -1) {
sepIdx = -1;
} else if (colonIdx === -1) {
sepIdx = slashIdx;
} else if (slashIdx === -1) {
sepIdx = colonIdx;
} else {
sepIdx = Math.min(colonIdx, slashIdx);
}
const providerName = sepIdx === -1 ? raw : raw.slice(0, sepIdx);
const modelId = sepIdx === -1 ? '' : raw.slice(sepIdx + 1);
const provider = this.providers.find(p => p.name === providerName);
if (!provider) {
throw new Error(
`Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}`
);
}
const resolved = plan ?? this.plan(request);
const provider = resolved.provider;
const modelId = resolved.model.id;
try {
const result = await sendChatCompletion(provider, modelId, request, this.timeoutMs);
@ -202,7 +177,7 @@ export class LlmRouter {
category: 'explicit',
});
throw new Error(`Rate limited by ${providerName} for model ${modelId}`);
throw new Error(`Rate limited by ${provider.name} for model ${modelId}`);
}
this.health.record(provider.name, modelId, {
@ -255,6 +230,88 @@ export class LlmRouter {
}
}
plan(request: ChatCompletionRequest): RoutePlan {
return this.planInternal(request, true);
}
private planInternal(request: ChatCompletionRequest, advanceRoundRobin: boolean): RoutePlan {
const explicit = this.resolveExplicitModel(request.model);
if (explicit) {
return {
provider: explicit.provider,
model: explicit.model,
category: 'explicit',
explicit: true,
};
}
const classification = classifyPrompt(request.messages);
const candidates = selectCandidates(this.providers, classification.category, this.health);
if (candidates.length === 0) {
throw new Error('No healthy providers available for routing');
}
const pick = advanceRoundRobin
? pickNext(candidates, this.roundRobinState)
: (candidates[0] ?? null);
if (!pick) {
throw new Error('No provider available for routing');
}
return {
provider: pick.provider,
model: pick.model,
category: classification.category,
explicit: false,
};
}
private resolveExplicitModel(
model?: string
): { provider: ProviderConfig; model: RoutePlan['model'] } | null {
if (!model) return null;
if (model.includes(':') || model.includes('/')) {
const { providerName, modelId } = parseExplicitModel(model);
const provider = this.providers.find(p => p.name === providerName);
if (!provider) {
throw new Error(
`Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}`
);
}
const matchedModel = provider.models.find(candidate => candidate.id === modelId);
if (!matchedModel) {
throw new Error(
`Model "${modelId}" not found for provider "${providerName}". Available: ${provider.models
.map(candidate => candidate.id)
.join(', ')}`
);
}
return { provider, model: matchedModel };
}
const matches = this.providers.flatMap(provider =>
provider.models
.filter(candidate => candidate.id === model)
.map(candidate => ({ provider, model: candidate }))
);
if (matches.length === 1) {
return matches[0]!;
}
if (matches.length > 1) {
throw new Error(
`Model "${model}" is available on multiple providers. Use provider:model format instead.`
);
}
return null;
}
/** Get health snapshots for all tracked provider+model pairs. */
getHealth(): HealthSnapshot[] {
return this.health.allSnapshots();
@ -271,6 +328,26 @@ export class LlmRouter {
}
}
function parseExplicitModel(raw: string): { providerName: string; modelId: string } {
const colonIdx = raw.indexOf(':');
const slashIdx = raw.indexOf('/');
let sepIdx: number;
if (colonIdx === -1 && slashIdx === -1) {
sepIdx = -1;
} else if (colonIdx === -1) {
sepIdx = slashIdx;
} else if (slashIdx === -1) {
sepIdx = colonIdx;
} else {
sepIdx = Math.min(colonIdx, slashIdx);
}
return {
providerName: sepIdx === -1 ? raw : raw.slice(0, sepIdx),
modelId: sepIdx === -1 ? '' : raw.slice(sepIdx + 1),
};
}
// ── Telemetry types ────────────────────────────────────────────
export interface TelemetryEntry {

View File

@ -18,8 +18,8 @@ export interface ProviderConfig {
name: string;
/** OpenAI-compatible base URL (e.g. https://api.groq.com/openai/v1) */
baseUrl: string;
/** Environment variable name that holds the API key */
apiKeyEnv: string;
/** Environment variable name that holds the API key (omit for local/no-auth providers) */
apiKeyEnv?: string;
/** Available models on this provider */
models: ModelConfig[];
/** Extra headers to send with every request */
@ -134,3 +134,10 @@ export interface RouteResult {
/** How many attempts were made */
attempts: number;
}
export interface RoutePlan {
provider: ProviderConfig;
model: ModelConfig;
category: PromptCategory | 'explicit';
explicit: boolean;
}