Coverage for src / mcp_server_langgraph / core / cache.py: 94%

210 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Multi-layer caching system for MCP server. 

3 

4Provides: 

5- L1: In-memory LRU cache (< 10ms, per-instance) 

6- L2: Redis distributed cache (< 50ms, shared) 

7- L3: Provider-native caching (Anthropic, Gemini prompt caching) 

8 

9Features: 

10- Automatic TTL and eviction 

11- Cache stampede prevention 

12- Metrics and observability 

13- Tiered cache promotion/demotion 

14 

15See ADR-0028 for design rationale. 

16""" 

17 

18import asyncio 

19import functools 

20import hashlib 

21import pickle 

22from collections.abc import Callable 

23from typing import Any, ParamSpec, TypeVar 

24from urllib.parse import urlparse, urlunparse 

25 

26import redis 

27from cachetools import TTLCache 

28 

29from mcp_server_langgraph.core.config import settings 

30from mcp_server_langgraph.observability.telemetry import logger, tracer 

31 

32P = ParamSpec("P") 

33T = TypeVar("T") 

34 

35# Cache TTL configurations (seconds) 

36CACHE_TTLS = { 

37 "auth_permission": 300, # 5 minutes - permissions don't change frequently 

38 "user_profile": 900, # 15 minutes - profiles update infrequently 

39 "llm_response": 3600, # 1 hour - deterministic responses 

40 "embedding": 86400, # 24 hours - embeddings are deterministic 

41 "prometheus_query": 60, # 1 minute - metrics change frequently 

42 "knowledge_base": 1800, # 30 minutes - search index updates periodically 

43 "feature_flag": 60, # 1 minute - fast rollout needed 

44} 

45 

46 

47def _build_redis_url_with_db(redis_url: str, db: int) -> str: 

48 """ 

49 Build Redis URL with database number using proper URL parsing. 

50 

51 This prevents malformed URLs that occur with simple string concatenation: 

52 - redis://localhost:6379/ + /2 → redis://localhost:6379//2 (double slash) 

53 - redis://localhost:6379/0 + /2 → redis://localhost:6379/0/2 (nested paths) 

54 

55 Args: 

56 redis_url: Base Redis URL (e.g., "redis://host:port" or "redis://:password@host:port") 

57 db: Database number to use 

58 

59 Returns: 

60 Properly formatted Redis URL with database number (e.g., "redis://host:port/2") 

61 

62 Example: 

63 >>> _build_redis_url_with_db("redis://localhost:6379", 2) 

64 'redis://localhost:6379/2' 

65 >>> _build_redis_url_with_db("redis://localhost:6379/", 2) 

66 'redis://localhost:6379/2' 

67 >>> _build_redis_url_with_db("redis://localhost:6379/0", 2) 

68 'redis://localhost:6379/2' 

69 """ 

70 # Parse the URL to handle existing database numbers, trailing slashes, query params 

71 parsed = urlparse(redis_url) 

72 

73 # Replace path with correct database number 

74 # Redis URL path is typically empty or /db_number 

75 new_path = f"/{db}" 

76 

77 # Reconstruct URL with new database number 

78 # This pattern matches dependencies.py:189-212 (API key manager) 

79 return urlunparse( 

80 ( 

81 parsed.scheme, # redis:// or rediss:// 

82 parsed.netloc, # host:port (or user:pass@host:port) 

83 new_path, # /db_number 

84 parsed.params, # unused in Redis URLs 

85 parsed.query, # query parameters (if any) 

86 parsed.fragment, # fragment (if any) 

87 ) 

88 ) 

89 

90 

91class CacheLayer: 

92 """Enum for cache layers""" 

93 

94 L1 = "l1" # In-memory LRU 

95 L2 = "l2" # Redis distributed 

96 L3 = "l3" # Provider-native 

97 

98 

99class CacheService: 

100 """ 

101 Unified caching service with L1 + L2 layers. 

102 

103 Usage: 

104 cache = CacheService() 

105 

106 # Set value 

107 cache.set("user:profile:123", user_data, ttl=900) 

108 

109 # Get value (tries L1 → L2 → None) 

110 user_data = cache.get("user:profile:123") 

111 

112 # Delete value 

113 cache.delete("user:profile:123") 

114 

115 # Clear all caches 

116 cache.clear() 

117 """ 

118 

119 def __init__( 

120 self, 

121 l1_maxsize: int = 1000, 

122 l1_ttl: int = 60, 

123 redis_url: str | None = None, 

124 redis_db: int = 2, 

125 redis_password: str | None = None, 

126 redis_ssl: bool = False, 

127 ): 

128 """ 

129 Initialize cache service. 

130 

131 Args: 

132 l1_maxsize: Max entries in L1 cache 

133 l1_ttl: Default TTL for L1 cache in seconds 

134 redis_url: Redis connection URL (e.g., "redis://host:port") 

135 redis_db: Redis database number (default: 2 for cache) 

136 redis_password: Redis password for authentication 

137 redis_ssl: Enable SSL/TLS for Redis connection 

138 """ 

139 # L1: In-memory cache (per-instance) 

140 self.l1_cache = TTLCache(maxsize=l1_maxsize, ttl=l1_ttl) 

141 self.l1_maxsize = l1_maxsize 

142 self.l1_ttl = l1_ttl 

143 

144 # L2: Redis cache (shared across instances) 

145 # Use redis.from_url() pattern (same as API key manager) instead of 

146 # Redis(host=..., port=...) to properly handle password and SSL 

147 

148 # Get Redis config from settings if not provided 

149 if redis_url is None: 

150 redis_url = getattr(settings, "redis_url", "redis://localhost:6379") 

151 if redis_password is None: 

152 redis_password = getattr(settings, "redis_password", None) 

153 if redis_ssl is False: 

154 redis_ssl = getattr(settings, "redis_ssl", False) 

155 

156 # Declare redis with Optional type for proper type checking 

157 self.redis: redis.Redis[bytes] | None = None # type: ignore[type-arg] 

158 

159 try: 

160 # Build Redis URL with database number using helper function 

161 # This prevents malformed URLs from simple string concatenation 

162 redis_url_with_db = _build_redis_url_with_db(redis_url, redis_db) 

163 

164 # Build connection kwargs 

165 # NOTE: Only pass ssl=True when SSL is enabled. Passing ssl=False causes 

166 # "AbstractConnection.__init__() got an unexpected keyword argument 'ssl'" 

167 # because the non-SSL connection class doesn't accept the ssl parameter. 

168 connection_kwargs: dict[str, Any] = { 

169 "decode_responses": False, # Keep binary for pickle 

170 "socket_connect_timeout": 2, 

171 "socket_timeout": 2, 

172 } 

173 if redis_password: 

174 connection_kwargs["password"] = redis_password 

175 if redis_ssl: 

176 connection_kwargs["ssl"] = True 

177 

178 # Create Redis client using from_url() with full configuration 

179 # This matches the pattern in dependencies.py:215-220 (API key manager) 

180 self.redis = redis.from_url( # type: ignore[no-untyped-call] 

181 redis_url_with_db, 

182 **connection_kwargs, 

183 ) 

184 # Test connection 

185 self.redis.ping() 

186 self.redis_available = True 

187 logger.info( 

188 "Redis cache initialized", 

189 extra={"redis_url": redis_url, "db": redis_db, "ssl": redis_ssl}, 

190 ) 

191 except Exception as e: 

192 logger.warning( 

193 f"Redis cache unavailable, L2 cache disabled: {e}", 

194 extra={"redis_url": redis_url}, 

195 ) 

196 self.redis = None 

197 self.redis_available = False 

198 

199 # Cache stampede prevention locks 

200 self._refresh_locks: dict[str, asyncio.Lock] = {} 

201 

202 # Statistics 

203 self.stats = { 

204 "l1_hits": 0, 

205 "l1_misses": 0, 

206 "l2_hits": 0, 

207 "l2_misses": 0, 

208 "sets": 0, 

209 "deletes": 0, 

210 } 

211 

212 def get(self, key: str, level: str = CacheLayer.L2) -> Any | None: 

213 """ 

214 Get value from cache (L1 → L2 → None). 

215 

216 Args: 

217 key: Cache key 

218 level: Cache level to search (l1 or l2) 

219 

220 Returns: 

221 Cached value or None if not found 

222 """ 

223 # Try L1 first 

224 if level in (CacheLayer.L1, CacheLayer.L2) and key in self.l1_cache: 

225 self.stats["l1_hits"] += 1 

226 logger.debug(f"L1 cache hit: {key}") 

227 

228 # Emit metric 

229 self._emit_cache_hit_metric(CacheLayer.L1, key) 

230 

231 return self.l1_cache[key] 

232 

233 self.stats["l1_misses"] += 1 

234 

235 # Try L2 if enabled 

236 if level == CacheLayer.L2 and self.redis_available and self.redis is not None: 

237 try: 

238 data = self.redis.get(key) 

239 if data: 239 ↛ 255line 239 didn't jump to line 255 because the condition on line 239 was always true

240 value = pickle.loads(data) # type: ignore[arg-type] 

241 

242 # Promote to L1 

243 self.l1_cache[key] = value 

244 

245 self.stats["l2_hits"] += 1 

246 logger.debug(f"L2 cache hit: {key}") 

247 

248 # Emit metric 

249 self._emit_cache_hit_metric(CacheLayer.L2, key) 

250 

251 return value 

252 except Exception as e: 

253 logger.warning(f"L2 cache get failed: {e}", extra={"key": key}) 

254 

255 self.stats["l2_misses"] += 1 

256 

257 # Emit cache miss metric 

258 self._emit_cache_miss_metric(level, key) 

259 

260 return None 

261 

262 def set( # type: ignore[no-untyped-def] 

263 self, 

264 key: str, 

265 value: Any, 

266 ttl: int | None = None, 

267 level: str = CacheLayer.L2, 

268 ): 

269 """ 

270 Set value in cache. 

271 

272 Args: 

273 key: Cache key 

274 value: Value to cache 

275 ttl: Time-to-live in seconds (default: from cache type) 

276 level: Cache level (l1, l2) 

277 """ 

278 # Use default TTL for cache type 

279 if ttl is None: 

280 ttl = self._get_ttl_from_key(key) 

281 

282 # Set in L1 

283 if level in (CacheLayer.L1, CacheLayer.L2): 283 ↛ 287line 283 didn't jump to line 287 because the condition on line 283 was always true

284 self.l1_cache[key] = value 

285 

286 # Set in L2 

287 if level == CacheLayer.L2 and self.redis_available and self.redis is not None: 

288 try: 

289 self.redis.setex(key, ttl, pickle.dumps(value)) 

290 logger.debug(f"L2 cache set: {key} (TTL: {ttl}s)") 

291 except Exception as e: 

292 logger.warning(f"L2 cache set failed: {e}", extra={"key": key}) 

293 

294 self.stats["sets"] += 1 

295 

296 # Emit metric 

297 self._emit_cache_set_metric(level, key) 

298 

299 def delete(self, key: str) -> None: 

300 """ 

301 Delete from all cache levels. 

302 

303 Args: 

304 key: Cache key to delete 

305 """ 

306 # Delete from L1 

307 self.l1_cache.pop(key, None) 

308 

309 # Delete from L2 

310 if self.redis_available and self.redis is not None: 

311 try: 

312 self.redis.delete(key) 

313 except Exception as e: 

314 logger.warning(f"L2 cache delete failed: {e}") 

315 

316 self.stats["deletes"] += 1 

317 

318 def clear(self, pattern: str | None = None) -> None: 

319 """ 

320 Clear cache (all or by pattern). 

321 

322 Args: 

323 pattern: Redis key pattern (e.g., "user:*") or None for all 

324 """ 

325 # Clear L1 

326 self.l1_cache.clear() 

327 

328 # Clear L2 

329 if self.redis_available and self.redis is not None: 

330 try: 

331 # SECURITY FIX (OpenAI Codex Finding #6): Always use pattern-based deletion 

332 # Never use flushdb() as it clears the ENTIRE Redis database, including 

333 # other data structures that may share the same DB (e.g., API key cache). 

334 # Use pattern="*" for "clear all" instead of flushdb(). 

335 search_pattern = pattern if pattern else "*" 

336 keys = self.redis.keys(search_pattern) 

337 if keys: 337 ↛ 341line 337 didn't jump to line 341 because the condition on line 337 was always true

338 self.redis.delete(*keys) # type: ignore[misc] 

339 logger.info(f"Cleared L2 cache by pattern: {search_pattern} ({len(keys)} keys)") # type: ignore[arg-type] 

340 else: 

341 logger.info(f"No L2 cache keys matched pattern: {search_pattern}") 

342 except Exception as e: 

343 logger.warning(f"L2 cache clear failed: {e}") 

344 

345 async def get_with_lock( 

346 self, 

347 key: str, 

348 fetcher: Callable, # type: ignore[type-arg] 

349 ttl: int | None = None, 

350 ) -> Any: 

351 """ 

352 Get from cache or fetch with lock (prevents cache stampede). 

353 

354 Args: 

355 key: Cache key 

356 fetcher: Async function to fetch data if cache miss 

357 ttl: Time-to-live in seconds 

358 

359 Returns: 

360 Cached or fetched value 

361 """ 

362 # Try cache first 

363 if cached := self.get(key): 

364 return cached 

365 

366 # Acquire lock for this key 

367 if key not in self._refresh_locks: 

368 self._refresh_locks[key] = asyncio.Lock() 

369 

370 async with self._refresh_locks[key]: 

371 # Double-check cache (another request may have filled it) 

372 if cached := self.get(key): 

373 return cached 

374 

375 # Fetch and cache 

376 value = await fetcher() if asyncio.iscoroutinefunction(fetcher) else fetcher() 

377 self.set(key, value, ttl) 

378 

379 return value 

380 

381 def _get_ttl_from_key(self, key: str) -> int: 

382 """ 

383 Determine TTL from cache key prefix. 

384 

385 Args: 

386 key: Cache key (format: "type:...") 

387 

388 Returns: 

389 TTL in seconds 

390 """ 

391 # Extract cache type from key prefix 

392 cache_type = key.split(":")[0] if ":" in key else "default" 

393 

394 return CACHE_TTLS.get(cache_type, 300) # Default 5 minutes 

395 

396 def _emit_cache_hit_metric(self, layer: str, key: str) -> None: 

397 """Emit cache hit metric""" 

398 try: 

399 from mcp_server_langgraph.observability.telemetry import config 

400 

401 cache_type = key.split(":")[0] if ":" in key else "unknown" 

402 

403 # Create metric if doesn't exist 

404 if not hasattr(config, "cache_hits_counter"): 

405 config.cache_hits_counter = config.meter.create_counter( 

406 name="cache.hits", 

407 description="Total cache hits", 

408 unit="1", 

409 ) 

410 

411 config.cache_hits_counter.add( 

412 1, 

413 attributes={"layer": layer, "cache_type": cache_type}, 

414 ) 

415 except Exception: 

416 pass # Don't let metrics failure break caching 

417 

418 def _emit_cache_miss_metric(self, layer: str, key: str) -> None: 

419 """Emit cache miss metric""" 

420 try: 

421 from mcp_server_langgraph.observability.telemetry import config 

422 

423 cache_type = key.split(":")[0] if ":" in key else "unknown" 

424 

425 if not hasattr(config, "cache_misses_counter"): 

426 config.cache_misses_counter = config.meter.create_counter( 

427 name="cache.misses", 

428 description="Total cache misses", 

429 unit="1", 

430 ) 

431 

432 config.cache_misses_counter.add( 

433 1, 

434 attributes={"layer": layer, "cache_type": cache_type}, 

435 ) 

436 except Exception: 

437 pass 

438 

439 def _emit_cache_set_metric(self, layer: str, key: str) -> None: 

440 """Emit cache set metric""" 

441 try: 

442 from mcp_server_langgraph.observability.telemetry import config 

443 

444 cache_type = key.split(":")[0] if ":" in key else "unknown" 

445 

446 if not hasattr(config, "cache_sets_counter"): 

447 config.cache_sets_counter = config.meter.create_counter( 

448 name="cache.sets", 

449 description="Total cache sets", 

450 unit="1", 

451 ) 

452 

453 config.cache_sets_counter.add( 

454 1, 

455 attributes={"layer": layer, "cache_type": cache_type}, 

456 ) 

457 except Exception: 

458 pass 

459 

460 def get_statistics(self) -> dict[str, Any]: 

461 """ 

462 Get cache statistics. 

463 

464 Returns: 

465 Dictionary with cache stats 

466 """ 

467 hit_rate_l1 = ( 

468 self.stats["l1_hits"] / (self.stats["l1_hits"] + self.stats["l1_misses"]) 

469 if (self.stats["l1_hits"] + self.stats["l1_misses"]) > 0 

470 else 0.0 

471 ) 

472 

473 hit_rate_l2 = ( 

474 self.stats["l2_hits"] / (self.stats["l2_hits"] + self.stats["l2_misses"]) 

475 if (self.stats["l2_hits"] + self.stats["l2_misses"]) > 0 

476 else 0.0 

477 ) 

478 

479 return { 

480 "l1": { 

481 "hits": self.stats["l1_hits"], 

482 "misses": self.stats["l1_misses"], 

483 "hit_rate": hit_rate_l1, 

484 "size": len(self.l1_cache), 

485 "max_size": self.l1_maxsize, 

486 }, 

487 "l2": { 

488 "hits": self.stats["l2_hits"], 

489 "misses": self.stats["l2_misses"], 

490 "hit_rate": hit_rate_l2, 

491 "available": self.redis_available, 

492 }, 

493 "total": { 

494 "sets": self.stats["sets"], 

495 "deletes": self.stats["deletes"], 

496 }, 

497 } 

498 

499 

500# Global cache instance 

501_cache_service: CacheService | None = None 

502 

503 

504def get_cache() -> CacheService: 

505 """Get global cache service (singleton)""" 

506 global _cache_service 

507 if _cache_service is None: 

508 _cache_service = CacheService( 

509 l1_maxsize=getattr(settings, "l1_cache_size", 1000), 

510 l1_ttl=getattr(settings, "l1_cache_ttl", 60), 

511 redis_url=getattr(settings, "redis_url", "redis://localhost:6379"), 

512 redis_db=getattr(settings, "redis_cache_db", 2), 

513 redis_password=getattr(settings, "redis_password", None), 

514 redis_ssl=getattr(settings, "redis_ssl", False), 

515 ) 

516 return _cache_service 

517 

518 

519def cached( 

520 key_prefix: str, 

521 ttl: int | None = None, 

522 level: str = CacheLayer.L2, 

523) -> Callable[[Callable[P, T]], Callable[P, T]]: 

524 """ 

525 Decorator for caching async function results. 

526 

527 Args: 

528 key_prefix: Prefix for cache key 

529 ttl: Time-to-live in seconds (default: auto from key_prefix) 

530 level: Cache level (l1 or l2) 

531 

532 Usage: 

533 @cached(key_prefix="user:profile", ttl=900) 

534 async def get_user_profile(user_id: str) -> dict[str, Any]: 

535 return await db.get_user(user_id) 

536 

537 # Auto TTL from key prefix 

538 @cached(key_prefix="auth_permission") # Uses 300s TTL 

539 async def check_permission(user: str, resource: str) -> bool: 

540 return await openfga.check(user, resource) 

541 """ 

542 

543 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

544 @functools.wraps(func) 

545 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

546 """Async wrapper with caching""" 

547 cache = get_cache() 

548 

549 # Generate cache key 

550 key_parts = [key_prefix] 

551 key_parts.extend(str(arg) for arg in args) 

552 key_parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items())) 

553 key = ":".join(key_parts) 

554 

555 # Hash if too long 

556 if len(key) > 200: 

557 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324 

558 key = f"{key_prefix}:hash:{key_hash}" 

559 

560 with tracer.start_as_current_span( 

561 f"cache.{key_prefix}", 

562 attributes={"cache.key": key, "cache.level": level}, 

563 ) as span: 

564 # Try cache 

565 if cached_value := cache.get(key, level=level): 

566 span.set_attribute("cache.hit", True) 

567 return cached_value # type: ignore[no-any-return] 

568 

569 span.set_attribute("cache.hit", False) 

570 

571 # Cache miss: call function 

572 result = await func(*args, **kwargs) # type: ignore[misc] 

573 

574 # Store in cache 

575 cache.set(key, result, ttl=ttl, level=level) 

576 

577 return result # type: ignore[no-any-return] 

578 

579 @functools.wraps(func) 

580 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

581 """Sync wrapper with caching""" 

582 cache = get_cache() 

583 

584 # Generate cache key 

585 key_parts = [key_prefix] 

586 key_parts.extend(str(arg) for arg in args) 

587 key_parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items())) 

588 key = ":".join(key_parts) 

589 

590 if len(key) > 200: 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true

591 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324 

592 key = f"{key_prefix}:hash:{key_hash}" 

593 

594 # Try cache 

595 if cached_value := cache.get(key, level=level): 

596 return cached_value # type: ignore[no-any-return] 

597 

598 # Cache miss: call function 

599 result = func(*args, **kwargs) 

600 

601 # Store in cache 

602 cache.set(key, result, ttl=ttl, level=level) 

603 

604 return result 

605 

606 # Return appropriate wrapper 

607 if asyncio.iscoroutinefunction(func): 

608 return async_wrapper # type: ignore[return-value] # Complex decorator typing 

609 else: 

610 return sync_wrapper 

611 

612 return decorator 

613 

614 

615def cache_invalidate(key_pattern: str) -> None: 

616 """ 

617 Invalidate cache entries matching pattern. 

618 

619 Args: 

620 key_pattern: Pattern to match (e.g., "user:123:*" for all user 123 caches) 

621 

622 Usage: 

623 # Invalidate all caches for a user 

624 cache_invalidate("user:123:*") 

625 

626 # Invalidate all auth caches 

627 cache_invalidate("auth:*") 

628 """ 

629 cache = get_cache() 

630 cache.clear(pattern=key_pattern) 

631 logger.info(f"Cache invalidated: {key_pattern}") 

632 

633 

634# Anthropic-specific prompt caching (L3) 

635def create_anthropic_cached_message(system_prompt: str, messages: list) -> dict[str, Any]: # type: ignore[type-arg] 

636 """ 

637 Create Anthropic message with prompt caching. 

638 

639 Args: 

640 system_prompt: System prompt to cache 

641 messages: List of messages 

642 

643 Returns: 

644 Message dict with cache_control 

645 """ 

646 return { 

647 "model": "claude-sonnet-4-5-20250929", 

648 "system": [ 

649 { 

650 "type": "text", 

651 "text": system_prompt, 

652 "cache_control": {"type": "ephemeral"}, # Cache this! 

653 } 

654 ], 

655 "messages": messages, 

656 } 

657 

658 

659# Helper function for cache key generation 

660def generate_cache_key(*parts: Any, prefix: str = "", version: str = "v1") -> str: 

661 """ 

662 Generate standardized cache key. 

663 

664 Args: 

665 *parts: Key components 

666 prefix: Key prefix (cache type) 

667 version: Cache version (for invalidation) 

668 

669 Returns: 

670 Cache key string 

671 

672 Usage: 

673 key = generate_cache_key("user_123", "profile", prefix="user", version="v1") 

674 # Returns: "user:user_123:profile:v1" 

675 """ 

676 key_parts = [prefix] if prefix else [] 

677 key_parts.extend(str(part) for part in parts) 

678 key_parts.append(version) 

679 

680 key = ":".join(key_parts) 

681 

682 # Hash if too long 

683 if len(key) > 200: 

684 key_hash = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest() # nosec B324 

685 key = f"{prefix}:hash:{key_hash}:{version}" 

686 

687 return key