diff --git a/README.md b/README.md index f844d5b..45937c1 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,22 @@ Pytest marks warnings as errors; update fixtures or add targeted `filterwarnings ## WhatsApp & Admin Endpoints - **Webhook:** `POST /meta-whatsapp` (signature verification + LangGraph processing). Verification handshake uses `GET /meta-whatsapp`. - **Progress messaging:** Status texts sourced from `bt_servant_engine.services.status_messages`. -- **Admin API:** See `bt_servant_engine.apps.api.routes.admin` for vector store maintenance (collection merges, document management) secured via bearer token headers when `ENABLE_ADMIN_AUTH=True`. +- **Admin API:** See `bt_servant_engine.apps.api.routes.admin` for vector store maintenance (collection merges, document management) secured via bearer token headers when `ENABLE_ADMIN_AUTH=True`. Cache controls are exposed here as well: + - `POST /cache/clear` wipes every cache namespace. + - `POST /cache/{name}/clear` clears an individual cache (e.g., `passage_summary`). + - `GET /cache/stats` reports global cache settings, hit/miss counters, and disk usage. + - `GET /cache/{name}?sample_limit=10` inspects a specific cache with recent entry metadata. + - Both clear endpoints accept `older_than_days=` to prune only entries older than the cutoff instead of nuking everything. + +--- + +## Cache Configuration +- Defaults: caching is enabled with a disk backend under `${DATA_DIR}/cache`, entries never expire (`TTL=-1`), and a 500 MB cap (per cache) enforced by environment variables. +- Toggle or tune via env settings (see `bt_servant_engine/core/config.py`): + - `CACHE_ENABLED`, `CACHE_BACKEND` (`disk` | `memory`), `CACHE_DISK_MAX_BYTES` + - Per-cache toggles (defaults in parentheses): `CACHE_SELECTION_ENABLED` (`false`), `CACHE_SUMMARY_ENABLED` (`true`), `CACHE_KEYWORDS_ENABLED` (`true`), `CACHE_TRANSLATION_HELPS_ENABLED` (`true`), `CACHE_RAG_VECTOR_ENABLED` (`false`), `CACHE_RAG_FINAL_ENABLED` (`false`) + - Per-cache TTL/size controls: `CACHE_SELECTION_TTL_SECONDS`, `CACHE_SUMMARY_TTL_SECONDS`, `CACHE_TRANSLATION_HELPS_TTL_SECONDS`, etc. (set to `-1` for no expiry) +- Admin endpoints (above) can purge or inspect caches without redeploying; deleting the cache directory in `${DATA_DIR}/cache` also resets disk stores. --- diff --git a/bt_servant_engine/apps/api/routes/admin.py b/bt_servant_engine/apps/api/routes/admin.py index a64f1bd..1ad5b34 100644 --- a/bt_servant_engine/apps/api/routes/admin.py +++ b/bt_servant_engine/apps/api/routes/admin.py @@ -25,6 +25,7 @@ ) from bt_servant_engine.core.logging import get_logger from bt_servant_engine.services import runtime +from bt_servant_engine.services.cache_manager import cache_manager router = APIRouter() logger = get_logger(__name__) @@ -33,6 +34,8 @@ os.environ.get("OPENAI_EMBED_MAX_TOKENS_PER_REQUEST", "290000") ) +MAX_CACHE_SAMPLE_LIMIT = 100 + # Re-export merge helpers for compatibility with existing references. iter_collection_batches = merge_helpers.iter_collection_batches estimate_tokens = merge_helpers.estimate_tokens @@ -870,3 +873,114 @@ async def chroma_catch_all(_path: str, _: None = Depends(require_admin_token)): __all__ = ["router"] + + +@router.post("/cache/clear") +async def clear_all_caches( + older_than_days: float | None = None, + _: None = Depends(require_admin_token), +): + """Clear all registered caches.""" + if older_than_days is not None: + if older_than_days <= 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "older_than_days must be greater than 0"}, + ) + cutoff = time.time() - older_than_days * 86400 + logger.info( + "[cache-admin] pruning all caches older than %.2f days (cutoff=%s)", + older_than_days, + cutoff, + ) + removed = cache_manager.prune_all(cutoff) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "status": "pruned", + "cutoff_epoch": cutoff, + "removed": removed, + }, + ) + logger.info("[cache-admin] clearing all caches via admin endpoint") + cache_manager.clear_all() + return JSONResponse(status_code=status.HTTP_200_OK, content={"status": "cleared"}) + + +@router.post("/cache/{name}/clear") +async def clear_named_cache( + name: str, + older_than_days: float | None = None, + _: None = Depends(require_admin_token), +): + """Clear a specific cache namespace.""" + if older_than_days is not None: + if older_than_days <= 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "older_than_days must be greater than 0"}, + ) + cutoff = time.time() - older_than_days * 86400 + try: + removed = cache_manager.prune_cache(name, cutoff) + except KeyError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Cache '{name}' not found"}, + ) from exc + logger.info( + "[cache-admin] pruned cache %s older than %.2f days (cutoff=%s, removed=%d)", + name, + older_than_days, + cutoff, + removed, + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "status": "pruned", + "cache": name, + "cutoff_epoch": cutoff, + "removed": removed, + }, + ) + try: + cache_manager.clear_cache(name) + except KeyError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail={"error": f"Cache '{name}' not found"} + ) from exc + logger.info("[cache-admin] cleared cache %s via admin endpoint", name) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "cleared", "cache": name}, + ) + + +@router.get("/cache/stats") +async def get_cache_stats(_: None = Depends(require_admin_token)): + """Return summary stats for all caches.""" + data = cache_manager.stats() + return JSONResponse(status_code=status.HTTP_200_OK, content=data) + + +@router.get("/cache/{name}") +async def inspect_cache(name: str, sample_limit: int = 10, _: None = Depends(require_admin_token)): + """Return detailed stats and samples for a specific cache.""" + if sample_limit <= 0 or sample_limit > MAX_CACHE_SAMPLE_LIMIT: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": f"sample_limit must be between 1 and {MAX_CACHE_SAMPLE_LIMIT}", + }, + ) + try: + cache = cache_manager.cache(name) + except KeyError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail={"error": f"Cache '{name}' not found"} + ) from exc + logger.info("[cache-admin] inspecting cache %s (limit=%d)", name, sample_limit) + details = cache.detailed_stats(sample_limit=sample_limit) + details["sample_limit"] = sample_limit + return JSONResponse(status_code=status.HTTP_200_OK, content=details) diff --git a/bt_servant_engine/core/config.py b/bt_servant_engine/core/config.py index cc1258a..7bf0d99 100644 --- a/bt_servant_engine/core/config.py +++ b/bt_servant_engine/core/config.py @@ -43,6 +43,30 @@ class Settings(BaseSettings): PROGRESS_MESSAGE_EMOJI: str = Field(default="⏳") PROGRESS_MESSAGE_EMOJI_OVERRIDES: dict[str, str] = Field(default_factory=dict) + # Cache configuration + CACHE_ENABLED: bool = Field(default=True) + CACHE_BACKEND: Literal["disk", "memory"] = Field(default="disk") + CACHE_DISK_MAX_BYTES: int = Field(default=500 * 1024 * 1024) # 500MB + CACHE_DEFAULT_TTL_SECONDS: int = Field(default=-1) + CACHE_SELECTION_ENABLED: bool = Field(default=False) + CACHE_SELECTION_TTL_SECONDS: int = Field(default=-1) + CACHE_SELECTION_MAX_ENTRIES: int = Field(default=5000) + CACHE_SUMMARY_ENABLED: bool = Field(default=True) + CACHE_SUMMARY_TTL_SECONDS: int = Field(default=-1) + CACHE_SUMMARY_MAX_ENTRIES: int = Field(default=1500) + CACHE_KEYWORDS_ENABLED: bool = Field(default=True) + CACHE_KEYWORDS_TTL_SECONDS: int = Field(default=-1) + CACHE_KEYWORDS_MAX_ENTRIES: int = Field(default=3000) + CACHE_TRANSLATION_HELPS_ENABLED: bool = Field(default=True) + CACHE_TRANSLATION_HELPS_TTL_SECONDS: int = Field(default=-1) + CACHE_TRANSLATION_HELPS_MAX_ENTRIES: int = Field(default=1000) + CACHE_RAG_VECTOR_ENABLED: bool = Field(default=False) + CACHE_RAG_VECTOR_TTL_SECONDS: int = Field(default=-1) + CACHE_RAG_VECTOR_MAX_ENTRIES: int = Field(default=3000) + CACHE_RAG_FINAL_ENABLED: bool = Field(default=False) + CACHE_RAG_FINAL_TTL_SECONDS: int = Field(default=-1) + CACHE_RAG_FINAL_MAX_ENTRIES: int = Field(default=1500) + DATA_DIR: Path = Field(default=Path("/data")) OPENAI_PRICING_JSON: str = Field( default=( diff --git a/bt_servant_engine/services/cache_manager.py b/bt_servant_engine/services/cache_manager.py new file mode 100644 index 0000000..90d6b23 --- /dev/null +++ b/bt_servant_engine/services/cache_manager.py @@ -0,0 +1,650 @@ +"""Caching infrastructure for intent handlers with disk and memory backends.""" + +# pylint: disable=missing-function-docstring,too-many-instance-attributes,too-many-locals,too-many-branches,redefined-outer-name + +from __future__ import annotations + +import hashlib +import json +import shutil +import threading +import time +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable + +from bt_servant_engine.core.config import config +from bt_servant_engine.core.intents import IntentType +from bt_servant_engine.core.logging import get_logger +from utils.perf import get_current_trace, record_external_span + +logger = get_logger(__name__) + +CACHE_SCHEMA_VERSION = "v1.0.0-cache" +_INTENT_SENTINEL = "__cache_intent__" +_TUPLE_SENTINEL = "__cache_tuple__" + + +def _hash_key(key: Any) -> tuple[str, str]: + """Return a stable hash for arbitrary key objects.""" + key_repr = repr(key) + digest = hashlib.sha256(key_repr.encode("utf-8")).hexdigest() + return digest, key_repr + + +def _encode_payload(value: Any) -> bytes: + def _default(obj: Any) -> Any: + if isinstance(obj, IntentType): + return {_INTENT_SENTINEL: obj.value} + if isinstance(obj, tuple): + return {_TUPLE_SENTINEL: list(obj)} + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + text = json.dumps(value, default=_default, ensure_ascii=False, separators=(",", ":")) + return text.encode("utf-8") + + +def _decode_payload(data: bytes) -> Any: + def _object_hook(obj: dict[str, Any]) -> Any: + if len(obj) == 1 and _INTENT_SENTINEL in obj: + return IntentType(obj[_INTENT_SENTINEL]) + if len(obj) == 1 and _TUPLE_SENTINEL in obj: + return tuple(obj[_TUPLE_SENTINEL]) + intent_value = obj.get("intent") + if isinstance(intent_value, str): + try: + obj["intent"] = IntentType(intent_value) + except ValueError: + pass + return obj + + text = data.decode("utf-8") + return json.loads(text, object_hook=_object_hook) + + +def _perf_span(name: str, start: float, end: float) -> None: + trace = get_current_trace() + if trace: + record_external_span(name, start, end, trace_id=trace) + + +@dataclass(slots=True) +class CacheStats: + """Mutable counters for cache operations.""" + + hits: int = 0 + misses: int = 0 + stores: int = 0 + evictions: int = 0 + + def as_dict(self) -> dict[str, int]: + return { + "hits": self.hits, + "misses": self.misses, + "stores": self.stores, + "evictions": self.evictions, + } + + +@dataclass(slots=True) +class CacheEntryMeta: + """Metadata describing a stored cache entry.""" + + key: str + key_repr: str + created_at: float + expires_at: float + last_access: float + size_bytes: int + + def is_expired(self, now: float) -> bool: + if self.expires_at < 0: + return False + return now >= self.expires_at + + +class MemoryCacheBackend: + """In-memory cache backend with LRU eviction.""" + + def __init__(self, max_entries: int | None) -> None: + self._lock = threading.RLock() + self._entries: "OrderedDict[str, tuple[Any, CacheEntryMeta]]" = OrderedDict() + self._max_entries = max_entries + + def get(self, key: str, now: float) -> tuple[Any, CacheEntryMeta] | None: + with self._lock: + entry = self._entries.get(key) + if entry is None: + return None + value, meta = entry + if meta.is_expired(now): + self._entries.pop(key, None) + return None + meta.last_access = now + self._entries.move_to_end(key) + return value, meta + + def set(self, key: str, value: Any, meta: CacheEntryMeta) -> None: + with self._lock: + self._entries[key] = (value, meta) + self._entries.move_to_end(key) + self._evict_if_needed() + + def delete(self, key: str) -> None: + with self._lock: + self._entries.pop(key, None) + + def clear(self) -> None: + with self._lock: + self._entries.clear() + + def entries(self) -> Iterable[CacheEntryMeta]: + with self._lock: + return [meta for _, meta in self._entries.values()] + + def prune_older_than(self, cutoff: float) -> list[CacheEntryMeta]: + removed: list[CacheEntryMeta] = [] + with self._lock: + keys = list(self._entries.keys()) + for key in keys: + _, meta = self._entries[key] + if meta.created_at < cutoff: + removed.append(meta) + self._entries.pop(key, None) + return removed + + def _evict_if_needed(self) -> list[CacheEntryMeta]: + evicted: list[CacheEntryMeta] = [] + if self._max_entries is None: + return evicted + while len(self._entries) > self._max_entries: + _, (_, meta) = self._entries.popitem(last=False) + evicted.append(meta) + return evicted + + +class DiskCacheBackend: + """Disk-backed cache storing serialized values with manifest bookkeeping.""" + + def __init__(self, cache_dir: Path, max_bytes: int) -> None: + self._cache_dir = cache_dir + self._manifest_path = cache_dir / "manifest.json" + self._lock = threading.RLock() + self._max_bytes = max_bytes + self._entries: Dict[str, CacheEntryMeta] = {} + self._total_bytes = 0 + self._load_manifest() + + def _load_manifest(self) -> None: + if not self._manifest_path.exists(): + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._persist_manifest() + return + try: + raw = json.loads(self._manifest_path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + logger.warning("[cache] manifest load failed for %s; rebuilding", self._manifest_path) + shutil.rmtree(self._cache_dir, ignore_errors=True) + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._persist_manifest() + return + entries = raw.get("entries", {}) + total_bytes = 0 + for key, meta in entries.items(): + path = self._cache_dir / f"{key}.json" + if not path.exists(): + continue + cem = CacheEntryMeta( + key=key, + key_repr=meta.get("key_repr", key), + created_at=float(meta.get("created_at", time.time())), + expires_at=float(meta.get("expires_at", 0)), + last_access=float(meta.get("last_access", 0)), + size_bytes=int(meta.get("size_bytes", path.stat().st_size)), + ) + self._entries[key] = cem + total_bytes += cem.size_bytes + self._total_bytes = total_bytes + + def _persist_manifest(self) -> None: + data = { + "entries": { + key: { + "key_repr": meta.key_repr, + "created_at": meta.created_at, + "expires_at": meta.expires_at, + "last_access": meta.last_access, + "size_bytes": meta.size_bytes, + } + for key, meta in self._entries.items() + }, + "total_bytes": self._total_bytes, + } + tmp = self._manifest_path.with_suffix(".json.tmp") + tmp.write_text(json.dumps(data, indent=2), encoding="utf-8") + tmp.replace(self._manifest_path) + + def get(self, key: str, now: float) -> tuple[Any, CacheEntryMeta] | None: + with self._lock: + meta = self._entries.get(key) + if meta is None: + return None + if meta.is_expired(now): + self._delete_unlocked(key) + return None + path = self._cache_dir / f"{key}.json" + try: + with path.open("rb") as fh: + payload = fh.read() + value = _decode_payload(payload) + except (OSError, json.JSONDecodeError, ValueError, TypeError): + logger.warning("[cache] failed to load cache entry %s; removing", key) + self._delete_unlocked(key) + return None + meta.last_access = now + self._persist_manifest() + return value, meta + + def set(self, key: str, value: Any, meta: CacheEntryMeta) -> list[CacheEntryMeta]: + payload = _encode_payload(value) + meta.size_bytes = len(payload) + with self._lock: + self._cache_dir.mkdir(parents=True, exist_ok=True) + path = self._cache_dir / f"{key}.json" + tmp = path.with_suffix(".tmp") + with tmp.open("wb") as fh: + fh.write(payload) + tmp.replace(path) + self._entries[key] = meta + self._total_bytes = sum(entry.size_bytes for entry in self._entries.values()) + evicted = self._evict_if_needed_unlocked() + self._persist_manifest() + return evicted + + def delete(self, key: str) -> None: + with self._lock: + self._delete_unlocked(key) + self._persist_manifest() + + def _delete_unlocked(self, key: str) -> None: + meta = self._entries.pop(key, None) + if meta is None: + return + path = self._cache_dir / f"{key}.json" + try: + path.unlink() + except FileNotFoundError: + pass + self._total_bytes = max(0, self._total_bytes - meta.size_bytes) + + def clear(self) -> None: + with self._lock: + shutil.rmtree(self._cache_dir, ignore_errors=True) + self._entries.clear() + self._total_bytes = 0 + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._persist_manifest() + + def entries(self) -> Iterable[CacheEntryMeta]: + with self._lock: + return list(self._entries.values()) + + def total_bytes(self) -> int: + with self._lock: + return self._total_bytes + + def prune_older_than(self, cutoff: float) -> list[CacheEntryMeta]: + removed: list[CacheEntryMeta] = [] + with self._lock: + keys = list(self._entries.keys()) + for key in keys: + meta = self._entries.get(key) + if meta and meta.created_at < cutoff: + self._delete_unlocked(key) + removed.append(meta) + self._persist_manifest() + return removed + + def _evict_if_needed_unlocked(self) -> list[CacheEntryMeta]: + evicted: list[CacheEntryMeta] = [] + while self._total_bytes > self._max_bytes and self._entries: + key, meta = min(self._entries.items(), key=lambda item: item[1].last_access) + self._delete_unlocked(key) + evicted.append(meta) + return evicted + + +@dataclass(slots=True) +class CacheConfig: + """Configuration for an individual cache namespace.""" + + name: str + ttl_seconds: int + max_entries: int | None + + +class CacheStore: + """High-level cache interface with logging and perf instrumentation.""" + + def __init__( + self, + config: CacheConfig, + backend: MemoryCacheBackend | DiskCacheBackend, + *, + enabled: bool, + ) -> None: + self._config = config + self._backend = backend + self._enabled = enabled + self._stats = CacheStats() + self._lock = threading.RLock() + + @property + def name(self) -> str: + return self._config.name + + def get_or_set( + self, + key: Any, + compute: Callable[[], Any], + *, + should_store: Callable[[Any], bool] | None = None, + ) -> tuple[Any, bool]: + if not self._enabled: + value = compute() + return value, False + key_hash, key_repr = _hash_key(key) + now = time.time() + lookup_start = time.time() + cached = self._backend.get(key_hash, now) + lookup_end = time.time() + if cached: + self._stats.hits += 1 + _, meta = cached + logger.info( + "[cache] hit name=%s key=%s size=%s age=%.1fs", + self.name, + key_repr, + meta.size_bytes, + now - meta.created_at, + ) + _perf_span(f"cache_hit:{self.name}", lookup_start, lookup_end) + return cached[0], True + self._stats.misses += 1 + logger.info("[cache] miss name=%s key=%s", self.name, key_repr) + compute_start = time.time() + try: + value = compute() + finally: + compute_end = time.time() + _perf_span(f"cache_miss:{self.name}", compute_start, compute_end) + if should_store is not None and not should_store(value): + logger.info( + "[cache] store skipped name=%s key=%s (predicate)", + self.name, + key_repr, + ) + return value, False + now = time.time() + expires_at = -1 if self._config.ttl_seconds < 0 else now + self._config.ttl_seconds + meta = CacheEntryMeta( + key=key_hash, + key_repr=key_repr, + created_at=now, + expires_at=expires_at, + last_access=now, + size_bytes=0, + ) + start_store = time.time() + evicted: list[CacheEntryMeta] = [] + if isinstance(self._backend, MemoryCacheBackend): + self._backend.set(key_hash, value, meta) + else: + evicted = self._backend.set(key_hash, value, meta) + store_duration = time.time() - start_store + self._stats.stores += 1 + self._stats.evictions += len(evicted) + if evicted: + logger.info( + "[cache] eviction name=%s count=%d keys=%s", + self.name, + len(evicted), + [m.key_repr for m in evicted], + ) + logger.info( + "[cache] store name=%s key=%s size=%s ttl=%ss", + self.name, + key_repr, + meta.size_bytes, + self._config.ttl_seconds, + ) + if store_duration: + _perf_span(f"cache_store:{self.name}", start_store, start_store + store_duration) + return value, False + + def clear(self) -> None: + logger.info("[cache] clear name=%s", self.name) + self._backend.clear() + self._stats = CacheStats() + + def stats(self) -> dict[str, Any]: + entries = self._sorted_entries() + newest = max((e.created_at for e in entries), default=None) + oldest = min((e.created_at for e in entries), default=None) + bytes_used = 0 + if isinstance(self._backend, DiskCacheBackend): + bytes_used = self._backend.total_bytes() + return { + "name": self.name, + "enabled": self._enabled, + "ttl_seconds": self._config.ttl_seconds, + "max_entries": self._config.max_entries, + "entry_count": len(entries), + "bytes_used": bytes_used, + "oldest_entry_epoch": oldest, + "newest_entry_epoch": newest, + "stats": self._stats.as_dict(), + } + + def detailed_stats(self, sample_limit: int = 10) -> dict[str, Any]: + data = self.stats() + entries = self._sorted_entries() + samples = [] + now = time.time() + for meta in entries[:sample_limit]: + ttl_remaining = float("inf") if meta.expires_at < 0 else max(0.0, meta.expires_at - now) + samples.append( + { + "key_repr": meta.key_repr, + "size_bytes": meta.size_bytes, + "created_at": meta.created_at, + "expires_at": meta.expires_at, + "last_access": meta.last_access, + "age_seconds": max(0.0, now - meta.created_at), + "ttl_remaining": ttl_remaining, + } + ) + data["samples"] = samples + return data + + def _sorted_entries(self) -> list[CacheEntryMeta]: + entries = list(self._backend.entries()) + entries.sort(key=lambda e: e.last_access, reverse=True) + return entries + + def prune_older_than(self, cutoff: float) -> int: + if isinstance(self._backend, MemoryCacheBackend): + removed = self._backend.prune_older_than(cutoff) + elif isinstance(self._backend, DiskCacheBackend): + removed = self._backend.prune_older_than(cutoff) + else: + removed = [] + count = len(removed) + if count: + self._stats.evictions += count + logger.info( + "[cache] prune name=%s removed=%d cutoff=%s", + self.name, + count, + cutoff, + ) + return count + + +class CacheManager: + """Registry of caches shared across the app.""" + + def __init__( + self, + *, + enabled: bool, + backend_type: str, + disk_root: Path, + disk_max_bytes: int, + ) -> None: + self._enabled = enabled + self._backend_type = backend_type + self._disk_root = disk_root + self._disk_max_bytes = disk_max_bytes + self._caches: Dict[str, CacheStore] = {} + self._lock = threading.RLock() + + def register(self, cache_config: CacheConfig, *, cache_enabled: bool = True) -> CacheStore: + with self._lock: + if cache_config.name in self._caches: + return self._caches[cache_config.name] + backend = self._create_backend(cache_config) + store = CacheStore(cache_config, backend, enabled=self._enabled and cache_enabled) + self._caches[cache_config.name] = store + return store + + def cache(self, name: str) -> CacheStore: + with self._lock: + return self._caches[name] + + def _create_backend(self, cache_config: CacheConfig) -> MemoryCacheBackend | DiskCacheBackend: + if self._backend_type == "memory": + return MemoryCacheBackend(cache_config.max_entries) + cache_dir = self._disk_root / cache_config.name + cache_dir.mkdir(parents=True, exist_ok=True) + return DiskCacheBackend(cache_dir, max_bytes=self._disk_max_bytes) + + def clear_all(self) -> None: + logger.info("[cache] clearing all caches") + for cache in self._caches.values(): + cache.clear() + if self._backend_type == "disk": + shutil.rmtree(self._disk_root, ignore_errors=True) + self._disk_root.mkdir(parents=True, exist_ok=True) + + def clear_cache(self, name: str) -> None: + cache = self._caches.get(name) + if not cache: + raise KeyError(name) + cache.clear() + if self._backend_type == "disk": + cache_dir = self._disk_root / name + shutil.rmtree(cache_dir, ignore_errors=True) + cache_dir.mkdir(parents=True, exist_ok=True) + + def prune_cache(self, name: str, cutoff: float) -> int: + cache = self._caches.get(name) + if not cache: + raise KeyError(name) + return cache.prune_older_than(cutoff) + + def prune_all(self, cutoff: float) -> dict[str, int]: + result: dict[str, int] = {} + for name, cache in self._caches.items(): + result[name] = cache.prune_older_than(cutoff) + return result + + def stats(self) -> dict[str, Any]: + caches = {name: cache.stats() for name, cache in self._caches.items()} + return { + "enabled": self._enabled, + "backend": self._backend_type, + "disk_root": str(self._disk_root), + "disk_max_bytes": self._disk_max_bytes, + "caches": caches, + } + + def cache_names(self) -> list[str]: + return list(self._caches.keys()) + + +def _init_manager() -> CacheManager: + disk_root = config.DATA_DIR / "cache" + disk_root.mkdir(parents=True, exist_ok=True) + manager = CacheManager( + enabled=config.CACHE_ENABLED, + backend_type=config.CACHE_BACKEND, + disk_root=disk_root, + disk_max_bytes=config.CACHE_DISK_MAX_BYTES, + ) + manager.register( + CacheConfig( + name="selection", + ttl_seconds=config.CACHE_SELECTION_TTL_SECONDS, + max_entries=config.CACHE_SELECTION_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_SELECTION_ENABLED, + ) + manager.register( + CacheConfig( + name="passage_summary", + ttl_seconds=config.CACHE_SUMMARY_TTL_SECONDS, + max_entries=config.CACHE_SUMMARY_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_SUMMARY_ENABLED, + ) + manager.register( + CacheConfig( + name="passage_keywords", + ttl_seconds=config.CACHE_KEYWORDS_TTL_SECONDS, + max_entries=config.CACHE_KEYWORDS_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_KEYWORDS_ENABLED, + ) + manager.register( + CacheConfig( + name="translation_helps", + ttl_seconds=config.CACHE_TRANSLATION_HELPS_TTL_SECONDS, + max_entries=config.CACHE_TRANSLATION_HELPS_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_TRANSLATION_HELPS_ENABLED, + ) + manager.register( + CacheConfig( + name="rag_vector", + ttl_seconds=config.CACHE_RAG_VECTOR_TTL_SECONDS, + max_entries=config.CACHE_RAG_VECTOR_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_RAG_VECTOR_ENABLED, + ) + manager.register( + CacheConfig( + name="rag_final", + ttl_seconds=config.CACHE_RAG_FINAL_TTL_SECONDS, + max_entries=config.CACHE_RAG_FINAL_MAX_ENTRIES, + ), + cache_enabled=config.CACHE_RAG_FINAL_ENABLED, + ) + return manager + + +cache_manager = _init_manager() + + +def get_cache(name: str) -> CacheStore: + """Convenience accessor for registered caches.""" + return cache_manager.cache(name) + + +__all__ = [ + "CACHE_SCHEMA_VERSION", + "CacheManager", + "cache_manager", + "get_cache", +] diff --git a/bt_servant_engine/services/graph_pipeline.py b/bt_servant_engine/services/graph_pipeline.py index 4e6031a..9d440a5 100644 --- a/bt_servant_engine/services/graph_pipeline.py +++ b/bt_servant_engine/services/graph_pipeline.py @@ -6,6 +6,7 @@ from __future__ import annotations +import hashlib import json from dataclasses import dataclass from typing import Any, Callable, Optional, cast @@ -15,6 +16,7 @@ from bt_servant_engine.core.intents import IntentType from bt_servant_engine.core.logging import get_logger +from bt_servant_engine.services.cache_manager import CACHE_SCHEMA_VERSION, get_cache from bt_servant_engine.services.openai_utils import track_openai_usage logger = get_logger(__name__) @@ -66,6 +68,36 @@ class OpenAIQueryPayload: boilerplate_features_message: str +@dataclass(frozen=True) +class VectorCacheKey: + """Cache key for vector database queries.""" + + schema: str + transformed_query: str + collections: tuple[str, ...] + top_k: int + relevance_cutoff: float + + +@dataclass(frozen=True) +class FinalResponseCacheKey: + """Cache key for final RAG responses.""" + + schema: str + transformed_query: str + agentic_strength: str + model_name: str + docs_digest: str + chat_history_digest: str + + +VECTOR_CACHE_SCHEMA = f"{CACHE_SCHEMA_VERSION}:rag_vector:v1" +FINAL_CACHE_SCHEMA = f"{CACHE_SCHEMA_VERSION}:rag_final:v1" + +_VECTOR_CACHE = get_cache("rag_vector") +_FINAL_CACHE = get_cache("rag_final") + + def _extract_query_rows(results: Any) -> tuple[list[str], list[float], list[Any]]: document_rows = cast(list[list[str]], results.get("documents", [])) if not document_rows: @@ -202,7 +234,7 @@ def query_vector_db( _boilerplate_features_message: str, *, config: VectorQueryConfig | None = None, -) -> dict[str, list[dict[str, str]]]: +) -> dict[str, list[dict[str, Any]]]: """Query the vector DB (Chroma) across ranked collections and filter by relevance. Args: @@ -217,18 +249,40 @@ def query_vector_db( Dictionary with "docs" key containing filtered document list """ vector_config = config or VectorQueryConfig() - filtered_docs: list[dict[str, Any]] = [] - # this loop is the current implementation of the "stacked ranked" algorithm - for collection_name in stack_rank_collections: - filtered_docs.extend( - _query_collection_docs( - transformed_query, - collection_name, - get_chroma_collection_fn, - vector_config, + cache_key = VectorCacheKey( + schema=VECTOR_CACHE_SCHEMA, + transformed_query=transformed_query, + collections=tuple(stack_rank_collections), + top_k=vector_config.top_k, + relevance_cutoff=vector_config.relevance_cutoff, + ) + + def _compute_vector() -> dict[str, list[dict[str, Any]]]: + filtered_docs: list[dict[str, Any]] = [] + for collection_name in stack_rank_collections: + filtered_docs.extend( + _query_collection_docs( + transformed_query, + collection_name, + get_chroma_collection_fn, + vector_config, + ) ) + logger.info( + "[rag-vector] retrieved %d docs for query signature %s", + len(filtered_docs), + cache_key.collections, + ) + return {"docs": filtered_docs} + + vector_response, hit = _VECTOR_CACHE.get_or_set(cache_key, _compute_vector) + if hit: + logger.info( + "[rag-vector] cache hit for query=%s collections=%s", + transformed_query, + cache_key.collections, ) - return {"docs": filtered_docs} + return vector_response def query_open_ai( @@ -246,58 +300,90 @@ def query_open_ai( Returns: Dictionary with "responses" key containing response list """ - try: - if not payload.docs: - return _no_docs_response(payload.boilerplate_features_message) - - context = json.dumps(payload.docs, indent=2) - logger.info("context passed to final node:\n\n%s", context) - messages = _build_rag_messages(payload, context) - model_name = dependencies.model_for_agentic_strength( - payload.agentic_strength, - allow_low=False, - allow_very_low=True, - ) - response = client.responses.create( - model=model_name, - instructions=FINAL_RESPONSE_AGENT_SYSTEM_PROMPT, - input=cast(Any, messages), - ) - usage = getattr(response, "usage", None) - track_openai_usage( - usage, - model_name, - dependencies.extract_cached_input_tokens, - dependencies.add_tokens, - ) - bt_servant_response = response.output_text - logger.info("response from openai: %s", bt_servant_response) - logger.debug("%d characters returned from openAI", len(bt_servant_response)) + if not payload.docs: + return _no_docs_response(payload.boilerplate_features_message) + + docs_digest = hashlib.sha256( + json.dumps(payload.docs, ensure_ascii=False, sort_keys=True).encode("utf-8") + ).hexdigest() + history_slice = payload.chat_history[-4:] if payload.chat_history else [] + chat_history_digest = hashlib.sha256( + json.dumps(history_slice, ensure_ascii=False, sort_keys=True).encode("utf-8") + ).hexdigest() + model_name = dependencies.model_for_agentic_strength( + payload.agentic_strength, + allow_low=False, + allow_very_low=True, + ) + cache_key = FinalResponseCacheKey( + schema=FINAL_CACHE_SCHEMA, + transformed_query=payload.transformed_query, + agentic_strength=payload.agentic_strength, + model_name=model_name, + docs_digest=docs_digest, + chat_history_digest=chat_history_digest, + ) + store_flag = {"store": True} + + def _compute_final_response() -> dict[str, list[dict[str, Any]]]: + try: + context = json.dumps(payload.docs, indent=2) + logger.info("context passed to final node:\n\n%s", context) + messages = _build_rag_messages(payload, context) + response = client.responses.create( + model=model_name, + instructions=FINAL_RESPONSE_AGENT_SYSTEM_PROMPT, + input=cast(Any, messages), + ) + usage = getattr(response, "usage", None) + track_openai_usage( + usage, + model_name, + dependencies.extract_cached_input_tokens, + dependencies.add_tokens, + ) + bt_servant_response = response.output_text + logger.info("response from openai: %s", bt_servant_response) + logger.debug("%d characters returned from openAI", len(bt_servant_response)) + + resource_summary = _summarize_resource_usage(payload.docs) + logger.info( + "bt servant used the following resources to generate its response: %s", + resource_summary, + ) - resource_summary = _summarize_resource_usage(payload.docs) - logger.info( - "bt servant used the following resources to generate its response: %s", - resource_summary, - ) + return { + "responses": [ + { + "intent": IntentType.GET_BIBLE_TRANSLATION_ASSISTANCE, + "response": bt_servant_response, + } + ] + } + except OpenAIError: + store_flag["store"] = False + logger.error("Error during OpenAI request", exc_info=True) + error_msg = ( + "I encountered some problems while trying to respond. Let Ian know about this one." + ) + return { + "responses": [ + { + "intent": IntentType.GET_BIBLE_TRANSLATION_ASSISTANCE, + "response": error_msg, + } + ] + } - return { - "responses": [ - { - "intent": IntentType.GET_BIBLE_TRANSLATION_ASSISTANCE, - "response": bt_servant_response, - } - ] - } - except OpenAIError: - logger.error("Error during OpenAI request", exc_info=True) - error_msg = ( - "I encountered some problems while trying to respond. Let Ian know about this one." + final_response, hit = _FINAL_CACHE.get_or_set( + cache_key, + _compute_final_response, + should_store=lambda _: store_flag["store"], + ) + if hit: + logger.info( + "[rag-final] cache hit for query=%s docs=%d", + payload.transformed_query, + len(payload.docs), ) - return { - "responses": [ - { - "intent": IntentType.GET_BIBLE_TRANSLATION_ASSISTANCE, - "response": error_msg, - } - ] - } + return final_response diff --git a/bt_servant_engine/services/intents/passage_intents.py b/bt_servant_engine/services/intents/passage_intents.py index 224038e..5bdce43 100644 --- a/bt_servant_engine/services/intents/passage_intents.py +++ b/bt_servant_engine/services/intents/passage_intents.py @@ -2,6 +2,7 @@ from __future__ import annotations +import hashlib import re from dataclasses import dataclass from pathlib import Path @@ -15,6 +16,7 @@ from bt_servant_engine.core.language import Language, ResponseLanguage from bt_servant_engine.core.language import SUPPORTED_LANGUAGE_MAP as supported_language_map from bt_servant_engine.core.logging import get_logger +from bt_servant_engine.services.cache_manager import CACHE_SCHEMA_VERSION, get_cache from bt_servant_engine.services.openai_utils import track_openai_usage from bt_servant_engine.services.passage_selection import ( PassageSelectionRequest, @@ -79,6 +81,39 @@ class PassageSelectionResult: ranges: list[RangeSelection] +# pylint: disable=too-many-instance-attributes +@dataclass(frozen=True) +class SummaryCacheKey: + """Cache key for passage summaries.""" + + schema: str + canonical_book: str + ranges: tuple["RangeSelection", ...] + source_language: str + source_version: str | None + agentic_strength: str + model_name: str + verses_digest: str + + +@dataclass(frozen=True) +class KeywordsCacheKey: + """Cache key for keyword lookups.""" + + schema: str + canonical_book: str + ranges: tuple["RangeSelection", ...] + data_root: str + data_root_mtime: int + + +SUMMARY_CACHE_SCHEMA = f"{CACHE_SCHEMA_VERSION}:passage_summary:v1" +KEYWORDS_CACHE_SCHEMA = f"{CACHE_SCHEMA_VERSION}:passage_keywords:v1" + +_SUMMARY_CACHE = get_cache("passage_summary") +_KEYWORDS_CACHE = get_cache("passage_keywords") + + @dataclass(slots=True) class PassageSummaryRequest: """Inputs required for generating a passage summary.""" @@ -236,17 +271,14 @@ def _build_summary_messages(ref_label: str, verses_text: str) -> list[EasyInputM def _summarize_passage( ref_label: str, - verses: list[tuple[str, str]], + verses_text: str, + verses_count: int, request: PassageSummaryRequest, + model_name: str, ) -> str: - messages = _build_summary_messages(ref_label, _join_passage_text(verses)) + messages = _build_summary_messages(ref_label, verses_text) deps = request.selection.dependencies - model_name = request.model_for_agentic_strength( - request.agentic_strength, - allow_low=True, - allow_very_low=True, - ) - logger.info("[passage-summary] summarizing %d verses", len(verses)) + logger.info("[passage-summary] summarizing %d verses", verses_count) summary_resp = deps.client.responses.create( model=model_name, instructions=PASSAGE_SUMMARY_AGENT_SYSTEM_PROMPT, @@ -489,7 +521,7 @@ def _build_verbatim_passage_response(passage: RetrievedPassage) -> dict[str, Any } -def get_passage_summary(request: PassageSummaryRequest) -> dict[str, Any]: +def get_passage_summary(request: PassageSummaryRequest) -> dict[str, Any]: # pylint: disable=too-many-locals """Handle get-passage-summary: extract refs, retrieve verses, summarize. - If user query language is not English, translate the transformed query to English @@ -544,21 +576,52 @@ def get_passage_summary(request: PassageSummaryRequest) -> dict[str, Any]: _, ref_label = _localize_reference(selection_result, source) logger.info("[passage-summary] label=%s", ref_label) - summary_text = _summarize_passage(ref_label, verses, request) - response_text = f"Summary of {ref_label}:\n\n{summary_text}" - logger.info("[passage-summary] done") - return { - "responses": [ - { - "intent": IntentType.GET_PASSAGE_SUMMARY, - "response": response_text, - } - ], - "passage_followup_context": _passage_followup_context( - IntentType.GET_PASSAGE_SUMMARY, - selection_result, - ), - } + verses_text = _join_passage_text(verses) + model_name = request.model_for_agentic_strength( + request.agentic_strength, + allow_low=True, + allow_very_low=True, + ) + ranges_tuple = tuple(selection_result.ranges) + verses_digest = hashlib.sha256(verses_text.encode("utf-8")).hexdigest() + cache_key = SummaryCacheKey( + schema=SUMMARY_CACHE_SCHEMA, + canonical_book=selection_result.canonical_book, + ranges=ranges_tuple, + source_language=source.language, + source_version=source.version, + agentic_strength=request.agentic_strength, + model_name=model_name, + verses_digest=verses_digest, + ) + + def _compute_summary() -> dict[str, Any]: + summary_text = _summarize_passage( + ref_label, + verses_text, + len(verses), + request, + model_name, + ) + response_text = f"Summary of {ref_label}:\n\n{summary_text}" + logger.info("[passage-summary] done (generated)") + return { + "responses": [ + { + "intent": IntentType.GET_PASSAGE_SUMMARY, + "response": response_text, + } + ], + "passage_followup_context": _passage_followup_context( + IntentType.GET_PASSAGE_SUMMARY, + selection_result, + ), + } + + summary_response, hit = _SUMMARY_CACHE.get_or_set(cache_key, _compute_summary) + if hit: + logger.info("[passage-summary] served from cache label=%s", ref_label) + return summary_response def get_passage_keywords(request: PassageKeywordsRequest) -> dict[str, Any]: @@ -581,39 +644,63 @@ def get_passage_keywords(request: PassageKeywordsRequest) -> dict[str, Any]: # Retrieve keywords from keyword dataset data_root = Path("sources") / "keyword_data" - logger.info("[passage-keywords] retrieving keywords from %s", data_root) - keywords = select_keywords( - data_root, - selection_result.canonical_book, - selection_result.ranges, + try: + data_root_mtime = int(data_root.stat().st_mtime_ns) + except OSError: + data_root_mtime = 0 + cache_key = KeywordsCacheKey( + schema=KEYWORDS_CACHE_SCHEMA, + canonical_book=selection_result.canonical_book, + ranges=tuple(selection_result.ranges), + data_root=str(data_root), + data_root_mtime=data_root_mtime, ) - logger.info("[passage-keywords] retrieved %d keyword(s)", len(keywords)) - if not keywords: - logger.info("[passage-keywords] no keywords found; prompting user") + def _compute_keywords() -> dict[str, Any]: + logger.info("[passage-keywords] retrieving keywords from %s", data_root) + keywords = select_keywords( + data_root, + selection_result.canonical_book, + selection_result.ranges, + ) + logger.info("[passage-keywords] retrieved %d keyword(s)", len(keywords)) + + if not keywords: + logger.info("[passage-keywords] no keywords found; prompting user") + return { + "responses": [ + { + "intent": IntentType.GET_PASSAGE_KEYWORDS, + "response": MISSING_KEYWORDS_MESSAGE, + } + ], + } + + ref_label = label_ranges(selection_result.canonical_book, selection_result.ranges) + header = f"Keywords in {ref_label}\n\n" + body = ", ".join(keywords) + response_text = header + body + logger.info("[passage-keywords] done (generated)") return { "responses": [ - {"intent": IntentType.GET_PASSAGE_KEYWORDS, "response": MISSING_KEYWORDS_MESSAGE} - ] + { + "intent": IntentType.GET_PASSAGE_KEYWORDS, + "response": response_text, + } + ], + "passage_followup_context": _passage_followup_context( + IntentType.GET_PASSAGE_KEYWORDS, + selection_result, + ), } - ref_label = label_ranges(selection_result.canonical_book, selection_result.ranges) - header = f"Keywords in {ref_label}\n\n" - body = ", ".join(keywords) - response_text = header + body - logger.info("[passage-keywords] done") - return { - "responses": [ - { - "intent": IntentType.GET_PASSAGE_KEYWORDS, - "response": response_text, - } - ], - "passage_followup_context": _passage_followup_context( - IntentType.GET_PASSAGE_KEYWORDS, - selection_result, - ), - } + keyword_response, hit = _KEYWORDS_CACHE.get_or_set(cache_key, _compute_keywords) + if hit: + logger.info( + "[passage-keywords] served from cache for book=%s", + selection_result.canonical_book, + ) + return keyword_response def retrieve_scripture(request: RetrieveScriptureRequest) -> dict[str, Any]: diff --git a/bt_servant_engine/services/intents/translation_intents.py b/bt_servant_engine/services/intents/translation_intents.py index 03dd821..201f565 100644 --- a/bt_servant_engine/services/intents/translation_intents.py +++ b/bt_servant_engine/services/intents/translation_intents.py @@ -2,6 +2,8 @@ from __future__ import annotations +import hashlib +import json import re from dataclasses import dataclass from pathlib import Path @@ -15,6 +17,7 @@ from bt_servant_engine.core.language import Language, ResponseLanguage, TranslatedPassage from bt_servant_engine.core.language import SUPPORTED_LANGUAGE_MAP as supported_language_map from bt_servant_engine.core.logging import get_logger +from bt_servant_engine.services.cache_manager import CACHE_SCHEMA_VERSION, get_cache from bt_servant_engine.services.openai_utils import track_openai_usage from bt_servant_engine.services.passage_selection import ( PassageSelectionDependencies, @@ -140,6 +143,23 @@ class TranslationHelpsDependencies: build_messages_fn: Callable[..., Any] +@dataclass(frozen=True) +class TranslationHelpsCacheKey: + """Cache key for translation helps responses.""" + + schema: str + canonical_book: str + ranges: tuple[TranslationRange, ...] + selection_note: str + agentic_strength: str + model_name: str + raw_helps_digest: str + + +TRANSLATION_HELPS_CACHE_SCHEMA = f"{CACHE_SCHEMA_VERSION}:translation_helps:v1" +_TRANSLATION_HELPS_CACHE = get_cache("translation_helps") + + TRANSLATE_PASSAGE_AGENT_SYSTEM_PROMPT = """ # Task @@ -674,40 +694,64 @@ def get_translation_helps( "I couldn't prepare translation helps for that selection. Please try again." ) - messages = dependencies.build_messages_fn( - payload.ref_label, - payload.context_obj, - payload.selection_note, - ) - - logger.info("[translation-helps] invoking LLM with %d helps", len(payload.raw_helps)) model_name = dependencies.select_model_fn( request.agentic_strength, allow_low=True, allow_very_low=True ) - resp = request.client.responses.create( - model=model_name, - instructions=TRANSLATION_HELPS_AGENT_SYSTEM_PROMPT, - input=cast(Any, messages), - store=False, + raw_helps_digest = hashlib.sha256( + json.dumps(payload.raw_helps, ensure_ascii=False, sort_keys=True).encode("utf-8") + ).hexdigest() + selection_note = payload.selection_note or "" + cache_key = TranslationHelpsCacheKey( + schema=TRANSLATION_HELPS_CACHE_SCHEMA, + canonical_book=payload.canonical_book, + ranges=tuple(payload.ranges), + selection_note=selection_note, + agentic_strength=request.agentic_strength, + model_name=model_name, + raw_helps_digest=raw_helps_digest, ) - usage = getattr(resp, "usage", None) - track_openai_usage(usage, model_name, dependencies.extract_cached_tokens_fn, add_tokens) - header = f"Translation helps for {payload.ref_label}\n\n" - response_text = header + (resp.output_text or "") - return { - "responses": [ - { + def _compute_translation_helps() -> dict[str, Any]: + messages = dependencies.build_messages_fn( + payload.ref_label, + payload.context_obj, + payload.selection_note, + ) + logger.info("[translation-helps] invoking LLM with %d helps", len(payload.raw_helps)) + resp = request.client.responses.create( + model=model_name, + instructions=TRANSLATION_HELPS_AGENT_SYSTEM_PROMPT, + input=cast(Any, messages), + store=False, + ) + usage = getattr(resp, "usage", None) + track_openai_usage(usage, model_name, dependencies.extract_cached_tokens_fn, add_tokens) + + header = f"Translation helps for {payload.ref_label}\n\n" + response_text = header + (resp.output_text or "") + logger.info("[translation-helps] done (generated)") + return { + "responses": [ + { + "intent": IntentType.GET_TRANSLATION_HELPS, + "response": response_text, + } + ], + "passage_followup_context": { "intent": IntentType.GET_TRANSLATION_HELPS, - "response": response_text, - } - ], - "passage_followup_context": { - "intent": IntentType.GET_TRANSLATION_HELPS, - "book": payload.canonical_book, - "ranges": payload.ranges, - }, - } + "book": payload.canonical_book, + "ranges": payload.ranges, + }, + } + + helps_response, hit = _TRANSLATION_HELPS_CACHE.get_or_set(cache_key, _compute_translation_helps) + if hit: + logger.info( + "[translation-helps] served from cache for book=%s ranges=%s", + payload.canonical_book, + payload.ranges, + ) + return helps_response __all__ = [ diff --git a/bt_servant_engine/services/passage_selection.py b/bt_servant_engine/services/passage_selection.py index 556e093..df4f8f8 100644 --- a/bt_servant_engine/services/passage_selection.py +++ b/bt_servant_engine/services/passage_selection.py @@ -2,6 +2,7 @@ from __future__ import annotations +import hashlib import re from dataclasses import dataclass from typing import Any, Callable, Optional, cast @@ -12,6 +13,7 @@ from bt_servant_engine.core.language import Language from bt_servant_engine.core.logging import get_logger from bt_servant_engine.core.models import PassageRef, PassageSelection +from bt_servant_engine.services.cache_manager import CACHE_SCHEMA_VERSION, get_cache from bt_servant_engine.services.openai_utils import extract_cached_input_tokens, track_openai_usage from utils.bsb import FULL_BOOK_SENTINEL, normalize_book_name from utils.perf import add_tokens @@ -74,6 +76,17 @@ """ +@dataclass(frozen=True) +class SelectionCacheKey: + """Cache key for passage selection normalization.""" + + schema: str + query: str + query_lang: str + focus_hint: str | None + book_map_digest: str + + @dataclass(slots=True) class PassageSelectionDependencies: """External services required for passage selection parsing.""" @@ -94,6 +107,17 @@ class PassageSelectionRequest: focus_hint: Optional[str] = None +_SELECTION_CACHE = get_cache("selection") + + +def _book_map_digest(book_map: dict[str, Any]) -> str: + keys = ",".join(sorted(book_map.keys())) + return hashlib.sha256(keys.encode("utf-8")).hexdigest() + + +SelectionResult = tuple[str | None, list["RangeSelection"] | None, str | None] + + def _prepare_parse_input(request: PassageSelectionRequest) -> str: if request.query_lang == Language.ENGLISH.value: logger.info("[selection-helper] parsing in English (no translation needed)") @@ -234,20 +258,35 @@ def _ranges_from_selections( def resolve_selection_for_single_book( request: PassageSelectionRequest, -) -> tuple[str | None, list[tuple[int, int | None, int | None, int | None]] | None, str | None]: - """Parse and normalize a user query into a single canonical book and ranges. +) -> SelectionResult: + """Parse and normalize a user query into a single canonical book and ranges.""" + + schema = f"{CACHE_SCHEMA_VERSION}:selection:v1" + cache_key = SelectionCacheKey( + schema=schema, + query=request.query, + query_lang=request.query_lang, + focus_hint=request.focus_hint or "", + book_map_digest=_book_map_digest(request.dependencies.book_map), + ) - Args: - request: Structured selection request containing query metadata and dependencies. + def _compute() -> SelectionResult: + return _resolve_selection_uncached(request) - Returns: - Tuple of (canonical_book, ranges, error_message). On success, the - error_message is None. On failure, canonical_book and ranges are None and - error_message contains a user-friendly explanation. + result, hit = _SELECTION_CACHE.get_or_set(cache_key, _compute) + if hit: + logger.info( + "[selection-helper] cache hit; query_lang=%s; query=%s", + request.query_lang, + request.query, + ) + return result - If ``focus_hint`` is provided, it is sent as a developer message to steer the - selection model toward the clause relevant to the current intent. - """ + +def _resolve_selection_uncached( + request: PassageSelectionRequest, +) -> SelectionResult: + """Uncached passage selection parsing for cache miss handling.""" logger.info( "[selection-helper] start; query_lang=%s; query=%s", request.query_lang, @@ -277,6 +316,9 @@ def resolve_selection_for_single_book( return canonical_book, ranges, None +RangeSelection = tuple[int, int | None, int | None, int | None] + + __all__ = [ "PASSAGE_SELECTION_AGENT_SYSTEM_PROMPT", "PassageSelectionDependencies", diff --git a/tests/apps/api/routes/test_admin_cache_endpoints.py b/tests/apps/api/routes/test_admin_cache_endpoints.py new file mode 100644 index 0000000..1cca563 --- /dev/null +++ b/tests/apps/api/routes/test_admin_cache_endpoints.py @@ -0,0 +1,184 @@ +"""Tests for admin cache management endpoints.""" + +# pylint: disable=missing-function-docstring + +from __future__ import annotations + +from http import HTTPStatus + +from fastapi.testclient import TestClient + +from bt_servant_engine.apps.api.app import create_app +from bt_servant_engine.apps.api.routes import admin +from bt_servant_engine.bootstrap import build_default_service_container +from bt_servant_engine.core.config import config as app_config + + +class _StubCache: + def __init__(self) -> None: + self.cleared = False + self.sample_limit = 0 + + def clear(self) -> None: + self.cleared = True + + def stats(self) -> dict[str, object]: + return { + "name": "selection", + "enabled": True, + "ttl_seconds": 100, + "max_entries": 10, + "entry_count": 1, + "bytes_used": 0, + "oldest_entry_epoch": None, + "newest_entry_epoch": None, + "stats": {"hits": 1, "misses": 0, "stores": 1, "evictions": 0}, + } + + def detailed_stats(self, sample_limit: int = 10) -> dict[str, object]: + self.sample_limit = sample_limit + data = self.stats() + data["samples"] = [ + { + "key_repr": "dummy-key", + "size_bytes": 42, + "created_at": 1.0, + "expires_at": 1000.0, + "last_access": 10.0, + "age_seconds": 9.0, + "ttl_remaining": 991.0, + } + ] + return data + + +class _StubCacheManager: + def __init__(self) -> None: + self.clear_all_called = False + self.clear_called: list[str] = [] + self.prune_all_cutoff: float | None = None + self.prune_cache_calls: list[tuple[str, float]] = [] + self.cache_obj = _StubCache() + + def clear_all(self) -> None: + self.clear_all_called = True + + def clear_cache(self, name: str) -> None: + if name != "selection": + raise KeyError(name) + self.clear_called.append(name) + + def stats(self) -> dict[str, object]: + return { + "enabled": True, + "backend": "memory", + "disk_root": "/tmp/cache", + "disk_max_bytes": 512, + "caches": {"selection": self.cache_obj.stats()}, + } + + def cache(self, name: str) -> _StubCache: + if name != "selection": + raise KeyError(name) + return self.cache_obj + + def prune_all(self, cutoff: float) -> dict[str, int]: + self.prune_all_cutoff = cutoff + return {"selection": 1} + + def prune_cache(self, name: str, cutoff: float) -> int: + if name != "selection": + raise KeyError(name) + self.prune_cache_calls.append((name, cutoff)) + return 2 + + +def _make_client(monkeypatch) -> tuple[TestClient, _StubCacheManager]: + stub = _StubCacheManager() + monkeypatch.setattr(admin, "cache_manager", stub) + monkeypatch.setattr(app_config, "ENABLE_ADMIN_AUTH", False) + client = TestClient(create_app(build_default_service_container())) + return client, stub + + +def test_clear_all_caches_endpoint(monkeypatch): + client, stub = _make_client(monkeypatch) + resp = client.post("/cache/clear") + assert resp.status_code == HTTPStatus.OK + assert resp.json()["status"] == "cleared" + assert stub.clear_all_called + + +def test_clear_named_cache_endpoint(monkeypatch): + client, stub = _make_client(monkeypatch) + resp = client.post("/cache/selection/clear") + assert resp.status_code == HTTPStatus.OK + assert resp.json()["cache"] == "selection" + assert stub.clear_called == ["selection"] + + resp = client.post("/cache/unknown/clear") + assert resp.status_code == HTTPStatus.NOT_FOUND + assert resp.json()["detail"]["error"] == "Cache 'unknown' not found" + + +def test_prune_all_caches_endpoint(monkeypatch): + client, stub = _make_client(monkeypatch) + resp = client.post("/cache/clear", params={"older_than_days": 1}) + assert resp.status_code == HTTPStatus.OK + payload = resp.json() + assert payload["status"] == "pruned" + assert "cutoff_epoch" in payload + assert payload["removed"] == {"selection": 1} + assert stub.prune_all_cutoff is not None + + +def test_prune_named_cache_endpoint(monkeypatch): + client, stub = _make_client(monkeypatch) + days = 2 + resp = client.post("/cache/selection/clear", params={"older_than_days": days}) + assert resp.status_code == HTTPStatus.OK + payload = resp.json() + assert payload["status"] == "pruned" + removed_expected = 2 + assert payload["removed"] == removed_expected + assert stub.prune_cache_calls and stub.prune_cache_calls[0][1] > 0 + + +def test_prune_invalid_params(monkeypatch): + client, _ = _make_client(monkeypatch) + resp = client.post("/cache/clear", params={"older_than_days": -1}) + assert resp.status_code == HTTPStatus.BAD_REQUEST + resp = client.post("/cache/selection/clear", params={"older_than_days": 0}) + assert resp.status_code == HTTPStatus.BAD_REQUEST + + +def test_get_cache_stats_endpoint(monkeypatch): + client, _ = _make_client(monkeypatch) + resp = client.get("/cache/stats") + assert resp.status_code == HTTPStatus.OK + data = resp.json() + assert data["enabled"] is True + assert "selection" in data["caches"] + + +def test_inspect_cache_endpoint(monkeypatch): + client, stub = _make_client(monkeypatch) + sample_limit = 5 + resp = client.get("/cache/selection", params={"sample_limit": sample_limit}) + assert resp.status_code == HTTPStatus.OK + data = resp.json() + assert data["name"] == "selection" + assert data["sample_limit"] == sample_limit + assert stub.cache_obj.sample_limit == sample_limit + assert data["samples"] + + bad_resp = client.get("/cache/selection", params={"sample_limit": 0}) + assert bad_resp.status_code == HTTPStatus.BAD_REQUEST + assert ( + bad_resp.json()["detail"]["error"] + == f"sample_limit must be between 1 and {admin.MAX_CACHE_SAMPLE_LIMIT}" + ) + + missing = client.get("/cache/missing") + assert missing.status_code == HTTPStatus.NOT_FOUND + assert missing.json()["detail"]["error"] == "Cache 'missing' not found"