Coverage for src / mcp_server_langgraph / core / cache.py: 94%
210 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Multi-layer caching system for MCP server.
4Provides:
5- L1: In-memory LRU cache (< 10ms, per-instance)
6- L2: Redis distributed cache (< 50ms, shared)
7- L3: Provider-native caching (Anthropic, Gemini prompt caching)
9Features:
10- Automatic TTL and eviction
11- Cache stampede prevention
12- Metrics and observability
13- Tiered cache promotion/demotion
15See ADR-0028 for design rationale.
16"""
18import asyncio
19import functools
20import hashlib
21import pickle
22from collections.abc import Callable
23from typing import Any, ParamSpec, TypeVar
24from urllib.parse import urlparse, urlunparse
26import redis
27from cachetools import TTLCache
29from mcp_server_langgraph.core.config import settings
30from mcp_server_langgraph.observability.telemetry import logger, tracer
32P = ParamSpec("P")
33T = TypeVar("T")
35# Cache TTL configurations (seconds)
36CACHE_TTLS = {
37 "auth_permission": 300, # 5 minutes - permissions don't change frequently
38 "user_profile": 900, # 15 minutes - profiles update infrequently
39 "llm_response": 3600, # 1 hour - deterministic responses
40 "embedding": 86400, # 24 hours - embeddings are deterministic
41 "prometheus_query": 60, # 1 minute - metrics change frequently
42 "knowledge_base": 1800, # 30 minutes - search index updates periodically
43 "feature_flag": 60, # 1 minute - fast rollout needed
44}
47def _build_redis_url_with_db(redis_url: str, db: int) -> str:
48 """
49 Build Redis URL with database number using proper URL parsing.
51 This prevents malformed URLs that occur with simple string concatenation:
52 - redis://localhost:6379/ + /2 → redis://localhost:6379//2 (double slash)
53 - redis://localhost:6379/0 + /2 → redis://localhost:6379/0/2 (nested paths)
55 Args:
56 redis_url: Base Redis URL (e.g., "redis://host:port" or "redis://:password@host:port")
57 db: Database number to use
59 Returns:
60 Properly formatted Redis URL with database number (e.g., "redis://host:port/2")
62 Example:
63 >>> _build_redis_url_with_db("redis://localhost:6379", 2)
64 'redis://localhost:6379/2'
65 >>> _build_redis_url_with_db("redis://localhost:6379/", 2)
66 'redis://localhost:6379/2'
67 >>> _build_redis_url_with_db("redis://localhost:6379/0", 2)
68 'redis://localhost:6379/2'
69 """
70 # Parse the URL to handle existing database numbers, trailing slashes, query params
71 parsed = urlparse(redis_url)
73 # Replace path with correct database number
74 # Redis URL path is typically empty or /db_number
75 new_path = f"/{db}"
77 # Reconstruct URL with new database number
78 # This pattern matches dependencies.py:189-212 (API key manager)
79 return urlunparse(
80 (
81 parsed.scheme, # redis:// or rediss://
82 parsed.netloc, # host:port (or user:pass@host:port)
83 new_path, # /db_number
84 parsed.params, # unused in Redis URLs
85 parsed.query, # query parameters (if any)
86 parsed.fragment, # fragment (if any)
87 )
88 )
91class CacheLayer:
92 """Enum for cache layers"""
94 L1 = "l1" # In-memory LRU
95 L2 = "l2" # Redis distributed
96 L3 = "l3" # Provider-native
99class CacheService:
100 """
101 Unified caching service with L1 + L2 layers.
103 Usage:
104 cache = CacheService()
106 # Set value
107 cache.set("user:profile:123", user_data, ttl=900)
109 # Get value (tries L1 → L2 → None)
110 user_data = cache.get("user:profile:123")
112 # Delete value
113 cache.delete("user:profile:123")
115 # Clear all caches
116 cache.clear()
117 """
119 def __init__(
120 self,
121 l1_maxsize: int = 1000,
122 l1_ttl: int = 60,
123 redis_url: str | None = None,
124 redis_db: int = 2,
125 redis_password: str | None = None,
126 redis_ssl: bool = False,
127 ):
128 """
129 Initialize cache service.
131 Args:
132 l1_maxsize: Max entries in L1 cache
133 l1_ttl: Default TTL for L1 cache in seconds
134 redis_url: Redis connection URL (e.g., "redis://host:port")
135 redis_db: Redis database number (default: 2 for cache)
136 redis_password: Redis password for authentication
137 redis_ssl: Enable SSL/TLS for Redis connection
138 """
139 # L1: In-memory cache (per-instance)
140 self.l1_cache = TTLCache(maxsize=l1_maxsize, ttl=l1_ttl)
141 self.l1_maxsize = l1_maxsize
142 self.l1_ttl = l1_ttl
144 # L2: Redis cache (shared across instances)
145 # Use redis.from_url() pattern (same as API key manager) instead of
146 # Redis(host=..., port=...) to properly handle password and SSL
148 # Get Redis config from settings if not provided
149 if redis_url is None:
150 redis_url = getattr(settings, "redis_url", "redis://localhost:6379")
151 if redis_password is None:
152 redis_password = getattr(settings, "redis_password", None)
153 if redis_ssl is False:
154 redis_ssl = getattr(settings, "redis_ssl", False)
156 # Declare redis with Optional type for proper type checking
157 self.redis: redis.Redis[bytes] | None = None # type: ignore[type-arg]
159 try:
160 # Build Redis URL with database number using helper function
161 # This prevents malformed URLs from simple string concatenation
162 redis_url_with_db = _build_redis_url_with_db(redis_url, redis_db)
164 # Build connection kwargs
165 # NOTE: Only pass ssl=True when SSL is enabled. Passing ssl=False causes
166 # "AbstractConnection.__init__() got an unexpected keyword argument 'ssl'"
167 # because the non-SSL connection class doesn't accept the ssl parameter.
168 connection_kwargs: dict[str, Any] = {
169 "decode_responses": False, # Keep binary for pickle
170 "socket_connect_timeout": 2,
171 "socket_timeout": 2,
172 }
173 if redis_password:
174 connection_kwargs["password"] = redis_password
175 if redis_ssl:
176 connection_kwargs["ssl"] = True
178 # Create Redis client using from_url() with full configuration
179 # This matches the pattern in dependencies.py:215-220 (API key manager)
180 self.redis = redis.from_url( # type: ignore[no-untyped-call]
181 redis_url_with_db,
182 **connection_kwargs,
183 )
184 # Test connection
185 self.redis.ping()
186 self.redis_available = True
187 logger.info(
188 "Redis cache initialized",
189 extra={"redis_url": redis_url, "db": redis_db, "ssl": redis_ssl},
190 )
191 except Exception as e:
192 logger.warning(
193 f"Redis cache unavailable, L2 cache disabled: {e}",
194 extra={"redis_url": redis_url},
195 )
196 self.redis = None
197 self.redis_available = False
199 # Cache stampede prevention locks
200 self._refresh_locks: dict[str, asyncio.Lock] = {}
202 # Statistics
203 self.stats = {
204 "l1_hits": 0,
205 "l1_misses": 0,
206 "l2_hits": 0,
207 "l2_misses": 0,
208 "sets": 0,
209 "deletes": 0,
210 }
212 def get(self, key: str, level: str = CacheLayer.L2) -> Any | None:
213 """
214 Get value from cache (L1 → L2 → None).
216 Args:
217 key: Cache key
218 level: Cache level to search (l1 or l2)
220 Returns:
221 Cached value or None if not found
222 """
223 # Try L1 first
224 if level in (CacheLayer.L1, CacheLayer.L2) and key in self.l1_cache:
225 self.stats["l1_hits"] += 1
226 logger.debug(f"L1 cache hit: {key}")
228 # Emit metric
229 self._emit_cache_hit_metric(CacheLayer.L1, key)
231 return self.l1_cache[key]
233 self.stats["l1_misses"] += 1
235 # Try L2 if enabled
236 if level == CacheLayer.L2 and self.redis_available and self.redis is not None:
237 try:
238 data = self.redis.get(key)
239 if data: 239 ↛ 255line 239 didn't jump to line 255 because the condition on line 239 was always true
240 value = pickle.loads(data) # type: ignore[arg-type]
242 # Promote to L1
243 self.l1_cache[key] = value
245 self.stats["l2_hits"] += 1
246 logger.debug(f"L2 cache hit: {key}")
248 # Emit metric
249 self._emit_cache_hit_metric(CacheLayer.L2, key)
251 return value
252 except Exception as e:
253 logger.warning(f"L2 cache get failed: {e}", extra={"key": key})
255 self.stats["l2_misses"] += 1
257 # Emit cache miss metric
258 self._emit_cache_miss_metric(level, key)
260 return None
262 def set( # type: ignore[no-untyped-def]
263 self,
264 key: str,
265 value: Any,
266 ttl: int | None = None,
267 level: str = CacheLayer.L2,
268 ):
269 """
270 Set value in cache.
272 Args:
273 key: Cache key
274 value: Value to cache
275 ttl: Time-to-live in seconds (default: from cache type)
276 level: Cache level (l1, l2)
277 """
278 # Use default TTL for cache type
279 if ttl is None:
280 ttl = self._get_ttl_from_key(key)
282 # Set in L1
283 if level in (CacheLayer.L1, CacheLayer.L2): 283 ↛ 287line 283 didn't jump to line 287 because the condition on line 283 was always true
284 self.l1_cache[key] = value
286 # Set in L2
287 if level == CacheLayer.L2 and self.redis_available and self.redis is not None:
288 try:
289 self.redis.setex(key, ttl, pickle.dumps(value))
290 logger.debug(f"L2 cache set: {key} (TTL: {ttl}s)")
291 except Exception as e:
292 logger.warning(f"L2 cache set failed: {e}", extra={"key": key})
294 self.stats["sets"] += 1
296 # Emit metric
297 self._emit_cache_set_metric(level, key)
299 def delete(self, key: str) -> None:
300 """
301 Delete from all cache levels.
303 Args:
304 key: Cache key to delete
305 """
306 # Delete from L1
307 self.l1_cache.pop(key, None)
309 # Delete from L2
310 if self.redis_available and self.redis is not None:
311 try:
312 self.redis.delete(key)
313 except Exception as e:
314 logger.warning(f"L2 cache delete failed: {e}")
316 self.stats["deletes"] += 1
318 def clear(self, pattern: str | None = None) -> None:
319 """
320 Clear cache (all or by pattern).
322 Args:
323 pattern: Redis key pattern (e.g., "user:*") or None for all
324 """
325 # Clear L1
326 self.l1_cache.clear()
328 # Clear L2
329 if self.redis_available and self.redis is not None:
330 try:
331 # SECURITY FIX (OpenAI Codex Finding #6): Always use pattern-based deletion
332 # Never use flushdb() as it clears the ENTIRE Redis database, including
333 # other data structures that may share the same DB (e.g., API key cache).
334 # Use pattern="*" for "clear all" instead of flushdb().
335 search_pattern = pattern if pattern else "*"
336 keys = self.redis.keys(search_pattern)
337 if keys: 337 ↛ 341line 337 didn't jump to line 341 because the condition on line 337 was always true
338 self.redis.delete(*keys) # type: ignore[misc]
339 logger.info(f"Cleared L2 cache by pattern: {search_pattern} ({len(keys)} keys)") # type: ignore[arg-type]
340 else:
341 logger.info(f"No L2 cache keys matched pattern: {search_pattern}")
342 except Exception as e:
343 logger.warning(f"L2 cache clear failed: {e}")
345 async def get_with_lock(
346 self,
347 key: str,
348 fetcher: Callable, # type: ignore[type-arg]
349 ttl: int | None = None,
350 ) -> Any:
351 """
352 Get from cache or fetch with lock (prevents cache stampede).
354 Args:
355 key: Cache key
356 fetcher: Async function to fetch data if cache miss
357 ttl: Time-to-live in seconds
359 Returns:
360 Cached or fetched value
361 """
362 # Try cache first
363 if cached := self.get(key):
364 return cached
366 # Acquire lock for this key
367 if key not in self._refresh_locks:
368 self._refresh_locks[key] = asyncio.Lock()
370 async with self._refresh_locks[key]:
371 # Double-check cache (another request may have filled it)
372 if cached := self.get(key):
373 return cached
375 # Fetch and cache
376 value = await fetcher() if asyncio.iscoroutinefunction(fetcher) else fetcher()
377 self.set(key, value, ttl)
379 return value
381 def _get_ttl_from_key(self, key: str) -> int:
382 """
383 Determine TTL from cache key prefix.
385 Args:
386 key: Cache key (format: "type:...")
388 Returns:
389 TTL in seconds
390 """
391 # Extract cache type from key prefix
392 cache_type = key.split(":")[0] if ":" in key else "default"
394 return CACHE_TTLS.get(cache_type, 300) # Default 5 minutes
396 def _emit_cache_hit_metric(self, layer: str, key: str) -> None:
397 """Emit cache hit metric"""
398 try:
399 from mcp_server_langgraph.observability.telemetry import config
401 cache_type = key.split(":")[0] if ":" in key else "unknown"
403 # Create metric if doesn't exist
404 if not hasattr(config, "cache_hits_counter"):
405 config.cache_hits_counter = config.meter.create_counter(
406 name="cache.hits",
407 description="Total cache hits",
408 unit="1",
409 )
411 config.cache_hits_counter.add(
412 1,
413 attributes={"layer": layer, "cache_type": cache_type},
414 )
415 except Exception:
416 pass # Don't let metrics failure break caching
418 def _emit_cache_miss_metric(self, layer: str, key: str) -> None:
419 """Emit cache miss metric"""
420 try:
421 from mcp_server_langgraph.observability.telemetry import config
423 cache_type = key.split(":")[0] if ":" in key else "unknown"
425 if not hasattr(config, "cache_misses_counter"):
426 config.cache_misses_counter = config.meter.create_counter(
427 name="cache.misses",
428 description="Total cache misses",
429 unit="1",
430 )
432 config.cache_misses_counter.add(
433 1,
434 attributes={"layer": layer, "cache_type": cache_type},
435 )
436 except Exception:
437 pass
439 def _emit_cache_set_metric(self, layer: str, key: str) -> None:
440 """Emit cache set metric"""
441 try:
442 from mcp_server_langgraph.observability.telemetry import config
444 cache_type = key.split(":")[0] if ":" in key else "unknown"
446 if not hasattr(config, "cache_sets_counter"):
447 config.cache_sets_counter = config.meter.create_counter(
448 name="cache.sets",
449 description="Total cache sets",
450 unit="1",
451 )
453 config.cache_sets_counter.add(
454 1,
455 attributes={"layer": layer, "cache_type": cache_type},
456 )
457 except Exception:
458 pass
460 def get_statistics(self) -> dict[str, Any]:
461 """
462 Get cache statistics.
464 Returns:
465 Dictionary with cache stats
466 """
467 hit_rate_l1 = (
468 self.stats["l1_hits"] / (self.stats["l1_hits"] + self.stats["l1_misses"])
469 if (self.stats["l1_hits"] + self.stats["l1_misses"]) > 0
470 else 0.0
471 )
473 hit_rate_l2 = (
474 self.stats["l2_hits"] / (self.stats["l2_hits"] + self.stats["l2_misses"])
475 if (self.stats["l2_hits"] + self.stats["l2_misses"]) > 0
476 else 0.0
477 )
479 return {
480 "l1": {
481 "hits": self.stats["l1_hits"],
482 "misses": self.stats["l1_misses"],
483 "hit_rate": hit_rate_l1,
484 "size": len(self.l1_cache),
485 "max_size": self.l1_maxsize,
486 },
487 "l2": {
488 "hits": self.stats["l2_hits"],
489 "misses": self.stats["l2_misses"],
490 "hit_rate": hit_rate_l2,
491 "available": self.redis_available,
492 },
493 "total": {
494 "sets": self.stats["sets"],
495 "deletes": self.stats["deletes"],
496 },
497 }
500# Global cache instance
501_cache_service: CacheService | None = None
504def get_cache() -> CacheService:
505 """Get global cache service (singleton)"""
506 global _cache_service
507 if _cache_service is None:
508 _cache_service = CacheService(
509 l1_maxsize=getattr(settings, "l1_cache_size", 1000),
510 l1_ttl=getattr(settings, "l1_cache_ttl", 60),
511 redis_url=getattr(settings, "redis_url", "redis://localhost:6379"),
512 redis_db=getattr(settings, "redis_cache_db", 2),
513 redis_password=getattr(settings, "redis_password", None),
514 redis_ssl=getattr(settings, "redis_ssl", False),
515 )
516 return _cache_service
519def cached(
520 key_prefix: str,
521 ttl: int | None = None,
522 level: str = CacheLayer.L2,
523) -> Callable[[Callable[P, T]], Callable[P, T]]:
524 """
525 Decorator for caching async function results.
527 Args:
528 key_prefix: Prefix for cache key
529 ttl: Time-to-live in seconds (default: auto from key_prefix)
530 level: Cache level (l1 or l2)
532 Usage:
533 @cached(key_prefix="user:profile", ttl=900)
534 async def get_user_profile(user_id: str) -> dict[str, Any]:
535 return await db.get_user(user_id)
537 # Auto TTL from key prefix
538 @cached(key_prefix="auth_permission") # Uses 300s TTL
539 async def check_permission(user: str, resource: str) -> bool:
540 return await openfga.check(user, resource)
541 """
543 def decorator(func: Callable[P, T]) -> Callable[P, T]:
544 @functools.wraps(func)
545 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
546 """Async wrapper with caching"""
547 cache = get_cache()
549 # Generate cache key
550 key_parts = [key_prefix]
551 key_parts.extend(str(arg) for arg in args)
552 key_parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items()))
553 key = ":".join(key_parts)
555 # Hash if too long
556 if len(key) > 200:
557 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324
558 key = f"{key_prefix}:hash:{key_hash}"
560 with tracer.start_as_current_span(
561 f"cache.{key_prefix}",
562 attributes={"cache.key": key, "cache.level": level},
563 ) as span:
564 # Try cache
565 if cached_value := cache.get(key, level=level):
566 span.set_attribute("cache.hit", True)
567 return cached_value # type: ignore[no-any-return]
569 span.set_attribute("cache.hit", False)
571 # Cache miss: call function
572 result = await func(*args, **kwargs) # type: ignore[misc]
574 # Store in cache
575 cache.set(key, result, ttl=ttl, level=level)
577 return result # type: ignore[no-any-return]
579 @functools.wraps(func)
580 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
581 """Sync wrapper with caching"""
582 cache = get_cache()
584 # Generate cache key
585 key_parts = [key_prefix]
586 key_parts.extend(str(arg) for arg in args)
587 key_parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items()))
588 key = ":".join(key_parts)
590 if len(key) > 200: 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true
591 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324
592 key = f"{key_prefix}:hash:{key_hash}"
594 # Try cache
595 if cached_value := cache.get(key, level=level):
596 return cached_value # type: ignore[no-any-return]
598 # Cache miss: call function
599 result = func(*args, **kwargs)
601 # Store in cache
602 cache.set(key, result, ttl=ttl, level=level)
604 return result
606 # Return appropriate wrapper
607 if asyncio.iscoroutinefunction(func):
608 return async_wrapper # type: ignore[return-value] # Complex decorator typing
609 else:
610 return sync_wrapper
612 return decorator
615def cache_invalidate(key_pattern: str) -> None:
616 """
617 Invalidate cache entries matching pattern.
619 Args:
620 key_pattern: Pattern to match (e.g., "user:123:*" for all user 123 caches)
622 Usage:
623 # Invalidate all caches for a user
624 cache_invalidate("user:123:*")
626 # Invalidate all auth caches
627 cache_invalidate("auth:*")
628 """
629 cache = get_cache()
630 cache.clear(pattern=key_pattern)
631 logger.info(f"Cache invalidated: {key_pattern}")
634# Anthropic-specific prompt caching (L3)
635def create_anthropic_cached_message(system_prompt: str, messages: list) -> dict[str, Any]: # type: ignore[type-arg]
636 """
637 Create Anthropic message with prompt caching.
639 Args:
640 system_prompt: System prompt to cache
641 messages: List of messages
643 Returns:
644 Message dict with cache_control
645 """
646 return {
647 "model": "claude-sonnet-4-5-20250929",
648 "system": [
649 {
650 "type": "text",
651 "text": system_prompt,
652 "cache_control": {"type": "ephemeral"}, # Cache this!
653 }
654 ],
655 "messages": messages,
656 }
659# Helper function for cache key generation
660def generate_cache_key(*parts: Any, prefix: str = "", version: str = "v1") -> str:
661 """
662 Generate standardized cache key.
664 Args:
665 *parts: Key components
666 prefix: Key prefix (cache type)
667 version: Cache version (for invalidation)
669 Returns:
670 Cache key string
672 Usage:
673 key = generate_cache_key("user_123", "profile", prefix="user", version="v1")
674 # Returns: "user:user_123:profile:v1"
675 """
676 key_parts = [prefix] if prefix else []
677 key_parts.extend(str(part) for part in parts)
678 key_parts.append(version)
680 key = ":".join(key_parts)
682 # Hash if too long
683 if len(key) > 200:
684 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324
685 key = f"{prefix}:hash:{key_hash}:{version}"
687 return key