feat(llm): add vision, streaming, and embedding support
- ContentPart types (TextContentPart, ImageUrlContentPart) for multipart messages - ChatMessage.content now accepts string | ContentPart[] for vision - EmbeddingRequest/Response types + optional embed() on LLMProvider - chatCompletionStream() implemented in OpenAI + Azure providers (SSE parsing) - embed() implemented in OpenAI + Azure providers - Vision helpers: isVisionMessage, hasVisionContent, buildVisionMessage, getMessageText - MockLLMProvider: streaming, embedding, vision content support - 27 tests passing (up from 7)
This commit is contained in:
parent
43bf51a290
commit
151e07207b
@ -1,10 +1,117 @@
|
||||
/**
|
||||
* Tests for LLM providers and factory.
|
||||
* Tests for LLM providers, factory, types, and helpers.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest';
|
||||
import { MockLLMProvider } from '../providers/mock.js';
|
||||
import { createLLMProvider, _resetLLM } from '../factory.js';
|
||||
import { isVisionMessage, hasVisionContent, buildVisionMessage, getMessageText } from '../types.js';
|
||||
import type { ChatMessage, ChatCompletionRequest, EmbeddingResponse } from '../types.js';
|
||||
|
||||
// ── Helper function tests ─────────────────────────────────────────
|
||||
|
||||
describe('isVisionMessage', () => {
|
||||
it('returns false for string content', () => {
|
||||
const msg: ChatMessage = { role: 'user', content: 'hello' };
|
||||
expect(isVisionMessage(msg)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for text-only multipart', () => {
|
||||
const msg: ChatMessage = {
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'hello' }],
|
||||
};
|
||||
expect(isVisionMessage(msg)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when message contains image_url part', () => {
|
||||
const msg: ChatMessage = {
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'describe this' },
|
||||
{ type: 'image_url', image_url: { url: 'https://example.com/img.png' } },
|
||||
],
|
||||
};
|
||||
expect(isVisionMessage(msg)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasVisionContent', () => {
|
||||
it('returns false for text-only request', () => {
|
||||
const req: ChatCompletionRequest = {
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are helpful' },
|
||||
{ role: 'user', content: 'hello' },
|
||||
],
|
||||
};
|
||||
expect(hasVisionContent(req)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true when any message has image content', () => {
|
||||
const req: ChatCompletionRequest = {
|
||||
messages: [
|
||||
{ role: 'system', content: 'You are helpful' },
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'what is this?' },
|
||||
{ type: 'image_url', image_url: { url: 'data:image/png;base64,abc' } },
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
expect(hasVisionContent(req)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildVisionMessage', () => {
|
||||
it('builds a multipart user message with text and image', () => {
|
||||
const msg = buildVisionMessage('Describe this', 'https://img.com/a.png');
|
||||
expect(msg.role).toBe('user');
|
||||
expect(Array.isArray(msg.content)).toBe(true);
|
||||
const parts = msg.content as Array<{ type: string }>;
|
||||
expect(parts).toHaveLength(2);
|
||||
expect(parts[0]).toEqual({ type: 'text', text: 'Describe this' });
|
||||
expect(parts[1]).toEqual({
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://img.com/a.png', detail: 'auto' },
|
||||
});
|
||||
});
|
||||
|
||||
it('respects detail parameter', () => {
|
||||
const msg = buildVisionMessage('hi', 'https://img.com/b.png', 'high');
|
||||
const parts = msg.content as Array<{ type: string; image_url?: { detail: string } }>;
|
||||
expect(parts[1]?.image_url?.detail).toBe('high');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessageText', () => {
|
||||
it('returns string content directly', () => {
|
||||
expect(getMessageText({ role: 'user', content: 'hello' })).toBe('hello');
|
||||
});
|
||||
|
||||
it('extracts text from multipart content', () => {
|
||||
const msg: ChatMessage = {
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'line one' },
|
||||
{ type: 'image_url', image_url: { url: 'https://img.com/x.png' } },
|
||||
{ type: 'text', text: 'line two' },
|
||||
],
|
||||
};
|
||||
expect(getMessageText(msg)).toBe('line one\nline two');
|
||||
});
|
||||
|
||||
it('returns empty for image-only multipart', () => {
|
||||
const msg: ChatMessage = {
|
||||
role: 'user',
|
||||
content: [{ type: 'image_url', image_url: { url: 'https://img.com/x.png' } }],
|
||||
};
|
||||
expect(getMessageText(msg)).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
// ── MockLLMProvider tests ─────────────────────────────────────────
|
||||
|
||||
describe('MockLLMProvider', () => {
|
||||
let provider: MockLLMProvider;
|
||||
@ -57,8 +164,124 @@ describe('MockLLMProvider', () => {
|
||||
provider.reset();
|
||||
expect(provider.calls).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('handles multipart (vision) content in echo response', async () => {
|
||||
const result = await provider.chatCompletion({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Describe this image' },
|
||||
{ type: 'image_url', image_url: { url: 'https://example.com/img.png' } },
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
expect(result.content).toContain('Describe this image');
|
||||
});
|
||||
|
||||
// ── Streaming tests ──────────────────────────────────
|
||||
|
||||
it('streams default echo response word by word', async () => {
|
||||
const stream = provider.chatCompletionStream!({
|
||||
messages: [{ role: 'user', content: 'Hi' }],
|
||||
});
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
const full = chunks.join('');
|
||||
expect(full).toContain('Hi');
|
||||
});
|
||||
|
||||
it('streams queued response word by word', async () => {
|
||||
provider.addResponse({
|
||||
content: 'Hello World',
|
||||
model: 'test',
|
||||
finishReason: 'stop',
|
||||
usage: { promptTokens: 1, completionTokens: 2, totalTokens: 3 },
|
||||
});
|
||||
const stream = provider.chatCompletionStream!({
|
||||
messages: [{ role: 'user', content: 'x' }],
|
||||
});
|
||||
const chunks: string[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
expect(chunks).toEqual(['Hello ', 'World ']);
|
||||
});
|
||||
|
||||
it('streaming tracks calls', async () => {
|
||||
const req = { messages: [{ role: 'user' as const, content: 'stream test' }] };
|
||||
const stream = provider.chatCompletionStream!(req);
|
||||
const drained: string[] = [];
|
||||
for await (const chunk of stream) {
|
||||
drained.push(chunk);
|
||||
}
|
||||
expect(drained.length).toBeGreaterThan(0);
|
||||
expect(provider.calls).toHaveLength(1);
|
||||
});
|
||||
|
||||
// ── Embedding tests ──────────────────────────────────
|
||||
|
||||
it('returns deterministic embeddings for single input', async () => {
|
||||
const result = await provider.embed!({ input: 'hello world' });
|
||||
expect(result.embeddings).toHaveLength(1);
|
||||
expect(result.embeddings[0].length).toBe(8);
|
||||
expect(result.model).toBe('mock-embedding-model');
|
||||
// Verify normalized (magnitude ≈ 1)
|
||||
const mag = Math.sqrt(result.embeddings[0].reduce((s, v) => s + v * v, 0));
|
||||
expect(mag).toBeCloseTo(1.0, 3);
|
||||
});
|
||||
|
||||
it('returns multiple embeddings for array input', async () => {
|
||||
const result = await provider.embed!({ input: ['hello', 'world'] });
|
||||
expect(result.embeddings).toHaveLength(2);
|
||||
expect(result.embeddings[0]).not.toEqual(result.embeddings[1]);
|
||||
});
|
||||
|
||||
it('returns queued embedding response', async () => {
|
||||
const custom: EmbeddingResponse = {
|
||||
embeddings: [[0.1, 0.2, 0.3]],
|
||||
model: 'custom-embed',
|
||||
usage: { promptTokens: 1, completionTokens: 0, totalTokens: 1 },
|
||||
};
|
||||
provider.addEmbeddingResponse(custom);
|
||||
const result = await provider.embed!({ input: 'test' });
|
||||
expect(result).toEqual(custom);
|
||||
});
|
||||
|
||||
it('tracks embed calls', async () => {
|
||||
await provider.embed!({ input: 'track me' });
|
||||
expect(provider.embedCalls).toHaveLength(1);
|
||||
expect(provider.embedCalls[0].input).toBe('track me');
|
||||
});
|
||||
|
||||
it('deterministic: same input produces same embedding', async () => {
|
||||
const r1 = await provider.embed!({ input: 'identical text' });
|
||||
const r2 = await provider.embed!({ input: 'identical text' });
|
||||
expect(r1.embeddings[0]).toEqual(r2.embeddings[0]);
|
||||
});
|
||||
|
||||
it('reset clears embed state', async () => {
|
||||
const custom: EmbeddingResponse = {
|
||||
embeddings: [[0.5]],
|
||||
model: 'm',
|
||||
usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 },
|
||||
};
|
||||
provider.addEmbeddingResponse(custom);
|
||||
await provider.embed!({ input: 'test' });
|
||||
provider.reset();
|
||||
expect(provider.embedCalls).toHaveLength(0);
|
||||
// After reset, embed should return default (not queued) response
|
||||
const result = await provider.embed!({ input: 'after reset' });
|
||||
expect(result.model).toBe('mock-embedding-model');
|
||||
});
|
||||
});
|
||||
|
||||
// ── Factory tests ─────────────────────────────────────────────────
|
||||
|
||||
describe('createLLMProvider', () => {
|
||||
afterEach(() => {
|
||||
_resetLLM();
|
||||
|
||||
@ -5,8 +5,15 @@ export type {
|
||||
ChatMessage,
|
||||
TokenUsage,
|
||||
LLMProviderType,
|
||||
ContentPart,
|
||||
TextContentPart,
|
||||
ImageUrlContentPart,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
} from './types.js';
|
||||
|
||||
export { isVisionMessage, hasVisionContent, buildVisionMessage, getMessageText } from './types.js';
|
||||
|
||||
export { getLLM, createLLMProvider, setLLM, _resetLLM } from './factory.js';
|
||||
export { createFallbackChain } from './fallback.js';
|
||||
export { AzureOpenAIProvider, type AzureOpenAIConfig } from './providers/azure-openai.js';
|
||||
|
||||
@ -3,14 +3,22 @@
|
||||
*
|
||||
* Uses Azure OpenAI REST API with api-key authentication.
|
||||
* Reads config from AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_KEY, AZURE_OPENAI_DEPLOYMENT.
|
||||
* Supports text, vision (multipart content), streaming, and embeddings.
|
||||
*/
|
||||
|
||||
import type { ChatCompletionRequest, ChatCompletionResponse, LLMProvider } from '../types.js';
|
||||
import type {
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
LLMProvider,
|
||||
} from '../types.js';
|
||||
|
||||
export interface AzureOpenAIConfig {
|
||||
endpoint: string;
|
||||
apiKey: string;
|
||||
deployment: string;
|
||||
embeddingDeployment?: string;
|
||||
apiVersion?: string;
|
||||
}
|
||||
|
||||
@ -26,6 +34,10 @@ export class AzureOpenAIProvider implements LLMProvider {
|
||||
process.env.AZURE_OPENAI_DEPLOYMENT ||
|
||||
process.env.OPENAI_MODEL ||
|
||||
'gpt-4o-mini',
|
||||
embeddingDeployment:
|
||||
config?.embeddingDeployment ||
|
||||
process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT ||
|
||||
'text-embedding-3-small',
|
||||
apiVersion: config?.apiVersion || process.env.AZURE_OPENAI_API_VERSION || '2024-06-01',
|
||||
};
|
||||
}
|
||||
@ -34,6 +46,24 @@ export class AzureOpenAIProvider implements LLMProvider {
|
||||
return Boolean(this.config.endpoint && this.config.apiKey && this.config.deployment);
|
||||
}
|
||||
|
||||
private getBaseUrl(): string {
|
||||
return this.config.endpoint.replace(/\/+$/, '');
|
||||
}
|
||||
|
||||
private getChatUrl(): string {
|
||||
const base = this.getBaseUrl();
|
||||
const deployment = encodeURIComponent(this.config.deployment);
|
||||
const version = encodeURIComponent(this.config.apiVersion!);
|
||||
return `${base}/openai/deployments/${deployment}/chat/completions?api-version=${version}`;
|
||||
}
|
||||
|
||||
private getHeaders(): Record<string, string> {
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': this.config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
async chatCompletion(req: ChatCompletionRequest): Promise<ChatCompletionResponse> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error(
|
||||
@ -41,8 +71,7 @@ export class AzureOpenAIProvider implements LLMProvider {
|
||||
);
|
||||
}
|
||||
|
||||
const base = this.config.endpoint.replace(/\/+$/, '');
|
||||
const url = `${base}/openai/deployments/${encodeURIComponent(this.config.deployment)}/chat/completions?api-version=${encodeURIComponent(this.config.apiVersion!)}`;
|
||||
const url = this.getChatUrl();
|
||||
|
||||
const body = {
|
||||
messages: req.messages,
|
||||
@ -55,10 +84,7 @@ export class AzureOpenAIProvider implements LLMProvider {
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'api-key': this.config.apiKey,
|
||||
},
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
@ -85,4 +111,116 @@ export class AzureOpenAIProvider implements LLMProvider {
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async *chatCompletionStream(req: ChatCompletionRequest): AsyncIterable<string> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error(
|
||||
'Azure OpenAI is not configured (missing AZURE_OPENAI_ENDPOINT or AZURE_OPENAI_KEY)'
|
||||
);
|
||||
}
|
||||
|
||||
const url = this.getChatUrl();
|
||||
|
||||
const body = {
|
||||
messages: req.messages,
|
||||
temperature: req.temperature,
|
||||
max_tokens: req.maxTokens,
|
||||
top_p: req.topP,
|
||||
stop: req.stop,
|
||||
response_format: req.responseFormat,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Azure OpenAI streaming error ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Azure OpenAI streaming: no response body');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() ?? '';
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue;
|
||||
const data = trimmed.slice(6);
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const parsed = JSON.parse(data) as {
|
||||
choices: Array<{ delta: { content?: string } }>;
|
||||
};
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
if (delta) yield delta;
|
||||
} catch {
|
||||
// skip malformed SSE chunks
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
async embed(req: EmbeddingRequest): Promise<EmbeddingResponse> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error(
|
||||
'Azure OpenAI is not configured (missing AZURE_OPENAI_ENDPOINT or AZURE_OPENAI_KEY)'
|
||||
);
|
||||
}
|
||||
|
||||
const base = this.getBaseUrl();
|
||||
const deployment = encodeURIComponent(this.config.embeddingDeployment!);
|
||||
const version = encodeURIComponent(this.config.apiVersion!);
|
||||
const url = `${base}/openai/deployments/${deployment}/embeddings?api-version=${version}`;
|
||||
|
||||
const body = {
|
||||
input: req.input,
|
||||
};
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Azure OpenAI embedding error ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
data: Array<{ embedding: number[]; index: number }>;
|
||||
model: string;
|
||||
usage: { prompt_tokens: number; total_tokens: number };
|
||||
};
|
||||
|
||||
return {
|
||||
embeddings: data.data.sort((a, b) => a.index - b.index).map(d => d.embedding),
|
||||
model: data.model,
|
||||
usage: {
|
||||
promptTokens: data.usage.prompt_tokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: data.usage.total_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import type {
|
||||
ChatMessage,
|
||||
LLMProvider,
|
||||
} from '../types.js';
|
||||
import { getMessageText } from '../types.js';
|
||||
|
||||
export interface GeminiConfig {
|
||||
apiKey: string;
|
||||
@ -102,13 +103,13 @@ export class GeminiProvider implements LLMProvider {
|
||||
contents: GeminiContent[];
|
||||
} {
|
||||
const systemMessages = messages.filter(m => m.role === 'system');
|
||||
const systemInstruction = systemMessages.map(m => m.content).join('\n') || null;
|
||||
const systemInstruction = systemMessages.map(m => getMessageText(m)).join('\n') || null;
|
||||
|
||||
const contents: GeminiContent[] = messages
|
||||
.filter(m => m.role !== 'system')
|
||||
.map(m => ({
|
||||
role: m.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: m.content }],
|
||||
role: (m.role === 'assistant' ? 'model' : 'user') as 'user' | 'model',
|
||||
parts: [{ text: getMessageText(m) }],
|
||||
}));
|
||||
|
||||
// Gemini requires at least one user turn
|
||||
|
||||
@ -2,13 +2,23 @@
|
||||
* Mock LLM provider — for testing.
|
||||
*
|
||||
* Returns pre-configured responses or a default echo response.
|
||||
* Supports vision content, streaming, and embedding.
|
||||
*/
|
||||
|
||||
import type { ChatCompletionRequest, ChatCompletionResponse, LLMProvider } from '../types.js';
|
||||
import type {
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
LLMProvider,
|
||||
} from '../types.js';
|
||||
import { getMessageText } from '../types.js';
|
||||
|
||||
export class MockLLMProvider implements LLMProvider {
|
||||
private responses: ChatCompletionResponse[] = [];
|
||||
private embeddingResponses: EmbeddingResponse[] = [];
|
||||
public calls: ChatCompletionRequest[] = [];
|
||||
public embedCalls: EmbeddingRequest[] = [];
|
||||
|
||||
constructor(responses?: ChatCompletionResponse[]) {
|
||||
if (responses) this.responses = [...responses];
|
||||
@ -18,11 +28,16 @@ export class MockLLMProvider implements LLMProvider {
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Add a response to the queue. */
|
||||
/** Add a chat response to the queue. */
|
||||
addResponse(response: ChatCompletionResponse): void {
|
||||
this.responses.push(response);
|
||||
}
|
||||
|
||||
/** Add an embedding response to the queue. */
|
||||
addEmbeddingResponse(response: EmbeddingResponse): void {
|
||||
this.embeddingResponses.push(response);
|
||||
}
|
||||
|
||||
async chatCompletion(req: ChatCompletionRequest): Promise<ChatCompletionResponse> {
|
||||
this.calls.push(req);
|
||||
|
||||
@ -30,19 +45,74 @@ export class MockLLMProvider implements LLMProvider {
|
||||
return this.responses.shift()!;
|
||||
}
|
||||
|
||||
// Default echo response
|
||||
// Default echo response — handles both string and multipart content
|
||||
const lastMessage = req.messages[req.messages.length - 1];
|
||||
const text = lastMessage ? getMessageText(lastMessage) : '(empty)';
|
||||
return {
|
||||
content: `Mock response to: ${lastMessage?.content ?? '(empty)'}`,
|
||||
content: `Mock response to: ${text}`,
|
||||
model: req.model ?? 'mock-model',
|
||||
finishReason: 'stop',
|
||||
usage: { promptTokens: 10, completionTokens: 10, totalTokens: 20 },
|
||||
};
|
||||
}
|
||||
|
||||
async *chatCompletionStream(req: ChatCompletionRequest): AsyncIterable<string> {
|
||||
this.calls.push(req);
|
||||
|
||||
if (this.responses.length > 0) {
|
||||
const resp = this.responses.shift()!;
|
||||
// Yield word-by-word to simulate streaming
|
||||
const words = resp.content.split(' ');
|
||||
for (const word of words) {
|
||||
yield word + ' ';
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const lastMessage = req.messages[req.messages.length - 1];
|
||||
const text = lastMessage ? getMessageText(lastMessage) : '(empty)';
|
||||
const words = `Mock response to: ${text}`.split(' ');
|
||||
for (const word of words) {
|
||||
yield word + ' ';
|
||||
}
|
||||
}
|
||||
|
||||
async embed(req: EmbeddingRequest): Promise<EmbeddingResponse> {
|
||||
this.embedCalls.push(req);
|
||||
|
||||
if (this.embeddingResponses.length > 0) {
|
||||
return this.embeddingResponses.shift()!;
|
||||
}
|
||||
|
||||
// Default: return deterministic pseudo-embeddings (dimension 8 for testing)
|
||||
const inputs = Array.isArray(req.input) ? req.input : [req.input];
|
||||
const embeddings = inputs.map(text => {
|
||||
// Simple hash-based deterministic vector for testing
|
||||
const vec = new Array(8).fill(0);
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
vec[i % 8] += text.charCodeAt(i) / 1000;
|
||||
}
|
||||
// Normalize
|
||||
const mag = Math.sqrt(vec.reduce((sum, v) => sum + v * v, 0)) || 1;
|
||||
return vec.map(v => v / mag);
|
||||
});
|
||||
|
||||
return {
|
||||
embeddings,
|
||||
model: req.model ?? 'mock-embedding-model',
|
||||
usage: {
|
||||
promptTokens: inputs.join(' ').split(/\s+/).length,
|
||||
completionTokens: 0,
|
||||
totalTokens: inputs.join(' ').split(/\s+/).length,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/** Reset call history and responses. */
|
||||
reset(): void {
|
||||
this.calls = [];
|
||||
this.embedCalls = [];
|
||||
this.responses = [];
|
||||
this.embeddingResponses = [];
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,14 +3,22 @@
|
||||
*
|
||||
* Uses OpenAI REST API with Bearer token authentication.
|
||||
* Reads config from OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL.
|
||||
* Supports text, vision (multipart content), streaming, and embeddings.
|
||||
*/
|
||||
|
||||
import type { ChatCompletionRequest, ChatCompletionResponse, LLMProvider } from '../types.js';
|
||||
import type {
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
LLMProvider,
|
||||
} from '../types.js';
|
||||
|
||||
export interface OpenAIConfig {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
model?: string;
|
||||
embeddingModel?: string;
|
||||
}
|
||||
|
||||
export class OpenAIProvider implements LLMProvider {
|
||||
@ -21,6 +29,8 @@ export class OpenAIProvider implements LLMProvider {
|
||||
apiKey: config?.apiKey || process.env.OPENAI_API_KEY || '',
|
||||
baseUrl: config?.baseUrl || process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1',
|
||||
model: config?.model || process.env.OPENAI_MODEL || 'gpt-4o-mini',
|
||||
embeddingModel:
|
||||
config?.embeddingModel || process.env.LLM_EMBEDDING_MODEL || 'text-embedding-3-small',
|
||||
};
|
||||
}
|
||||
|
||||
@ -28,13 +38,23 @@ export class OpenAIProvider implements LLMProvider {
|
||||
return Boolean(this.config.apiKey);
|
||||
}
|
||||
|
||||
private getBaseUrl(): string {
|
||||
return this.config.baseUrl!.replace(/\/+$/, '');
|
||||
}
|
||||
|
||||
private getHeaders(): Record<string, string> {
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
};
|
||||
}
|
||||
|
||||
async chatCompletion(req: ChatCompletionRequest): Promise<ChatCompletionResponse> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error('OpenAI is not configured (missing OPENAI_API_KEY)');
|
||||
}
|
||||
|
||||
const base = this.config.baseUrl!.replace(/\/+$/, '');
|
||||
const url = `${base}/chat/completions`;
|
||||
const url = `${this.getBaseUrl()}/chat/completions`;
|
||||
|
||||
const body = {
|
||||
model: req.model || this.config.model,
|
||||
@ -48,10 +68,7 @@ export class OpenAIProvider implements LLMProvider {
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
},
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
@ -78,4 +95,111 @@ export class OpenAIProvider implements LLMProvider {
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async *chatCompletionStream(req: ChatCompletionRequest): AsyncIterable<string> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error('OpenAI is not configured (missing OPENAI_API_KEY)');
|
||||
}
|
||||
|
||||
const url = `${this.getBaseUrl()}/chat/completions`;
|
||||
|
||||
const body = {
|
||||
model: req.model || this.config.model,
|
||||
messages: req.messages,
|
||||
temperature: req.temperature,
|
||||
max_tokens: req.maxTokens,
|
||||
top_p: req.topP,
|
||||
stop: req.stop,
|
||||
response_format: req.responseFormat,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`OpenAI streaming error ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('OpenAI streaming: no response body');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() ?? '';
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue;
|
||||
const data = trimmed.slice(6);
|
||||
if (data === '[DONE]') return;
|
||||
try {
|
||||
const parsed = JSON.parse(data) as {
|
||||
choices: Array<{ delta: { content?: string } }>;
|
||||
};
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
if (delta) yield delta;
|
||||
} catch {
|
||||
// skip malformed SSE chunks
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
async embed(req: EmbeddingRequest): Promise<EmbeddingResponse> {
|
||||
if (!this.isConfigured()) {
|
||||
throw new Error('OpenAI is not configured (missing OPENAI_API_KEY)');
|
||||
}
|
||||
|
||||
const url = `${this.getBaseUrl()}/embeddings`;
|
||||
|
||||
const body = {
|
||||
model: req.model || this.config.embeddingModel,
|
||||
input: req.input,
|
||||
};
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: this.getHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`OpenAI embedding error ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
data: Array<{ embedding: number[]; index: number }>;
|
||||
model: string;
|
||||
usage: { prompt_tokens: number; total_tokens: number };
|
||||
};
|
||||
|
||||
return {
|
||||
embeddings: data.data.sort((a, b) => a.index - b.index).map(d => d.embedding),
|
||||
model: data.model,
|
||||
usage: {
|
||||
promptTokens: data.usage.prompt_tokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: data.usage.total_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,19 +3,53 @@
|
||||
*
|
||||
* Provides a unified chat completion API that works with
|
||||
* Azure OpenAI, OpenAI direct, or mock providers.
|
||||
* Supports text, vision (image), and embedding modalities.
|
||||
*/
|
||||
|
||||
// ── Content Parts (vision support) ────────────────────────────────
|
||||
|
||||
/** A text segment within a multipart message. */
|
||||
export interface TextContentPart {
|
||||
type: 'text';
|
||||
text: string;
|
||||
}
|
||||
|
||||
/** An image URL segment within a multipart message (vision). */
|
||||
export interface ImageUrlContentPart {
|
||||
type: 'image_url';
|
||||
image_url: { url: string; detail?: 'auto' | 'low' | 'high' };
|
||||
}
|
||||
|
||||
/** A single part of a multipart message — text or image. */
|
||||
export type ContentPart = TextContentPart | ImageUrlContentPart;
|
||||
|
||||
// ── Chat Messages ─────────────────────────────────────────────────
|
||||
|
||||
export interface ChatMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
/** Text string OR multipart content array (for vision messages). */
|
||||
content: string | ContentPart[];
|
||||
name?: string;
|
||||
}
|
||||
|
||||
// ── Provider Interface ────────────────────────────────────────────
|
||||
|
||||
export interface LLMProvider {
|
||||
/** Send a chat completion request. */
|
||||
chatCompletion(req: ChatCompletionRequest): Promise<ChatCompletionResponse>;
|
||||
|
||||
/** Stream a chat completion response. */
|
||||
/** Stream a chat completion response — yields content delta strings. */
|
||||
chatCompletionStream?(req: ChatCompletionRequest): AsyncIterable<string>;
|
||||
|
||||
/** Generate vector embeddings for input text(s). */
|
||||
embed?(req: EmbeddingRequest): Promise<EmbeddingResponse>;
|
||||
|
||||
/** Check if the provider is configured with valid credentials. */
|
||||
isConfigured(): boolean;
|
||||
}
|
||||
|
||||
// ── Chat Completion ───────────────────────────────────────────────
|
||||
|
||||
export interface ChatCompletionRequest {
|
||||
messages: ChatMessage[];
|
||||
model?: string;
|
||||
@ -26,12 +60,6 @@ export interface ChatCompletionRequest {
|
||||
responseFormat?: { type: 'text' | 'json_object' };
|
||||
}
|
||||
|
||||
export interface ChatMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
content: string;
|
||||
name?: string;
|
||||
}
|
||||
|
||||
export interface ChatCompletionResponse {
|
||||
content: string;
|
||||
model: string;
|
||||
@ -39,6 +67,24 @@ export interface ChatCompletionResponse {
|
||||
finishReason: 'stop' | 'length' | 'content_filter' | 'tool_calls' | null;
|
||||
}
|
||||
|
||||
// ── Embeddings ────────────────────────────────────────────────────
|
||||
|
||||
export interface EmbeddingRequest {
|
||||
/** One or more texts to embed. */
|
||||
input: string | string[];
|
||||
/** Override the default embedding model. */
|
||||
model?: string;
|
||||
}
|
||||
|
||||
export interface EmbeddingResponse {
|
||||
/** One embedding vector per input string. */
|
||||
embeddings: number[][];
|
||||
model: string;
|
||||
usage: TokenUsage;
|
||||
}
|
||||
|
||||
// ── Shared ────────────────────────────────────────────────────────
|
||||
|
||||
export interface TokenUsage {
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
@ -46,3 +92,40 @@ export interface TokenUsage {
|
||||
}
|
||||
|
||||
export type LLMProviderType = 'azure' | 'openai' | 'perplexity' | 'gemini' | 'mock';
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────
|
||||
|
||||
/** Type guard: does this message contain image content parts? */
|
||||
export function isVisionMessage(msg: ChatMessage): boolean {
|
||||
if (typeof msg.content === 'string') return false;
|
||||
return msg.content.some((p) => p.type === 'image_url');
|
||||
}
|
||||
|
||||
/** Does the request contain any vision (image) messages? */
|
||||
export function hasVisionContent(req: ChatCompletionRequest): boolean {
|
||||
return req.messages.some(isVisionMessage);
|
||||
}
|
||||
|
||||
/** Convenience builder for a user message with text + image. */
|
||||
export function buildVisionMessage(
|
||||
text: string,
|
||||
imageUrl: string,
|
||||
detail: 'auto' | 'low' | 'high' = 'auto',
|
||||
): ChatMessage {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text },
|
||||
{ type: 'image_url', image_url: { url: imageUrl, detail } },
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
/** Extract plain text from a ChatMessage content (string or multipart). */
|
||||
export function getMessageText(msg: ChatMessage): string {
|
||||
if (typeof msg.content === 'string') return msg.content;
|
||||
return msg.content
|
||||
.filter((p): p is TextContentPart => p.type === 'text')
|
||||
.map((p) => p.text)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user