Coverage for src / mcp_server_langgraph / core / dynamic_context_loader.py: 15%

237 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Dynamic Context Loading with Qdrant Vector Store 

3 

4Implements Anthropic's Just-in-Time context loading strategy: 

5- Semantic search for relevant context 

6- Progressive discovery patterns 

7- Lightweight context references 

8""" 

9 

10import asyncio 

11import base64 

12import time 

13from datetime import datetime, timedelta, UTC 

14from functools import lru_cache 

15from typing import Any 

16 

17from cryptography.fernet import Fernet 

18from langchain_core.embeddings import Embeddings 

19from langchain_core.messages import BaseMessage, SystemMessage 

20from pydantic import BaseModel, Field 

21from qdrant_client import QdrantClient 

22from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams 

23 

24from mcp_server_langgraph.core.config import settings 

25from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

26from mcp_server_langgraph.utils.response_optimizer import count_tokens 

27 

28 

29def _create_embeddings( 

30 provider: str, 

31 model_name: str, 

32 google_api_key: str | None = None, 

33 task_type: str | None = None, 

34) -> Embeddings: 

35 """ 

36 Create embeddings instance based on provider. 

37 

38 Args: 

39 provider: "google" for Gemini API or "local" for sentence-transformers 

40 model_name: Model name (e.g., "models/text-embedding-004" or "all-MiniLM-L6-v2") 

41 google_api_key: Google API key (required for "google" provider) 

42 task_type: Task type for Google embeddings optimization 

43 

44 Returns: 

45 Embeddings instance 

46 

47 Raises: 

48 ValueError: If provider is unsupported or required API key is missing 

49 """ 

50 if provider == "google": 

51 try: 

52 from langchain_google_genai import GoogleGenerativeAIEmbeddings 

53 except ImportError: 

54 msg = ( 

55 "langchain-google-genai is required for Google embeddings. " 

56 "Add 'langchain-google-genai' to pyproject.toml dependencies, then run: uv sync" 

57 ) 

58 raise ImportError(msg) 

59 

60 if not google_api_key: 

61 msg = "GOOGLE_API_KEY is required for Google embeddings. Set via environment variable or Infisical." 

62 raise ValueError(msg) 

63 

64 # Create Google embeddings with task type optimization 

65 from pydantic import SecretStr 

66 

67 embeddings: Embeddings = GoogleGenerativeAIEmbeddings( 

68 model=model_name, 

69 google_api_key=SecretStr(google_api_key), 

70 task_type=task_type or "RETRIEVAL_DOCUMENT", 

71 ) 

72 

73 logger.info( 

74 "Initialized Google Gemini embeddings", 

75 extra={"model": model_name, "task_type": task_type}, 

76 ) 

77 

78 return embeddings 

79 

80 elif provider == "local": 

81 try: 

82 from sentence_transformers import SentenceTransformer 

83 

84 class SentenceTransformerEmbeddings(Embeddings): 

85 """Wrapper to make SentenceTransformer compatible with LangChain Embeddings interface.""" 

86 

87 def __init__(self, model_name: str) -> None: 

88 self.model = SentenceTransformer(model_name) 

89 

90 def embed_documents(self, texts: list[str]) -> list[list[float]]: 

91 """Embed multiple documents.""" 

92 embeddings = self.model.encode(texts) 

93 return embeddings.tolist() # type: ignore[no-any-return] 

94 

95 def embed_query(self, text: str) -> list[float]: 

96 """Embed a single query.""" 

97 embedding = self.model.encode(text) 

98 return embedding.tolist() # type: ignore[no-any-return] 

99 

100 embeddings = SentenceTransformerEmbeddings(model_name) 

101 

102 logger.info( 

103 "Initialized local sentence-transformers embeddings", 

104 extra={"model": model_name}, 

105 ) 

106 

107 return embeddings 

108 

109 except ImportError: 

110 msg = ( 

111 "sentence-transformers is required for local embeddings. " 

112 "Add 'sentence-transformers' to pyproject.toml dependencies, then run: uv sync" 

113 ) 

114 raise ImportError(msg) 

115 

116 else: 

117 msg = f"Unsupported embedding provider: {provider}. Supported providers: 'google', 'local'" 

118 raise ValueError(msg) 

119 

120 

121class ContextReference(BaseModel): 

122 """Lightweight reference to context that can be loaded on demand.""" 

123 

124 ref_id: str = Field(description="Unique identifier for this context") 

125 ref_type: str = Field(description="Type: conversation, document, tool_usage, file") 

126 summary: str = Field(description="Brief summary for filtering (< 100 chars)") 

127 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") 

128 relevance_score: float | None = Field(default=None, description="Relevance score if from search") 

129 

130 

131class LoadedContext(BaseModel): 

132 """Full context loaded from a reference.""" 

133 

134 reference: ContextReference 

135 content: str = Field(description="Full context content") 

136 token_count: int = Field(description="Token count of content") 

137 loaded_at: float = Field(description="Timestamp when loaded") 

138 

139 

140class DynamicContextLoader: 

141 """ 

142 Manages just-in-time context loading with semantic search. 

143 

144 Follows Anthropic's recommendations: 

145 1. Store lightweight identifiers instead of full context 

146 2. Load context dynamically when needed 

147 3. Use semantic search for relevance 

148 4. Progressive discovery through iterative search 

149 """ 

150 

151 def __init__( 

152 self, 

153 qdrant_url: str | None = None, 

154 qdrant_port: int | None = None, 

155 collection_name: str | None = None, 

156 embedding_model: str | None = None, 

157 embedding_provider: str | None = None, 

158 embedding_dimensions: int | None = None, 

159 cache_size: int | None = None, 

160 ): 

161 """ 

162 Initialize dynamic context loader with encryption and retention support. 

163 

164 Args: 

165 qdrant_url: Qdrant server URL (defaults to settings.qdrant_url) 

166 qdrant_port: Qdrant server port (defaults to settings.qdrant_port) 

167 collection_name: Name of Qdrant collection (defaults to settings.qdrant_collection_name) 

168 embedding_model: Model name (defaults to settings.embedding_model_name) 

169 embedding_provider: "google" or "local" (defaults to settings.embedding_provider) 

170 embedding_dimensions: Embedding dimensions (defaults to settings.embedding_dimensions) 

171 cache_size: LRU cache size for loaded contexts (defaults to settings.context_cache_size) 

172 

173 Note: 

174 For regulated workloads (HIPAA, GDPR, etc.): 

175 - Set enable_context_encryption=True in settings 

176 - Configure retention period with context_retention_days 

177 - For multi-tenant isolation, use enable_multi_tenant_isolation=True 

178 

179 Migration from sentence-transformers: 

180 - Existing collections with different dimensions need recreation 

181 - Google embeddings (768) vs sentence-transformers (384) 

182 """ 

183 self.qdrant_url = qdrant_url or settings.qdrant_url 

184 self.qdrant_port = qdrant_port or settings.qdrant_port 

185 self.collection_name = collection_name or settings.qdrant_collection_name 

186 

187 # Embedding configuration 

188 self.embedding_provider = embedding_provider or settings.embedding_provider 

189 # Support both new and legacy config parameter names 

190 self.embedding_model_name = embedding_model or settings.embedding_model_name 

191 self.embedding_dim = embedding_dimensions or settings.embedding_dimensions 

192 

193 cache_size = cache_size or settings.context_cache_size 

194 

195 # Security & Compliance configuration 

196 self.enable_encryption = settings.enable_context_encryption 

197 self.retention_days = settings.context_retention_days 

198 self.enable_auto_deletion = settings.enable_auto_deletion 

199 

200 # Initialize encryption if enabled 

201 self.cipher: Fernet | None = None 

202 if self.enable_encryption: 

203 if not settings.context_encryption_key: 

204 msg = ( 

205 "CRITICAL: Context encryption enabled but no encryption key configured. " 

206 "Set CONTEXT_ENCRYPTION_KEY environment variable or configure via Infisical." 

207 ) 

208 raise ValueError(msg) 

209 # Fernet requires base64-encoded 32-byte key 

210 try: 

211 self.cipher = Fernet(settings.context_encryption_key.encode()) 

212 except Exception as e: 

213 msg = f"Invalid encryption key format: {e}. Generate with: Fernet.generate_key()" 

214 raise ValueError(msg) 

215 

216 # Initialize Qdrant client 

217 self.client = QdrantClient(host=self.qdrant_url, port=self.qdrant_port) 

218 

219 # Initialize embeddings 

220 self.embedder = _create_embeddings( 

221 provider=self.embedding_provider, 

222 model_name=self.embedding_model_name, 

223 google_api_key=settings.google_api_key, 

224 task_type=settings.embedding_task_type, 

225 ) 

226 

227 # Create collection if it doesn't exist 

228 self._ensure_collection_exists() 

229 

230 # LRU cache for loaded contexts 

231 self._load_context_cached = lru_cache(maxsize=cache_size)(self._load_context_impl) 

232 

233 logger.info( 

234 "DynamicContextLoader initialized", 

235 extra={ 

236 "qdrant_url": self.qdrant_url, 

237 "collection": self.collection_name, 

238 "embedding_provider": self.embedding_provider, 

239 "embedding_model": self.embedding_model_name, 

240 "embedding_dim": self.embedding_dim, 

241 "encryption_enabled": self.enable_encryption, 

242 "retention_days": self.retention_days, 

243 "auto_deletion_enabled": self.enable_auto_deletion, 

244 }, 

245 ) 

246 

247 def _encrypt_content(self, content: str) -> str: 

248 """ 

249 Encrypt content for storage (encryption-at-rest). 

250 

251 Args: 

252 content: Plain text content 

253 

254 Returns: 

255 Base64-encoded encrypted content 

256 """ 

257 if not self.cipher: 

258 return content 

259 

260 encrypted_bytes = self.cipher.encrypt(content.encode()) 

261 return base64.b64encode(encrypted_bytes).decode() 

262 

263 def _decrypt_content(self, encrypted_content: str) -> str: 

264 """ 

265 Decrypt content from storage. 

266 

267 Args: 

268 encrypted_content: Base64-encoded encrypted content 

269 

270 Returns: 

271 Decrypted plain text 

272 """ 

273 if not self.cipher: 

274 return encrypted_content 

275 

276 try: 

277 encrypted_bytes = base64.b64decode(encrypted_content.encode()) 

278 decrypted_bytes = self.cipher.decrypt(encrypted_bytes) 

279 return decrypted_bytes.decode() 

280 except Exception as e: 

281 logger.error(f"Decryption failed: {e}", exc_info=True) 

282 msg = f"Failed to decrypt content: {e}" 

283 raise ValueError(msg) 

284 

285 def _calculate_expiry_timestamp(self) -> float: 

286 """Calculate expiry timestamp based on retention policy.""" 

287 expiry_date = datetime.now(UTC) + timedelta(days=self.retention_days) 

288 return expiry_date.timestamp() 

289 

290 def _ensure_collection_exists(self) -> None: 

291 """Create Qdrant collection if it doesn't exist.""" 

292 try: 

293 collections = self.client.get_collections().collections 

294 exists = any(c.name == self.collection_name for c in collections) 

295 

296 if not exists: 

297 self.client.create_collection( 

298 collection_name=self.collection_name, 

299 vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE), 

300 ) 

301 logger.info(f"Created Qdrant collection: {self.collection_name}") 

302 else: 

303 logger.info(f"Qdrant collection exists: {self.collection_name}") 

304 except Exception as e: 

305 logger.error(f"Failed to ensure Qdrant collection: {e}", exc_info=True) 

306 raise 

307 

308 async def index_context( 

309 self, 

310 ref_id: str, 

311 content: str, 

312 ref_type: str, 

313 summary: str, 

314 metadata: dict[str, Any] | None = None, 

315 ) -> None: 

316 """ 

317 Index context for later retrieval. 

318 

319 Args: 

320 ref_id: Unique identifier 

321 content: Full content to index 

322 ref_type: Type of context 

323 summary: Brief summary 

324 metadata: Additional metadata 

325 """ 

326 with tracer.start_as_current_span("context.index") as span: 

327 span.set_attribute("ref_id", ref_id) 

328 span.set_attribute("ref_type", ref_type) 

329 

330 try: 

331 # Generate embedding using LangChain Embeddings interface 

332 embedding = await asyncio.to_thread(self.embedder.embed_query, content) 

333 

334 # Encrypt content if encryption is enabled 

335 stored_content = self._encrypt_content(content) if self.enable_encryption else content 

336 

337 # Calculate expiry timestamp for retention 

338 created_at = datetime.now(UTC).timestamp() 

339 expires_at = self._calculate_expiry_timestamp() 

340 

341 # Create point with encryption and retention support 

342 point = PointStruct( 

343 id=ref_id, 

344 vector=embedding, # Already a list from embed_query 

345 payload={ 

346 "ref_id": ref_id, 

347 "ref_type": ref_type, 

348 "summary": summary, 

349 "content": stored_content, # Encrypted if encryption enabled 

350 "token_count": count_tokens(content), 

351 "metadata": metadata or {}, 

352 "created_at": created_at, # For retention tracking 

353 "expires_at": expires_at, # For auto-deletion 

354 "encrypted": self.enable_encryption, # Flag for decryption 

355 }, 

356 ) 

357 

358 # Upsert to Qdrant 

359 await asyncio.to_thread(self.client.upsert, collection_name=self.collection_name, points=[point]) 

360 

361 logger.info(f"Indexed context: {ref_id}", extra={"ref_type": ref_type, "summary": summary}) 

362 metrics.successful_calls.add(1, {"operation": "index_context", "type": ref_type}) 

363 

364 except Exception as e: 

365 logger.error(f"Failed to index context: {e}", exc_info=True) 

366 metrics.failed_calls.add(1, {"operation": "index_context", "error": type(e).__name__}) 

367 raise 

368 

369 async def semantic_search( 

370 self, 

371 query: str, 

372 top_k: int = 5, 

373 ref_type_filter: str | None = None, 

374 min_score: float = 0.5, 

375 ) -> list[ContextReference]: 

376 """ 

377 Search for relevant context using semantic similarity. 

378 

379 Implements Anthropic's recommendation for search-focused retrieval. 

380 

381 Args: 

382 query: Search query 

383 top_k: Number of results 

384 ref_type_filter: Optional filter by ref_type 

385 min_score: Minimum similarity score (0-1) 

386 

387 Returns: 

388 List of context references sorted by relevance 

389 """ 

390 with tracer.start_as_current_span("context.semantic_search") as span: 

391 span.set_attribute("query", query) 

392 span.set_attribute("top_k", top_k) 

393 

394 try: 

395 # Generate query embedding using LangChain Embeddings interface 

396 query_embedding = await asyncio.to_thread(self.embedder.embed_query, query) 

397 

398 # Build filter 

399 search_filter = None 

400 if ref_type_filter: 

401 search_filter = Filter(must=[FieldCondition(key="ref_type", match=MatchValue(value=ref_type_filter))]) 

402 

403 # Search Qdrant 

404 results = await asyncio.to_thread( 

405 self.client.search, # type: ignore[attr-defined] 

406 collection_name=self.collection_name, 

407 query_vector=query_embedding, # Already a list from embed_query 

408 limit=top_k, 

409 query_filter=search_filter, 

410 score_threshold=min_score, 

411 ) 

412 

413 # Convert to ContextReferences 

414 references = [] 

415 for result in results: 

416 payload = result.payload 

417 if payload is None: 

418 continue 

419 ref = ContextReference( 

420 ref_id=payload["ref_id"], 

421 ref_type=payload["ref_type"], 

422 summary=payload["summary"], 

423 metadata=payload.get("metadata", {}), 

424 relevance_score=result.score, 

425 ) 

426 references.append(ref) 

427 

428 logger.info( 

429 "Semantic search completed", 

430 extra={"query": query, "results_found": len(references), "top_k": top_k}, 

431 ) 

432 

433 span.set_attribute("results.count", len(references)) 

434 metrics.successful_calls.add(1, {"operation": "semantic_search"}) 

435 

436 return references 

437 

438 except Exception as e: 

439 logger.error(f"Semantic search failed: {e}", exc_info=True) 

440 metrics.failed_calls.add(1, {"operation": "semantic_search", "error": type(e).__name__}) 

441 return [] 

442 

443 async def progressive_discover( 

444 self, 

445 initial_query: str, 

446 max_iterations: int = 3, 

447 expansion_keywords: list[str] | None = None, 

448 ) -> list[ContextReference]: 

449 """ 

450 Progressive discovery: iteratively refine search based on results. 

451 

452 Implements Anthropic's "Progressive Disclosure" pattern. 

453 

454 Args: 

455 initial_query: Starting search query 

456 max_iterations: Maximum search iterations 

457 expansion_keywords: Keywords to expand search 

458 

459 Returns: 

460 Aggregated context references from all iterations 

461 """ 

462 with tracer.start_as_current_span("context.progressive_discover") as span: 

463 span.set_attribute("initial_query", initial_query) 

464 span.set_attribute("max_iterations", max_iterations) 

465 

466 all_references = [] 

467 seen_ids = set() 

468 current_query = initial_query 

469 

470 for iteration in range(max_iterations): 

471 logger.info(f"Progressive discovery iteration {iteration + 1}/{max_iterations}") 

472 

473 # Search with current query 

474 results = await self.semantic_search(current_query, top_k=5) 

475 

476 # Add new results 

477 for ref in results: 

478 if ref.ref_id not in seen_ids: 

479 all_references.append(ref) 

480 seen_ids.add(ref.ref_id) 

481 

482 # Stop if no new results 

483 if not results: 

484 logger.info(f"No new results in iteration {iteration + 1}, stopping") 

485 break 

486 

487 # Expand query for next iteration 

488 if expansion_keywords and iteration < max_iterations - 1: 

489 current_query = f"{current_query} {expansion_keywords[iteration]}" 

490 

491 span.set_attribute("total_references", len(all_references)) 

492 span.set_attribute("iterations_completed", iteration + 1) 

493 

494 logger.info( 

495 "Progressive discovery completed", 

496 extra={"iterations": iteration + 1, "total_references": len(all_references)}, 

497 ) 

498 

499 return all_references 

500 

501 async def load_context(self, reference: ContextReference) -> LoadedContext: 

502 """ 

503 Load full context from a reference. 

504 

505 Uses LRU cache for frequently accessed contexts. 

506 

507 Args: 

508 reference: Context reference to load 

509 

510 Returns: 

511 Loaded context with full content 

512 """ 

513 with tracer.start_as_current_span("context.load") as span: 

514 span.set_attribute("ref_id", reference.ref_id) 

515 

516 # Use cached implementation 

517 loaded = await asyncio.to_thread(self._load_context_cached, reference.ref_id) 

518 

519 span.set_attribute("token_count", loaded.token_count) 

520 metrics.successful_calls.add(1, {"operation": "load_context", "type": reference.ref_type}) 

521 

522 return loaded 

523 

524 def _load_context_impl(self, ref_id: str) -> LoadedContext: 

525 """ 

526 Implementation of context loading (cached). 

527 

528 Args: 

529 ref_id: Reference ID to load 

530 

531 Returns: 

532 Loaded context 

533 """ 

534 try: 

535 # Retrieve from Qdrant 

536 results = self.client.retrieve(collection_name=self.collection_name, ids=[ref_id]) 

537 

538 if not results: 

539 msg = f"Context not found: {ref_id}" 

540 raise ValueError(msg) 

541 

542 result = results[0] 

543 payload = result.payload 

544 

545 if payload is None: 

546 msg = f"Context payload is None: {ref_id}" 

547 raise ValueError(msg) 

548 

549 reference = ContextReference( 

550 ref_id=payload["ref_id"], 

551 ref_type=payload["ref_type"], 

552 summary=payload["summary"], 

553 metadata=payload.get("metadata", {}), 

554 ) 

555 

556 # Decrypt content if it was encrypted 

557 stored_content = payload["content"] 

558 is_encrypted = payload.get("encrypted", False) 

559 content = self._decrypt_content(stored_content) if is_encrypted else stored_content 

560 

561 loaded = LoadedContext( 

562 reference=reference, 

563 content=content, # Decrypted content 

564 token_count=payload["token_count"], 

565 loaded_at=time.time(), 

566 ) 

567 

568 logger.info(f"Loaded context: {ref_id}", extra={"token_count": loaded.token_count}) 

569 

570 return loaded 

571 

572 except Exception as e: 

573 logger.error(f"Failed to load context {ref_id}: {e}", exc_info=True) 

574 metrics.failed_calls.add(1, {"operation": "load_context", "error": type(e).__name__}) 

575 raise 

576 

577 async def load_batch(self, references: list[ContextReference], max_tokens: int = 4000) -> list[LoadedContext]: 

578 """ 

579 Load multiple contexts up to token limit. 

580 

581 Implements token-aware batching. 

582 

583 Args: 

584 references: List of references to load 

585 max_tokens: Maximum total tokens 

586 

587 Returns: 

588 List of loaded contexts within token budget 

589 """ 

590 with tracer.start_as_current_span("context.load_batch") as span: 

591 loaded = [] 

592 total_tokens = 0 

593 

594 for ref in references: 

595 context = await self.load_context(ref) 

596 

597 if total_tokens + context.token_count <= max_tokens: 

598 loaded.append(context) 

599 total_tokens += context.token_count 

600 else: 

601 logger.info( 

602 f"Token limit reached, loaded {len(loaded)}/{len(references)} contexts", 

603 extra={"total_tokens": total_tokens, "limit": max_tokens}, 

604 ) 

605 break 

606 

607 span.set_attribute("contexts_loaded", len(loaded)) 

608 span.set_attribute("total_tokens", total_tokens) 

609 

610 return loaded 

611 

612 def to_messages(self, loaded_contexts: list[LoadedContext]) -> list[BaseMessage]: 

613 """ 

614 Convert loaded contexts to LangChain messages. 

615 

616 Args: 

617 loaded_contexts: List of loaded contexts 

618 

619 Returns: 

620 List of SystemMessages containing context 

621 """ 

622 messages: list[BaseMessage] = [] 

623 

624 for ctx in loaded_contexts: 

625 message = SystemMessage( 

626 content=f'<context type="{ctx.reference.ref_type}" id="{ctx.reference.ref_id}">\n{ctx.content}\n</context>' 

627 ) 

628 messages.append(message) 

629 

630 return messages 

631 

632 

633# Convenience functions 

634async def search_and_load_context( 

635 query: str, 

636 loader: DynamicContextLoader | None = None, 

637 top_k: int = 3, 

638 max_tokens: int = 2000, 

639) -> list[LoadedContext]: 

640 """ 

641 Search for context and load top results within token budget. 

642 

643 Args: 

644 query: Search query 

645 loader: Context loader instance (creates new if None) 

646 top_k: Number of results to search for 

647 max_tokens: Maximum tokens to load 

648 

649 Returns: 

650 List of loaded contexts 

651 """ 

652 if loader is None: 

653 loader = DynamicContextLoader() 

654 

655 # Search 

656 references = await loader.semantic_search(query, top_k=top_k) 

657 

658 # Load within budget 

659 loaded = await loader.load_batch(references, max_tokens=max_tokens) 

660 

661 return loaded