learning_ai_common_plat/services/extraction-service/python/src/cache.py
saravanakumardb1 9c8a3169dc feat(extraction): Phase 5 caching + cost controls (5.1-5.6)
- 5.1: Python sidecar LRU cache (cache.py) with configurable TTL + max size
- 5.2: Fastify-level cache with X-Extraction-Cache HIT/MISS header + /extract/cache-stats
- 5.3-5.5: Per-user daily quota (free=10, pro=100, enterprise=unlimited) with 429 response
- 5.6: GET /extract/usage endpoint for admin usage reporting
- Both Python + TS caches use sha256(taskId:modelId:text) keys
- 46 TS tests + 29 Python tests still passing
2026-02-14 14:02:21 -08:00

99 lines
2.8 KiB
Python

"""
In-memory LRU cache for extraction results.
Cache key: hash(task_id + text + model_id)
TTL: configurable via EXTRACTION_CACHE_TTL env var (default 86400s = 24h)
"""
from __future__ import annotations
import hashlib
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, field
import structlog
from .models import ExtractResponse
logger = structlog.get_logger(__name__)
CACHE_TTL = int(os.environ.get("EXTRACTION_CACHE_TTL", "86400"))
CACHE_MAX_SIZE = int(os.environ.get("EXTRACTION_CACHE_MAX_SIZE", "1000"))
@dataclass
class CacheEntry:
response: ExtractResponse
created_at: float = field(default_factory=time.monotonic)
class ExtractionCache:
"""Thread-safe LRU cache with TTL expiry."""
def __init__(self, max_size: int = CACHE_MAX_SIZE, ttl: int = CACHE_TTL):
self._store: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_size = max_size
self._ttl = ttl
self._hits = 0
self._misses = 0
@staticmethod
def _make_key(text: str, task_id: str | None, model_id: str | None) -> str:
raw = f"{task_id or ''}:{model_id or ''}:{text}"
return hashlib.sha256(raw.encode()).hexdigest()
def get(self, text: str, task_id: str | None, model_id: str | None) -> ExtractResponse | None:
key = self._make_key(text, task_id, model_id)
entry = self._store.get(key)
if entry is None:
self._misses += 1
return None
# Check TTL
age = time.monotonic() - entry.created_at
if age > self._ttl:
del self._store[key]
self._misses += 1
logger.debug("cache_expired", key=key[:12], age_s=round(age))
return None
# Move to end (most recently used)
self._store.move_to_end(key)
self._hits += 1
logger.debug("cache_hit", key=key[:12])
return entry.response
def put(self, text: str, task_id: str | None, model_id: str | None, response: ExtractResponse) -> None:
key = self._make_key(text, task_id, model_id)
# Evict oldest if at capacity
while len(self._store) >= self._max_size:
evicted_key, _ = self._store.popitem(last=False)
logger.debug("cache_evicted", key=evicted_key[:12])
self._store[key] = CacheEntry(response=response)
def clear(self) -> None:
self._store.clear()
self._hits = 0
self._misses = 0
@property
def stats(self) -> dict:
total = self._hits + self._misses
return {
"size": len(self._store),
"max_size": self._max_size,
"ttl": self._ttl,
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(self._hits / total, 3) if total > 0 else 0.0,
}
# Module-level singleton
extraction_cache = ExtractionCache()