From bf2c285e4199bdf28828c34410916dec0132174b Mon Sep 17 00:00:00 2001 From: root Date: Sat, 14 Mar 2026 06:17:17 +0000 Subject: [PATCH] feat(llm-router): unify local ollama routing --- __LOCAL_LLMs/dashboard/package.json | 2 + .../components/ConversationView.tsx | 36 ++-- .../src/app/api/ollama/chat/route.ts | 68 +++++++- .../dashboard/src/app/lib/llm-router.ts | 8 + .../llm-router/src/__tests__/registry.test.ts | 18 +- .../llm-router/src/__tests__/router.test.ts | 25 +++ packages/llm-router/src/client.ts | 8 +- packages/llm-router/src/index.ts | 3 +- packages/llm-router/src/registry.ts | 54 +++++- packages/llm-router/src/router.ts | 161 +++++++++++++----- packages/llm-router/src/types.ts | 11 +- 11 files changed, 323 insertions(+), 71 deletions(-) create mode 100644 __LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts diff --git a/__LOCAL_LLMs/dashboard/package.json b/__LOCAL_LLMs/dashboard/package.json index 8d44f3db..52126cc8 100644 --- a/__LOCAL_LLMs/dashboard/package.json +++ b/__LOCAL_LLMs/dashboard/package.json @@ -3,6 +3,8 @@ "version": "0.1.0", "private": true, "scripts": { + "predev": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build", + "prebuild": "corepack pnpm --dir ../.. --filter @bytelyst/llm-router build", "dev": "next dev", "build": "next build", "start": "next start", diff --git a/__LOCAL_LLMs/dashboard/src/app/(workspace)/components/ConversationView.tsx b/__LOCAL_LLMs/dashboard/src/app/(workspace)/components/ConversationView.tsx index 5a1e78a1..58a25ffb 100644 --- a/__LOCAL_LLMs/dashboard/src/app/(workspace)/components/ConversationView.tsx +++ b/__LOCAL_LLMs/dashboard/src/app/(workspace)/components/ConversationView.tsx @@ -8,7 +8,7 @@ import { InputBar } from './InputBar'; import { MessageThread } from './MessageThread'; import { ContextBar } from './ContextBar'; import { estimateTokens, getModelContextWindow } from '../../lib/format'; -import { autoDetectDefaults, classifyTask, resolveModel } from '../../lib/router'; +import { autoDetectDefaults } from '../../lib/router'; import type { ModelDefaults } from '../../lib/types'; interface ConversationViewProps { @@ -31,7 +31,6 @@ export function ConversationView({ const [streaming, setStreaming] = useState(false); const [selectedModel, setSelectedModel] = useState(conversation.model || '__auto__'); const [models, setModels] = useState([]); - const [runningModels, setRunningModels] = useState([]); const [showModels, setShowModels] = useState(false); const [modelDefaults, setModelDefaults] = useState(null); const abortRef = useRef(null); @@ -53,9 +52,7 @@ export function ConversationView({ if (!res.ok) return; const data = (await res.json()) as OllamaData; const installed = data.models.map(m => m.name); - const running = data.running.map(m => m.name); setModels(installed); - setRunningModels(running); const stored = localStorage.getItem('llm-model-defaults'); if (stored) { @@ -114,15 +111,7 @@ export function ConversationView({ const onSend = async (text: string) => { await ensureModels(); - const routedModel = (() => { - if (selectedModel !== '__auto__') return selectedModel; - const defaults = modelDefaults || autoDetectDefaults(models); - const taskType = classifyTask(text); - const resolved = resolveModel(taskType, defaults, runningModels, models); - return resolved.model; - })(); - - if (!routedModel) return; + const requestedModel = selectedModel !== '__auto__' ? selectedModel : undefined; const now = Date.now(); const userMessage: Message = { @@ -131,7 +120,7 @@ export function ConversationView({ role: 'user', content: text, timestamp: now, - model: routedModel, + model: requestedModel, }; await addMessage(userMessage); @@ -159,9 +148,7 @@ export function ConversationView({ method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - model: routedModel, - modelDefaults: modelDefaults, - taskType: classifyTask(text), + model: requestedModel ?? '__auto__', messages: chatPayload, }), signal: controller.signal, @@ -171,6 +158,19 @@ export function ConversationView({ throw new Error('Failed to stream response'); } + const routedModel = res.headers.get('x-by-model') || requestedModel || ''; + if (!routedModel) { + throw new Error('Router did not return a model'); + } + const previousUserModel = userMessage.model; + userMessage.model = routedModel; + if (routedModel !== previousUserModel) { + await updateMessage(userMessage.id, { model: routedModel }); + } + setMessages(prev => + prev.map(message => (message.id === userMessage.id ? userMessage : message)) + ); + const reader = res.body.getReader(); const decoder = new TextDecoder(); let buffer = ''; @@ -254,7 +254,7 @@ export function ConversationView({ role: 'assistant', content: `Error: ${String(err)}`, timestamp: Date.now(), - model: routedModel, + model: requestedModel, }; setMessages(prev => { const withoutTemp = prev.filter(m => m.id !== assistantId); diff --git a/__LOCAL_LLMs/dashboard/src/app/api/ollama/chat/route.ts b/__LOCAL_LLMs/dashboard/src/app/api/ollama/chat/route.ts index 9718cd3a..241045f5 100644 --- a/__LOCAL_LLMs/dashboard/src/app/api/ollama/chat/route.ts +++ b/__LOCAL_LLMs/dashboard/src/app/api/ollama/chat/route.ts @@ -1,5 +1,63 @@ import { NextRequest } from 'next/server'; import { OLLAMA_URL } from '../../../lib/ollama-config'; +import { + LlmRouter, + createLocalOllamaProvider, + type ChatCompletionRequest, + type RoutePlan, +} from '../../../lib/llm-router'; + +interface OllamaTagsResponse { + models?: Array<{ name?: string }>; +} + +interface OllamaPsResponse { + models?: Array<{ name?: string }>; +} + +async function fetchOllamaJson(path: string): Promise { + const response = await fetch(`${OLLAMA_URL}${path}`, { + cache: 'no-store', + signal: AbortSignal.timeout(5000), + }); + + if (!response.ok) { + throw new Error(`Ollama ${path} failed: ${response.status}`); + } + + return (await response.json()) as T; +} + +async function resolvePlan(request: ChatCompletionRequest): Promise { + const [tags, ps] = await Promise.all([ + fetchOllamaJson('/api/tags'), + fetchOllamaJson('/api/ps').catch(() => ({ models: [] })), + ]); + + const installedModels = (tags.models ?? []) + .map(model => model.name?.trim() || '') + .filter(Boolean); + const runningModels = new Set( + (ps.models ?? []).map(model => model.name?.trim() || '').filter(Boolean) + ); + + if (installedModels.length === 0) { + throw new Error('No local Ollama models installed'); + } + + // Put currently loaded models first so the router prefers them when scores are equal. + const orderedModels = [ + ...installedModels.filter(model => runningModels.has(model)), + ...installedModels.filter(model => !runningModels.has(model)), + ]; + + const router = new LlmRouter({ + providers: [createLocalOllamaProvider(orderedModels, `${OLLAMA_URL}/v1`)], + maxRetries: 1, + }); + + return router.plan(request); +} export async function POST(request: NextRequest) { try { @@ -13,10 +71,15 @@ export async function POST(request: NextRequest) { }); } + const plan = await resolvePlan({ + messages, + model: model && model !== '__auto__' ? `local-ollama:${String(model)}` : undefined, + }); + const response = await fetch(`${OLLAMA_URL}/api/chat`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ model, messages, stream: true }), + body: JSON.stringify({ model: plan.model.id, messages, stream: true }), }); if (!response.ok || !response.body) { @@ -31,6 +94,9 @@ export async function POST(request: NextRequest) { 'Content-Type': 'application/x-ndjson', 'Transfer-Encoding': 'chunked', 'Cache-Control': 'no-cache', + 'x-by-provider': plan.provider.name, + 'x-by-model': plan.model.id, + 'x-by-category': plan.category, }, }); } catch (err) { diff --git a/__LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts b/__LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts new file mode 100644 index 00000000..e6f0e8b9 --- /dev/null +++ b/__LOCAL_LLMs/dashboard/src/app/lib/llm-router.ts @@ -0,0 +1,8 @@ +export { + LlmRouter, + createLocalOllamaProvider, + classifyPrompt, + type ChatCompletionRequest, + type PromptCategory, + type RoutePlan, +} from '../../../../../packages/llm-router/dist/index.js'; diff --git a/packages/llm-router/src/__tests__/registry.test.ts b/packages/llm-router/src/__tests__/registry.test.ts index 4bf9fd9e..a66729d6 100644 --- a/packages/llm-router/src/__tests__/registry.test.ts +++ b/packages/llm-router/src/__tests__/registry.test.ts @@ -1,5 +1,9 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { getAvailableProviders, DEFAULT_PROVIDERS } from '../registry.js'; +import { + createLocalOllamaProvider, + getAvailableProviders, + DEFAULT_PROVIDERS, +} from '../registry.js'; import type { ProviderConfig } from '../types.js'; describe('getAvailableProviders', () => { @@ -8,8 +12,9 @@ describe('getAvailableProviders', () => { beforeEach(() => { // Save and clear all default provider env vars for (const p of DEFAULT_PROVIDERS) { - saved[p.apiKeyEnv] = process.env[p.apiKeyEnv]; - delete process.env[p.apiKeyEnv]; + const key = p.apiKeyEnv!; + saved[key] = process.env[key]; + delete process.env[key]; } }); @@ -68,6 +73,13 @@ describe('getAvailableProviders', () => { delete process.env.CUSTOM_TEST_KEY; }); + it('includes local providers that do not require API keys', () => { + const local = createLocalOllamaProvider(['qwen2.5-coder:7b']); + const result = getAvailableProviders([local]); + expect(result).toHaveLength(1); + expect(result[0]!.name).toBe('local-ollama'); + }); + it('DEFAULT_PROVIDERS includes all 4 providers', () => { expect(DEFAULT_PROVIDERS).toHaveLength(4); const names = DEFAULT_PROVIDERS.map(p => p.name); diff --git a/packages/llm-router/src/__tests__/router.test.ts b/packages/llm-router/src/__tests__/router.test.ts index 9ece98e8..6f3bebc4 100644 --- a/packages/llm-router/src/__tests__/router.test.ts +++ b/packages/llm-router/src/__tests__/router.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { LlmRouter } from '../router.js'; +import { createLocalOllamaProvider } from '../registry.js'; import type { ProviderConfig, ChatCompletionResponse } from '../types.js'; import * as client from '../client.js'; @@ -241,6 +242,30 @@ describe('LlmRouter', () => { expect(router.getProviders()).toEqual(['test-fast', 'test-quality']); }); + it('plans the best provider/model without executing a request', () => { + const router = new LlmRouter({ providers: TEST_PROVIDERS }); + const plan = router.plan({ + messages: [{ role: 'user', content: 'Write a TypeScript endpoint with validation' }], + }); + + expect(plan.provider.name).toBe('test-quality'); + expect(plan.model.id).toBe('quality-model'); + expect(plan.category).toBe('code'); + expect(plan.explicit).toBe(false); + }); + + it('plans local ollama models without requiring an API key', () => { + const router = new LlmRouter({ + providers: [createLocalOllamaProvider(['qwen2.5-coder:7b', 'llama3.1:8b'])], + }); + const plan = router.plan({ + messages: [{ role: 'user', content: 'Refactor this TypeScript function' }], + }); + + expect(plan.provider.name).toBe('local-ollama'); + expect(plan.model.id).toBe('qwen2.5-coder:7b'); + }); + it('fires telemetry for explicit model routing', async () => { vi.mocked(client.sendChatCompletion).mockResolvedValueOnce({ response: MOCK_RESPONSE, diff --git a/packages/llm-router/src/client.ts b/packages/llm-router/src/client.ts index febd13b8..fe577694 100644 --- a/packages/llm-router/src/client.ts +++ b/packages/llm-router/src/client.ts @@ -10,17 +10,19 @@ export async function sendChatCompletion( request: ChatCompletionRequest, timeoutMs: number = 30_000 ): Promise<{ response: ChatCompletionResponse; latencyMs: number; status: number }> { - const apiKey = process.env[provider.apiKeyEnv]; - if (!apiKey) { + const apiKey = provider.apiKeyEnv ? process.env[provider.apiKeyEnv] : null; + if (provider.apiKeyEnv && !apiKey) { throw new Error(`Missing API key: env var ${provider.apiKeyEnv} is not set`); } const url = `${provider.baseUrl}/chat/completions`; const headers: Record = { 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, ...provider.extraHeaders, }; + if (apiKey) { + headers.Authorization = `Bearer ${apiKey}`; + } const body = JSON.stringify({ model: modelId, diff --git a/packages/llm-router/src/index.ts b/packages/llm-router/src/index.ts index f7dbe513..74ab47bd 100644 --- a/packages/llm-router/src/index.ts +++ b/packages/llm-router/src/index.ts @@ -1,7 +1,7 @@ export { LlmRouter } from './router.js'; export type { TelemetryEntry } from './router.js'; -export { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js'; +export { DEFAULT_PROVIDERS, createLocalOllamaProvider, getAvailableProviders } from './registry.js'; export { classifyPrompt } from './classifier.js'; export { HealthTracker } from './health.js'; export { selectCandidates, pickNext, excludeCandidate, createRoundRobinState } from './selector.js'; @@ -22,4 +22,5 @@ export type { ChatCompletionUsage, ChatCompletionResponse, RouteResult, + RoutePlan, } from './types.js'; diff --git a/packages/llm-router/src/registry.ts b/packages/llm-router/src/registry.ts index 2d533111..5db490bc 100644 --- a/packages/llm-router/src/registry.ts +++ b/packages/llm-router/src/registry.ts @@ -1,4 +1,4 @@ -import type { ProviderConfig } from './types.js'; +import type { ModelConfig, PromptCategory, ProviderConfig } from './types.js'; /** * Default free-tier provider configurations. @@ -121,6 +121,57 @@ export const DEFAULT_PROVIDERS: ProviderConfig[] = [ }, ]; +function inferStrengths(modelId: string): PromptCategory[] { + const lower = modelId.toLowerCase(); + const strengths = new Set(['general']); + + if (/coder|code|codestral|starcoder|deepseek/.test(lower)) strengths.add('code'); + if (/r1|reason|think|math/.test(lower)) { + strengths.add('reasoning'); + strengths.add('math'); + } + if (/qwen|llama|mistral|chat/.test(lower)) strengths.add('creative'); + + return [...strengths]; +} + +function inferContextWindow(modelId: string): number { + const lower = modelId.toLowerCase(); + if (/128k|131072/.test(lower)) return 128_000; + if (/64k|65536/.test(lower)) return 64_000; + if (/32k|32768|qwen2\.5/.test(lower)) return 32_768; + if (/16k|16384/.test(lower)) return 16_384; + return 8_192; +} + +function inferSpeedTier(modelId: string): 1 | 2 | 3 { + const lower = modelId.toLowerCase(); + if (/0\.5b|1b|3b|7b|mini|tiny/.test(lower)) return 1; + if (/14b|15b|16b|20b|22b|30b|32b/.test(lower)) return 2; + return 3; +} + +export function createLocalOllamaProvider( + modelIds: string[], + baseUrl: string = 'http://localhost:11434/v1' +): ProviderConfig { + const models: ModelConfig[] = modelIds.map(modelId => ({ + id: modelId, + label: modelId, + contextWindow: inferContextWindow(modelId), + strengths: inferStrengths(modelId), + speedTier: inferSpeedTier(modelId), + })); + + return { + name: 'local-ollama', + baseUrl, + models, + rpmLimit: 0, + tpmLimit: 0, + }; +} + /** * Filter providers to only those with API keys present in env. */ @@ -128,6 +179,7 @@ export function getAvailableProviders( providers: ProviderConfig[] = DEFAULT_PROVIDERS ): ProviderConfig[] { return providers.filter(p => { + if (!p.apiKeyEnv) return true; const key = process.env[p.apiKeyEnv]; return key !== undefined && key !== ''; }); diff --git a/packages/llm-router/src/router.ts b/packages/llm-router/src/router.ts index e8cdd821..f46951de 100644 --- a/packages/llm-router/src/router.ts +++ b/packages/llm-router/src/router.ts @@ -1,8 +1,10 @@ import type { ChatCompletionRequest, + PromptCategory, RouterConfig, ProviderConfig, RouteResult, + RoutePlan, HealthSnapshot, } from './types.js'; import { DEFAULT_PROVIDERS, getAvailableProviders } from './registry.js'; @@ -49,20 +51,14 @@ export class LlmRouter { async chat(request: ChatCompletionRequest): Promise { const startTime = Date.now(); - // If user specified a specific provider:model or provider/model, try that first - if (request.model && (request.model.includes(':') || request.model.includes('/'))) { - return this.chatWithExplicitModel(request, startTime); + const plan = this.planInternal(request, false); + + if (plan.explicit) { + return this.chatWithExplicitModel(request, startTime, plan); } - // Classify the prompt - const classification = classifyPrompt(request.messages); - - // Get ranked candidates - let candidates = selectCandidates(this.providers, classification.category, this.health); - - if (candidates.length === 0) { - throw new Error('No healthy providers available for routing'); - } + const category = plan.category as PromptCategory; + let candidates = selectCandidates(this.providers, category, this.health); let lastError: Error | null = null; @@ -90,7 +86,7 @@ export class LlmRouter { model: model.id, attempt, latencyMs: result.latencyMs, - category: classification.category, + category, }); candidates = excludeCandidate(candidates, provider.name, model.id); @@ -110,7 +106,7 @@ export class LlmRouter { model: model.id, attempt, latencyMs: result.latencyMs, - category: classification.category, + category, tokens: result.response.usage?.total_tokens, }); @@ -137,7 +133,7 @@ export class LlmRouter { model: model.id, attempt, latencyMs: attemptLatency, - category: classification.category, + category, error: lastError.message, }); @@ -155,33 +151,12 @@ export class LlmRouter { */ private async chatWithExplicitModel( request: ChatCompletionRequest, - startTime: number + startTime: number, + plan?: RoutePlan ): Promise { - // Support both "provider:model" and "provider/model" separators - // Use first colon or first slash (whichever comes first) as separator - const raw = request.model!; - const colonIdx = raw.indexOf(':'); - const slashIdx = raw.indexOf('/'); - let sepIdx: number; - if (colonIdx === -1 && slashIdx === -1) { - sepIdx = -1; - } else if (colonIdx === -1) { - sepIdx = slashIdx; - } else if (slashIdx === -1) { - sepIdx = colonIdx; - } else { - sepIdx = Math.min(colonIdx, slashIdx); - } - - const providerName = sepIdx === -1 ? raw : raw.slice(0, sepIdx); - const modelId = sepIdx === -1 ? '' : raw.slice(sepIdx + 1); - - const provider = this.providers.find(p => p.name === providerName); - if (!provider) { - throw new Error( - `Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}` - ); - } + const resolved = plan ?? this.plan(request); + const provider = resolved.provider; + const modelId = resolved.model.id; try { const result = await sendChatCompletion(provider, modelId, request, this.timeoutMs); @@ -202,7 +177,7 @@ export class LlmRouter { category: 'explicit', }); - throw new Error(`Rate limited by ${providerName} for model ${modelId}`); + throw new Error(`Rate limited by ${provider.name} for model ${modelId}`); } this.health.record(provider.name, modelId, { @@ -255,6 +230,88 @@ export class LlmRouter { } } + plan(request: ChatCompletionRequest): RoutePlan { + return this.planInternal(request, true); + } + + private planInternal(request: ChatCompletionRequest, advanceRoundRobin: boolean): RoutePlan { + const explicit = this.resolveExplicitModel(request.model); + if (explicit) { + return { + provider: explicit.provider, + model: explicit.model, + category: 'explicit', + explicit: true, + }; + } + + const classification = classifyPrompt(request.messages); + const candidates = selectCandidates(this.providers, classification.category, this.health); + + if (candidates.length === 0) { + throw new Error('No healthy providers available for routing'); + } + + const pick = advanceRoundRobin + ? pickNext(candidates, this.roundRobinState) + : (candidates[0] ?? null); + if (!pick) { + throw new Error('No provider available for routing'); + } + + return { + provider: pick.provider, + model: pick.model, + category: classification.category, + explicit: false, + }; + } + + private resolveExplicitModel( + model?: string + ): { provider: ProviderConfig; model: RoutePlan['model'] } | null { + if (!model) return null; + + if (model.includes(':') || model.includes('/')) { + const { providerName, modelId } = parseExplicitModel(model); + const provider = this.providers.find(p => p.name === providerName); + if (!provider) { + throw new Error( + `Provider "${providerName}" not found. Available: ${this.providers.map(p => p.name).join(', ')}` + ); + } + + const matchedModel = provider.models.find(candidate => candidate.id === modelId); + if (!matchedModel) { + throw new Error( + `Model "${modelId}" not found for provider "${providerName}". Available: ${provider.models + .map(candidate => candidate.id) + .join(', ')}` + ); + } + + return { provider, model: matchedModel }; + } + + const matches = this.providers.flatMap(provider => + provider.models + .filter(candidate => candidate.id === model) + .map(candidate => ({ provider, model: candidate })) + ); + + if (matches.length === 1) { + return matches[0]!; + } + + if (matches.length > 1) { + throw new Error( + `Model "${model}" is available on multiple providers. Use provider:model format instead.` + ); + } + + return null; + } + /** Get health snapshots for all tracked provider+model pairs. */ getHealth(): HealthSnapshot[] { return this.health.allSnapshots(); @@ -271,6 +328,26 @@ export class LlmRouter { } } +function parseExplicitModel(raw: string): { providerName: string; modelId: string } { + const colonIdx = raw.indexOf(':'); + const slashIdx = raw.indexOf('/'); + let sepIdx: number; + if (colonIdx === -1 && slashIdx === -1) { + sepIdx = -1; + } else if (colonIdx === -1) { + sepIdx = slashIdx; + } else if (slashIdx === -1) { + sepIdx = colonIdx; + } else { + sepIdx = Math.min(colonIdx, slashIdx); + } + + return { + providerName: sepIdx === -1 ? raw : raw.slice(0, sepIdx), + modelId: sepIdx === -1 ? '' : raw.slice(sepIdx + 1), + }; +} + // ── Telemetry types ──────────────────────────────────────────── export interface TelemetryEntry { diff --git a/packages/llm-router/src/types.ts b/packages/llm-router/src/types.ts index 7b7c1d17..cc8abee4 100644 --- a/packages/llm-router/src/types.ts +++ b/packages/llm-router/src/types.ts @@ -18,8 +18,8 @@ export interface ProviderConfig { name: string; /** OpenAI-compatible base URL (e.g. https://api.groq.com/openai/v1) */ baseUrl: string; - /** Environment variable name that holds the API key */ - apiKeyEnv: string; + /** Environment variable name that holds the API key (omit for local/no-auth providers) */ + apiKeyEnv?: string; /** Available models on this provider */ models: ModelConfig[]; /** Extra headers to send with every request */ @@ -134,3 +134,10 @@ export interface RouteResult { /** How many attempts were made */ attempts: number; } + +export interface RoutePlan { + provider: ProviderConfig; + model: ModelConfig; + category: PromptCategory | 'explicit'; + explicit: boolean; +}