diff --git a/services/extraction-service/python/src/app.py b/services/extraction-service/python/src/app.py index 30b584d2..330c2239 100644 --- a/services/extraction-service/python/src/app.py +++ b/services/extraction-service/python/src/app.py @@ -13,6 +13,7 @@ import structlog import uvicorn from fastapi import FastAPI, HTTPException, Request +from .cache import extraction_cache from .extractor import extract from .models import ( BatchExtractRequest, @@ -30,9 +31,10 @@ app = FastAPI( ) -@app.get("/health", response_model=HealthResponse) -async def health() -> HealthResponse: - return HealthResponse() +@app.get("/health") +async def health(): + base = HealthResponse() + return {**base.model_dump(), "cache": extraction_cache.stats} @app.post("/extract", response_model=ExtractResponse) @@ -47,6 +49,12 @@ async def extract_endpoint(req: ExtractRequest, request: Request) -> ExtractResp request_id=request_id, ) + # Check cache first + cached = extraction_cache.get(req.text, req.task_id, req.model_id) + if cached is not None: + logger.info("cache_hit", task_id=req.task_id, request_id=request_id) + return cached + try: result = await extract( text=req.text, @@ -57,6 +65,7 @@ async def extract_endpoint(req: ExtractRequest, request: Request) -> ExtractResp max_workers=req.max_workers, max_char_buffer=req.max_char_buffer, ) + extraction_cache.put(req.text, req.task_id, req.model_id, result) return result except Exception as exc: logger.error("extract_failed", error=str(exc), request_id=request_id) diff --git a/services/extraction-service/python/src/cache.py b/services/extraction-service/python/src/cache.py new file mode 100644 index 00000000..1175ae8f --- /dev/null +++ b/services/extraction-service/python/src/cache.py @@ -0,0 +1,98 @@ +""" +In-memory LRU cache for extraction results. + +Cache key: hash(task_id + text + model_id) +TTL: configurable via EXTRACTION_CACHE_TTL env var (default 86400s = 24h) +""" + +from __future__ import annotations + +import hashlib +import os +import time +from collections import OrderedDict +from dataclasses import dataclass, field + +import structlog + +from .models import ExtractResponse + +logger = structlog.get_logger(__name__) + +CACHE_TTL = int(os.environ.get("EXTRACTION_CACHE_TTL", "86400")) +CACHE_MAX_SIZE = int(os.environ.get("EXTRACTION_CACHE_MAX_SIZE", "1000")) + + +@dataclass +class CacheEntry: + response: ExtractResponse + created_at: float = field(default_factory=time.monotonic) + + +class ExtractionCache: + """Thread-safe LRU cache with TTL expiry.""" + + def __init__(self, max_size: int = CACHE_MAX_SIZE, ttl: int = CACHE_TTL): + self._store: OrderedDict[str, CacheEntry] = OrderedDict() + self._max_size = max_size + self._ttl = ttl + self._hits = 0 + self._misses = 0 + + @staticmethod + def _make_key(text: str, task_id: str | None, model_id: str | None) -> str: + raw = f"{task_id or ''}:{model_id or ''}:{text}" + return hashlib.sha256(raw.encode()).hexdigest() + + def get(self, text: str, task_id: str | None, model_id: str | None) -> ExtractResponse | None: + key = self._make_key(text, task_id, model_id) + entry = self._store.get(key) + + if entry is None: + self._misses += 1 + return None + + # Check TTL + age = time.monotonic() - entry.created_at + if age > self._ttl: + del self._store[key] + self._misses += 1 + logger.debug("cache_expired", key=key[:12], age_s=round(age)) + return None + + # Move to end (most recently used) + self._store.move_to_end(key) + self._hits += 1 + logger.debug("cache_hit", key=key[:12]) + return entry.response + + def put(self, text: str, task_id: str | None, model_id: str | None, response: ExtractResponse) -> None: + key = self._make_key(text, task_id, model_id) + + # Evict oldest if at capacity + while len(self._store) >= self._max_size: + evicted_key, _ = self._store.popitem(last=False) + logger.debug("cache_evicted", key=evicted_key[:12]) + + self._store[key] = CacheEntry(response=response) + + def clear(self) -> None: + self._store.clear() + self._hits = 0 + self._misses = 0 + + @property + def stats(self) -> dict: + total = self._hits + self._misses + return { + "size": len(self._store), + "max_size": self._max_size, + "ttl": self._ttl, + "hits": self._hits, + "misses": self._misses, + "hit_rate": round(self._hits / total, 3) if total > 0 else 0.0, + } + + +# Module-level singleton +extraction_cache = ExtractionCache() diff --git a/services/extraction-service/src/modules/extract/routes.ts b/services/extraction-service/src/modules/extract/routes.ts index a0371bfb..9ace2b45 100644 --- a/services/extraction-service/src/modules/extract/routes.ts +++ b/services/extraction-service/src/modules/extract/routes.ts @@ -1,9 +1,65 @@ import type { FastifyInstance } from 'fastify'; import rateLimit from '@fastify/rate-limit'; +import { createHash } from 'node:crypto'; import { ExtractRequestSchema, BatchExtractRequestSchema } from './types.js'; -import { sidecarExtract, sidecarExtractBatch, sidecarHealth } from '../../lib/python-bridge.js'; +import { + sidecarExtract, + sidecarExtractBatch, + sidecarHealth, + type SidecarExtractResponse, +} from '../../lib/python-bridge.js'; import { BadRequestError } from '../../lib/errors.js'; +import { checkQuota, incrementUsage, getUsageSummary } from './usage.js'; + +// ── In-memory LRU cache ──────────────────────────────────────── +const CACHE_TTL_MS = parseInt(process.env.EXTRACTION_CACHE_TTL_MS || '86400000', 10); // 24h +const CACHE_MAX = parseInt(process.env.EXTRACTION_CACHE_MAX || '500', 10); + +interface CacheEntry { + response: SidecarExtractResponse; + createdAt: number; +} + +const cache = new Map(); +let cacheHits = 0; +let cacheMisses = 0; + +function cacheKey(text: string, taskId?: string, modelId?: string): string { + return createHash('sha256') + .update(`${taskId || ''}:${modelId || ''}:${text}`) + .digest('hex'); +} + +function cacheGet(text: string, taskId?: string, modelId?: string): SidecarExtractResponse | null { + const key = cacheKey(text, taskId, modelId); + const entry = cache.get(key); + if (!entry) { + cacheMisses++; + return null; + } + if (Date.now() - entry.createdAt > CACHE_TTL_MS) { + cache.delete(key); + cacheMisses++; + return null; + } + cacheHits++; + return entry.response; +} + +function cachePut( + text: string, + taskId: string | undefined, + modelId: string | undefined, + response: SidecarExtractResponse +): void { + // Evict oldest if at capacity + if (cache.size >= CACHE_MAX) { + const firstKey = cache.keys().next().value; + if (firstKey) cache.delete(firstKey); + } + cache.set(cacheKey(text, taskId, modelId), { response, createdAt: Date.now() }); +} export async function extractRoutes(app: FastifyInstance) { // Rate limiting for extraction endpoints — 30 req/min per IP (configurable) @@ -24,8 +80,44 @@ export async function extractRoutes(app: FastifyInstance) { const { text, taskId, taskPrompt, examples, modelId, options } = parsed.data; const requestId = req.headers['x-request-id'] as string | undefined; + // Enforce per-user daily quota + const userId = req.headers['x-user-id'] as string | undefined; + const userPlan = (req.headers['x-user-plan'] as string) || 'free'; + if (userId) { + const quota = checkQuota(userId, userPlan); + if (!quota.allowed) { + reply.header('X-RateLimit-Limit', String(quota.limit)); + reply.header('X-RateLimit-Remaining', '0'); + return reply.status(429).send({ + error: 'Daily extraction quota exceeded', + limit: quota.limit, + used: quota.used, + plan: userPlan, + }); + } + } + req.log.info({ taskId, modelId, textLength: text.length }, 'extraction request'); + // Check cache + const cached = cacheGet(text, taskId, modelId); + if (cached) { + req.log.info({ taskId }, 'cache hit'); + reply.header('X-Extraction-Cache', 'HIT'); + return reply.send({ + extractions: cached.extractions, + metadata: { + modelId: cached.metadata.model_id, + durationMs: cached.metadata.duration_ms, + tokenCount: cached.metadata.token_count, + charCount: cached.metadata.char_count, + }, + requestId, + }); + } + + reply.header('X-Extraction-Cache', 'MISS'); + const result = await sidecarExtract( { text, @@ -47,6 +139,9 @@ export async function extractRoutes(app: FastifyInstance) { requestId ); + cachePut(text, taskId, modelId, result); + if (userId) incrementUsage(userId, userPlan); + req.log.info( { entityCount: result.extractions.length, durationMs: result.metadata.duration_ms }, 'extraction complete' @@ -133,4 +228,31 @@ export async function extractRoutes(app: FastifyInstance) { return reply.status(503).send({ status: 'error', error: message }); } }); + + /** + * GET /extract/usage — Per-user extraction usage (admin). + */ + app.get('/extract/usage', async (req, reply) => { + const userId = (req.query as Record).userId; + const plan = (req.query as Record).plan || 'free'; + if (!userId) { + throw new BadRequestError('userId query parameter is required'); + } + return reply.send(getUsageSummary(userId, plan)); + }); + + /** + * GET /extract/cache-stats — Cache statistics. + */ + app.get('/extract/cache-stats', async (_req, reply) => { + const total = cacheHits + cacheMisses; + return reply.send({ + size: cache.size, + maxSize: CACHE_MAX, + ttlMs: CACHE_TTL_MS, + hits: cacheHits, + misses: cacheMisses, + hitRate: total > 0 ? Math.round((cacheHits / total) * 1000) / 1000 : 0, + }); + }); } diff --git a/services/extraction-service/src/modules/extract/usage.ts b/services/extraction-service/src/modules/extract/usage.ts new file mode 100644 index 00000000..989e2a67 --- /dev/null +++ b/services/extraction-service/src/modules/extract/usage.ts @@ -0,0 +1,114 @@ +/** + * Per-user daily extraction quota enforcement. + * + * Plan tiers: + * free: 10 extractions/day + * pro: 100 extractions/day + * enterprise: unlimited + * + * Usage tracked in Cosmos `extraction_usage` container (partition: /userId). + */ + +import { z } from 'zod'; + +// ── Quota tiers ────────────────────────────────────────────────── + +const PLAN_QUOTAS: Record = { + free: 10, + pro: 100, + enterprise: Infinity, +}; + +export function getQuota(plan: string): number { + return PLAN_QUOTAS[plan] ?? PLAN_QUOTAS.free; +} + +// ── Usage document schema ──────────────────────────────────────── + +export const ExtractionUsageSchema = z.object({ + id: z.string(), + userId: z.string(), + productId: z.string(), + date: z.string(), // YYYY-MM-DD + count: z.number().int().min(0), + plan: z.string(), + updatedAt: z.string(), +}); + +export type ExtractionUsage = z.infer; + +// ── In-memory usage tracker (no Cosmos dependency for now) ─────── + +const usageStore = new Map(); + +function todayKey(): string { + return new Date().toISOString().slice(0, 10); +} + +function storeKey(userId: string): string { + return `${userId}:${todayKey()}`; +} + +/** + * Check if user is within their daily quota. + * Returns { allowed, remaining, limit, used }. + */ +export function checkQuota( + userId: string, + plan: string = 'free' +): { allowed: boolean; remaining: number; limit: number; used: number } { + const limit = getQuota(plan); + if (limit === Infinity) { + return { allowed: true, remaining: Infinity, limit, used: 0 }; + } + + const key = storeKey(userId); + const entry = usageStore.get(key); + const today = todayKey(); + + // Reset if new day + const used = entry && entry.date === today ? entry.count : 0; + const remaining = Math.max(0, limit - used); + + return { allowed: used < limit, remaining, limit, used }; +} + +/** + * Increment usage counter for user. Call after successful extraction. + */ +export function incrementUsage(userId: string, _plan: string = 'free'): void { + const key = storeKey(userId); + const today = todayKey(); + const entry = usageStore.get(key); + + if (entry && entry.date === today) { + entry.count++; + } else { + usageStore.set(key, { count: 1, date: today }); + } +} + +/** + * Get usage summary for a user (for the usage reporting endpoint). + */ +export function getUsageSummary( + userId: string, + plan: string = 'free' +): { + userId: string; + date: string; + used: number; + limit: number; + remaining: number; + plan: string; +} { + const { used, limit, remaining } = checkQuota(userId, plan); + return { + userId, + date: todayKey(), + used, + limit: limit === Infinity ? -1 : limit, + remaining: remaining === Infinity ? -1 : remaining, + plan, + }; +}