Coverage for src / mcp_server_langgraph / core / dynamic_context_loader.py: 15%
237 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Dynamic Context Loading with Qdrant Vector Store
4Implements Anthropic's Just-in-Time context loading strategy:
5- Semantic search for relevant context
6- Progressive discovery patterns
7- Lightweight context references
8"""
10import asyncio
11import base64
12import time
13from datetime import datetime, timedelta, UTC
14from functools import lru_cache
15from typing import Any
17from cryptography.fernet import Fernet
18from langchain_core.embeddings import Embeddings
19from langchain_core.messages import BaseMessage, SystemMessage
20from pydantic import BaseModel, Field
21from qdrant_client import QdrantClient
22from qdrant_client.models import Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams
24from mcp_server_langgraph.core.config import settings
25from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
26from mcp_server_langgraph.utils.response_optimizer import count_tokens
29def _create_embeddings(
30 provider: str,
31 model_name: str,
32 google_api_key: str | None = None,
33 task_type: str | None = None,
34) -> Embeddings:
35 """
36 Create embeddings instance based on provider.
38 Args:
39 provider: "google" for Gemini API or "local" for sentence-transformers
40 model_name: Model name (e.g., "models/text-embedding-004" or "all-MiniLM-L6-v2")
41 google_api_key: Google API key (required for "google" provider)
42 task_type: Task type for Google embeddings optimization
44 Returns:
45 Embeddings instance
47 Raises:
48 ValueError: If provider is unsupported or required API key is missing
49 """
50 if provider == "google":
51 try:
52 from langchain_google_genai import GoogleGenerativeAIEmbeddings
53 except ImportError:
54 msg = (
55 "langchain-google-genai is required for Google embeddings. "
56 "Add 'langchain-google-genai' to pyproject.toml dependencies, then run: uv sync"
57 )
58 raise ImportError(msg)
60 if not google_api_key:
61 msg = "GOOGLE_API_KEY is required for Google embeddings. Set via environment variable or Infisical."
62 raise ValueError(msg)
64 # Create Google embeddings with task type optimization
65 from pydantic import SecretStr
67 embeddings: Embeddings = GoogleGenerativeAIEmbeddings(
68 model=model_name,
69 google_api_key=SecretStr(google_api_key),
70 task_type=task_type or "RETRIEVAL_DOCUMENT",
71 )
73 logger.info(
74 "Initialized Google Gemini embeddings",
75 extra={"model": model_name, "task_type": task_type},
76 )
78 return embeddings
80 elif provider == "local":
81 try:
82 from sentence_transformers import SentenceTransformer
84 class SentenceTransformerEmbeddings(Embeddings):
85 """Wrapper to make SentenceTransformer compatible with LangChain Embeddings interface."""
87 def __init__(self, model_name: str) -> None:
88 self.model = SentenceTransformer(model_name)
90 def embed_documents(self, texts: list[str]) -> list[list[float]]:
91 """Embed multiple documents."""
92 embeddings = self.model.encode(texts)
93 return embeddings.tolist() # type: ignore[no-any-return]
95 def embed_query(self, text: str) -> list[float]:
96 """Embed a single query."""
97 embedding = self.model.encode(text)
98 return embedding.tolist() # type: ignore[no-any-return]
100 embeddings = SentenceTransformerEmbeddings(model_name)
102 logger.info(
103 "Initialized local sentence-transformers embeddings",
104 extra={"model": model_name},
105 )
107 return embeddings
109 except ImportError:
110 msg = (
111 "sentence-transformers is required for local embeddings. "
112 "Add 'sentence-transformers' to pyproject.toml dependencies, then run: uv sync"
113 )
114 raise ImportError(msg)
116 else:
117 msg = f"Unsupported embedding provider: {provider}. Supported providers: 'google', 'local'"
118 raise ValueError(msg)
121class ContextReference(BaseModel):
122 """Lightweight reference to context that can be loaded on demand."""
124 ref_id: str = Field(description="Unique identifier for this context")
125 ref_type: str = Field(description="Type: conversation, document, tool_usage, file")
126 summary: str = Field(description="Brief summary for filtering (< 100 chars)")
127 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
128 relevance_score: float | None = Field(default=None, description="Relevance score if from search")
131class LoadedContext(BaseModel):
132 """Full context loaded from a reference."""
134 reference: ContextReference
135 content: str = Field(description="Full context content")
136 token_count: int = Field(description="Token count of content")
137 loaded_at: float = Field(description="Timestamp when loaded")
140class DynamicContextLoader:
141 """
142 Manages just-in-time context loading with semantic search.
144 Follows Anthropic's recommendations:
145 1. Store lightweight identifiers instead of full context
146 2. Load context dynamically when needed
147 3. Use semantic search for relevance
148 4. Progressive discovery through iterative search
149 """
151 def __init__(
152 self,
153 qdrant_url: str | None = None,
154 qdrant_port: int | None = None,
155 collection_name: str | None = None,
156 embedding_model: str | None = None,
157 embedding_provider: str | None = None,
158 embedding_dimensions: int | None = None,
159 cache_size: int | None = None,
160 ):
161 """
162 Initialize dynamic context loader with encryption and retention support.
164 Args:
165 qdrant_url: Qdrant server URL (defaults to settings.qdrant_url)
166 qdrant_port: Qdrant server port (defaults to settings.qdrant_port)
167 collection_name: Name of Qdrant collection (defaults to settings.qdrant_collection_name)
168 embedding_model: Model name (defaults to settings.embedding_model_name)
169 embedding_provider: "google" or "local" (defaults to settings.embedding_provider)
170 embedding_dimensions: Embedding dimensions (defaults to settings.embedding_dimensions)
171 cache_size: LRU cache size for loaded contexts (defaults to settings.context_cache_size)
173 Note:
174 For regulated workloads (HIPAA, GDPR, etc.):
175 - Set enable_context_encryption=True in settings
176 - Configure retention period with context_retention_days
177 - For multi-tenant isolation, use enable_multi_tenant_isolation=True
179 Migration from sentence-transformers:
180 - Existing collections with different dimensions need recreation
181 - Google embeddings (768) vs sentence-transformers (384)
182 """
183 self.qdrant_url = qdrant_url or settings.qdrant_url
184 self.qdrant_port = qdrant_port or settings.qdrant_port
185 self.collection_name = collection_name or settings.qdrant_collection_name
187 # Embedding configuration
188 self.embedding_provider = embedding_provider or settings.embedding_provider
189 # Support both new and legacy config parameter names
190 self.embedding_model_name = embedding_model or settings.embedding_model_name
191 self.embedding_dim = embedding_dimensions or settings.embedding_dimensions
193 cache_size = cache_size or settings.context_cache_size
195 # Security & Compliance configuration
196 self.enable_encryption = settings.enable_context_encryption
197 self.retention_days = settings.context_retention_days
198 self.enable_auto_deletion = settings.enable_auto_deletion
200 # Initialize encryption if enabled
201 self.cipher: Fernet | None = None
202 if self.enable_encryption:
203 if not settings.context_encryption_key:
204 msg = (
205 "CRITICAL: Context encryption enabled but no encryption key configured. "
206 "Set CONTEXT_ENCRYPTION_KEY environment variable or configure via Infisical."
207 )
208 raise ValueError(msg)
209 # Fernet requires base64-encoded 32-byte key
210 try:
211 self.cipher = Fernet(settings.context_encryption_key.encode())
212 except Exception as e:
213 msg = f"Invalid encryption key format: {e}. Generate with: Fernet.generate_key()"
214 raise ValueError(msg)
216 # Initialize Qdrant client
217 self.client = QdrantClient(host=self.qdrant_url, port=self.qdrant_port)
219 # Initialize embeddings
220 self.embedder = _create_embeddings(
221 provider=self.embedding_provider,
222 model_name=self.embedding_model_name,
223 google_api_key=settings.google_api_key,
224 task_type=settings.embedding_task_type,
225 )
227 # Create collection if it doesn't exist
228 self._ensure_collection_exists()
230 # LRU cache for loaded contexts
231 self._load_context_cached = lru_cache(maxsize=cache_size)(self._load_context_impl)
233 logger.info(
234 "DynamicContextLoader initialized",
235 extra={
236 "qdrant_url": self.qdrant_url,
237 "collection": self.collection_name,
238 "embedding_provider": self.embedding_provider,
239 "embedding_model": self.embedding_model_name,
240 "embedding_dim": self.embedding_dim,
241 "encryption_enabled": self.enable_encryption,
242 "retention_days": self.retention_days,
243 "auto_deletion_enabled": self.enable_auto_deletion,
244 },
245 )
247 def _encrypt_content(self, content: str) -> str:
248 """
249 Encrypt content for storage (encryption-at-rest).
251 Args:
252 content: Plain text content
254 Returns:
255 Base64-encoded encrypted content
256 """
257 if not self.cipher:
258 return content
260 encrypted_bytes = self.cipher.encrypt(content.encode())
261 return base64.b64encode(encrypted_bytes).decode()
263 def _decrypt_content(self, encrypted_content: str) -> str:
264 """
265 Decrypt content from storage.
267 Args:
268 encrypted_content: Base64-encoded encrypted content
270 Returns:
271 Decrypted plain text
272 """
273 if not self.cipher:
274 return encrypted_content
276 try:
277 encrypted_bytes = base64.b64decode(encrypted_content.encode())
278 decrypted_bytes = self.cipher.decrypt(encrypted_bytes)
279 return decrypted_bytes.decode()
280 except Exception as e:
281 logger.error(f"Decryption failed: {e}", exc_info=True)
282 msg = f"Failed to decrypt content: {e}"
283 raise ValueError(msg)
285 def _calculate_expiry_timestamp(self) -> float:
286 """Calculate expiry timestamp based on retention policy."""
287 expiry_date = datetime.now(UTC) + timedelta(days=self.retention_days)
288 return expiry_date.timestamp()
290 def _ensure_collection_exists(self) -> None:
291 """Create Qdrant collection if it doesn't exist."""
292 try:
293 collections = self.client.get_collections().collections
294 exists = any(c.name == self.collection_name for c in collections)
296 if not exists:
297 self.client.create_collection(
298 collection_name=self.collection_name,
299 vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE),
300 )
301 logger.info(f"Created Qdrant collection: {self.collection_name}")
302 else:
303 logger.info(f"Qdrant collection exists: {self.collection_name}")
304 except Exception as e:
305 logger.error(f"Failed to ensure Qdrant collection: {e}", exc_info=True)
306 raise
308 async def index_context(
309 self,
310 ref_id: str,
311 content: str,
312 ref_type: str,
313 summary: str,
314 metadata: dict[str, Any] | None = None,
315 ) -> None:
316 """
317 Index context for later retrieval.
319 Args:
320 ref_id: Unique identifier
321 content: Full content to index
322 ref_type: Type of context
323 summary: Brief summary
324 metadata: Additional metadata
325 """
326 with tracer.start_as_current_span("context.index") as span:
327 span.set_attribute("ref_id", ref_id)
328 span.set_attribute("ref_type", ref_type)
330 try:
331 # Generate embedding using LangChain Embeddings interface
332 embedding = await asyncio.to_thread(self.embedder.embed_query, content)
334 # Encrypt content if encryption is enabled
335 stored_content = self._encrypt_content(content) if self.enable_encryption else content
337 # Calculate expiry timestamp for retention
338 created_at = datetime.now(UTC).timestamp()
339 expires_at = self._calculate_expiry_timestamp()
341 # Create point with encryption and retention support
342 point = PointStruct(
343 id=ref_id,
344 vector=embedding, # Already a list from embed_query
345 payload={
346 "ref_id": ref_id,
347 "ref_type": ref_type,
348 "summary": summary,
349 "content": stored_content, # Encrypted if encryption enabled
350 "token_count": count_tokens(content),
351 "metadata": metadata or {},
352 "created_at": created_at, # For retention tracking
353 "expires_at": expires_at, # For auto-deletion
354 "encrypted": self.enable_encryption, # Flag for decryption
355 },
356 )
358 # Upsert to Qdrant
359 await asyncio.to_thread(self.client.upsert, collection_name=self.collection_name, points=[point])
361 logger.info(f"Indexed context: {ref_id}", extra={"ref_type": ref_type, "summary": summary})
362 metrics.successful_calls.add(1, {"operation": "index_context", "type": ref_type})
364 except Exception as e:
365 logger.error(f"Failed to index context: {e}", exc_info=True)
366 metrics.failed_calls.add(1, {"operation": "index_context", "error": type(e).__name__})
367 raise
369 async def semantic_search(
370 self,
371 query: str,
372 top_k: int = 5,
373 ref_type_filter: str | None = None,
374 min_score: float = 0.5,
375 ) -> list[ContextReference]:
376 """
377 Search for relevant context using semantic similarity.
379 Implements Anthropic's recommendation for search-focused retrieval.
381 Args:
382 query: Search query
383 top_k: Number of results
384 ref_type_filter: Optional filter by ref_type
385 min_score: Minimum similarity score (0-1)
387 Returns:
388 List of context references sorted by relevance
389 """
390 with tracer.start_as_current_span("context.semantic_search") as span:
391 span.set_attribute("query", query)
392 span.set_attribute("top_k", top_k)
394 try:
395 # Generate query embedding using LangChain Embeddings interface
396 query_embedding = await asyncio.to_thread(self.embedder.embed_query, query)
398 # Build filter
399 search_filter = None
400 if ref_type_filter:
401 search_filter = Filter(must=[FieldCondition(key="ref_type", match=MatchValue(value=ref_type_filter))])
403 # Search Qdrant
404 results = await asyncio.to_thread(
405 self.client.search, # type: ignore[attr-defined]
406 collection_name=self.collection_name,
407 query_vector=query_embedding, # Already a list from embed_query
408 limit=top_k,
409 query_filter=search_filter,
410 score_threshold=min_score,
411 )
413 # Convert to ContextReferences
414 references = []
415 for result in results:
416 payload = result.payload
417 if payload is None:
418 continue
419 ref = ContextReference(
420 ref_id=payload["ref_id"],
421 ref_type=payload["ref_type"],
422 summary=payload["summary"],
423 metadata=payload.get("metadata", {}),
424 relevance_score=result.score,
425 )
426 references.append(ref)
428 logger.info(
429 "Semantic search completed",
430 extra={"query": query, "results_found": len(references), "top_k": top_k},
431 )
433 span.set_attribute("results.count", len(references))
434 metrics.successful_calls.add(1, {"operation": "semantic_search"})
436 return references
438 except Exception as e:
439 logger.error(f"Semantic search failed: {e}", exc_info=True)
440 metrics.failed_calls.add(1, {"operation": "semantic_search", "error": type(e).__name__})
441 return []
443 async def progressive_discover(
444 self,
445 initial_query: str,
446 max_iterations: int = 3,
447 expansion_keywords: list[str] | None = None,
448 ) -> list[ContextReference]:
449 """
450 Progressive discovery: iteratively refine search based on results.
452 Implements Anthropic's "Progressive Disclosure" pattern.
454 Args:
455 initial_query: Starting search query
456 max_iterations: Maximum search iterations
457 expansion_keywords: Keywords to expand search
459 Returns:
460 Aggregated context references from all iterations
461 """
462 with tracer.start_as_current_span("context.progressive_discover") as span:
463 span.set_attribute("initial_query", initial_query)
464 span.set_attribute("max_iterations", max_iterations)
466 all_references = []
467 seen_ids = set()
468 current_query = initial_query
470 for iteration in range(max_iterations):
471 logger.info(f"Progressive discovery iteration {iteration + 1}/{max_iterations}")
473 # Search with current query
474 results = await self.semantic_search(current_query, top_k=5)
476 # Add new results
477 for ref in results:
478 if ref.ref_id not in seen_ids:
479 all_references.append(ref)
480 seen_ids.add(ref.ref_id)
482 # Stop if no new results
483 if not results:
484 logger.info(f"No new results in iteration {iteration + 1}, stopping")
485 break
487 # Expand query for next iteration
488 if expansion_keywords and iteration < max_iterations - 1:
489 current_query = f"{current_query} {expansion_keywords[iteration]}"
491 span.set_attribute("total_references", len(all_references))
492 span.set_attribute("iterations_completed", iteration + 1)
494 logger.info(
495 "Progressive discovery completed",
496 extra={"iterations": iteration + 1, "total_references": len(all_references)},
497 )
499 return all_references
501 async def load_context(self, reference: ContextReference) -> LoadedContext:
502 """
503 Load full context from a reference.
505 Uses LRU cache for frequently accessed contexts.
507 Args:
508 reference: Context reference to load
510 Returns:
511 Loaded context with full content
512 """
513 with tracer.start_as_current_span("context.load") as span:
514 span.set_attribute("ref_id", reference.ref_id)
516 # Use cached implementation
517 loaded = await asyncio.to_thread(self._load_context_cached, reference.ref_id)
519 span.set_attribute("token_count", loaded.token_count)
520 metrics.successful_calls.add(1, {"operation": "load_context", "type": reference.ref_type})
522 return loaded
524 def _load_context_impl(self, ref_id: str) -> LoadedContext:
525 """
526 Implementation of context loading (cached).
528 Args:
529 ref_id: Reference ID to load
531 Returns:
532 Loaded context
533 """
534 try:
535 # Retrieve from Qdrant
536 results = self.client.retrieve(collection_name=self.collection_name, ids=[ref_id])
538 if not results:
539 msg = f"Context not found: {ref_id}"
540 raise ValueError(msg)
542 result = results[0]
543 payload = result.payload
545 if payload is None:
546 msg = f"Context payload is None: {ref_id}"
547 raise ValueError(msg)
549 reference = ContextReference(
550 ref_id=payload["ref_id"],
551 ref_type=payload["ref_type"],
552 summary=payload["summary"],
553 metadata=payload.get("metadata", {}),
554 )
556 # Decrypt content if it was encrypted
557 stored_content = payload["content"]
558 is_encrypted = payload.get("encrypted", False)
559 content = self._decrypt_content(stored_content) if is_encrypted else stored_content
561 loaded = LoadedContext(
562 reference=reference,
563 content=content, # Decrypted content
564 token_count=payload["token_count"],
565 loaded_at=time.time(),
566 )
568 logger.info(f"Loaded context: {ref_id}", extra={"token_count": loaded.token_count})
570 return loaded
572 except Exception as e:
573 logger.error(f"Failed to load context {ref_id}: {e}", exc_info=True)
574 metrics.failed_calls.add(1, {"operation": "load_context", "error": type(e).__name__})
575 raise
577 async def load_batch(self, references: list[ContextReference], max_tokens: int = 4000) -> list[LoadedContext]:
578 """
579 Load multiple contexts up to token limit.
581 Implements token-aware batching.
583 Args:
584 references: List of references to load
585 max_tokens: Maximum total tokens
587 Returns:
588 List of loaded contexts within token budget
589 """
590 with tracer.start_as_current_span("context.load_batch") as span:
591 loaded = []
592 total_tokens = 0
594 for ref in references:
595 context = await self.load_context(ref)
597 if total_tokens + context.token_count <= max_tokens:
598 loaded.append(context)
599 total_tokens += context.token_count
600 else:
601 logger.info(
602 f"Token limit reached, loaded {len(loaded)}/{len(references)} contexts",
603 extra={"total_tokens": total_tokens, "limit": max_tokens},
604 )
605 break
607 span.set_attribute("contexts_loaded", len(loaded))
608 span.set_attribute("total_tokens", total_tokens)
610 return loaded
612 def to_messages(self, loaded_contexts: list[LoadedContext]) -> list[BaseMessage]:
613 """
614 Convert loaded contexts to LangChain messages.
616 Args:
617 loaded_contexts: List of loaded contexts
619 Returns:
620 List of SystemMessages containing context
621 """
622 messages: list[BaseMessage] = []
624 for ctx in loaded_contexts:
625 message = SystemMessage(
626 content=f'<context type="{ctx.reference.ref_type}" id="{ctx.reference.ref_id}">\n{ctx.content}\n</context>'
627 )
628 messages.append(message)
630 return messages
633# Convenience functions
634async def search_and_load_context(
635 query: str,
636 loader: DynamicContextLoader | None = None,
637 top_k: int = 3,
638 max_tokens: int = 2000,
639) -> list[LoadedContext]:
640 """
641 Search for context and load top results within token budget.
643 Args:
644 query: Search query
645 loader: Context loader instance (creates new if None)
646 top_k: Number of results to search for
647 max_tokens: Maximum tokens to load
649 Returns:
650 List of loaded contexts
651 """
652 if loader is None:
653 loader = DynamicContextLoader()
655 # Search
656 references = await loader.semantic_search(query, top_k=top_k)
658 # Load within budget
659 loaded = await loader.load_batch(references, max_tokens=max_tokens)
661 return loaded