Coverage for src / mcp_server_langgraph / core / agent.py: 58%
363 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2LangGraph Functional API Agent with full observability and multi-provider LLM support
3Includes OpenTelemetry and LangSmith tracing integration
4Enhanced with Pydantic AI for type-safe routing and responses
6Implements Anthropic's gather-action-verify-repeat agentic loop:
71. Gather Context: Compaction and just-in-time loading
82. Take Action: Routing and tool execution
93. Verify Work: LLM-as-judge pattern
104. Repeat: Iterative refinement based on feedback
11"""
13import operator
14from typing import Annotated, Any, Literal, Sequence, TypedDict
16from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
17from langchain_core.runnables import RunnableConfig
18from langgraph.checkpoint.base import BaseCheckpointSaver
19from langgraph.checkpoint.memory import MemorySaver
20from langgraph.graph import END, START, StateGraph
22from mcp_server_langgraph.core.config import settings
23from mcp_server_langgraph.core.context_manager import ContextManager
24from mcp_server_langgraph.core.url_utils import ensure_redis_password_encoded
25from mcp_server_langgraph.llm.factory import create_llm_from_config
26from mcp_server_langgraph.llm.verifier import OutputVerifier
27from mcp_server_langgraph.observability.telemetry import logger
29# Import Dynamic Context Loader if enabled
30try:
31 from mcp_server_langgraph.core.dynamic_context_loader import DynamicContextLoader, search_and_load_context
33 DYNAMIC_CONTEXT_AVAILABLE = True
34except ImportError:
35 DYNAMIC_CONTEXT_AVAILABLE = False
36 # Logger warning deferred to runtime to avoid initialization issues
38# Import Redis checkpointer if available
39try:
40 from langgraph.checkpoint.redis import RedisSaver
42 REDIS_CHECKPOINTER_AVAILABLE = True
43except ImportError:
44 REDIS_CHECKPOINTER_AVAILABLE = False
45 # Logger warning deferred to runtime when checkpointer is actually created
47# Import Pydantic AI for type-safe responses
48try:
49 from mcp_server_langgraph.llm.pydantic_agent import create_pydantic_agent
51 PYDANTIC_AI_AVAILABLE = True
52except ImportError:
53 PYDANTIC_AI_AVAILABLE = False
54 # Logger warning deferred to runtime to avoid initialization issues
56# Import LangSmith config if available
57try:
58 from mcp_server_langgraph.observability.langsmith import get_run_metadata, get_run_tags, langsmith_config
60 LANGSMITH_AVAILABLE = True
61except ImportError:
62 LANGSMITH_AVAILABLE = False
63 langsmith_config = None # type: ignore[assignment]
66class AgentState(TypedDict):
67 """
68 State for the agent graph.
70 Implements full agentic loop state management:
71 - Context: messages, compaction status
72 - Routing: next_action, confidence, reasoning
73 - Verification: verification results, refinement attempts
74 - Metadata: user_id, request_id
75 """
77 messages: Annotated[Sequence[BaseMessage], operator.add]
78 next_action: str
79 user_id: str | None
80 request_id: str | None
81 routing_confidence: float | None # Confidence from Pydantic AI routing
82 reasoning: str | None # Reasoning from Pydantic AI
84 # Context management
85 compaction_applied: bool | None # Whether compaction was applied
86 original_message_count: int | None # Message count before compaction
88 # Verification and refinement
89 verification_passed: bool | None # Whether verification passed
90 verification_score: float | None # Overall quality score (0-1)
91 verification_feedback: str | None # Feedback for refinement
92 refinement_attempts: int | None # Number of refinement iterations
93 user_request: str | None # Original user request for verification
96def _initialize_pydantic_agent() -> Any:
97 """Initialize Pydantic AI agent if available"""
98 if not PYDANTIC_AI_AVAILABLE:
99 return None
101 try:
102 pydantic_agent = create_pydantic_agent()
103 logger.info("Pydantic AI agent initialized for type-safe routing")
104 return pydantic_agent
105 except Exception as e:
106 logger.warning(f"Failed to initialize Pydantic AI agent: {e}", exc_info=True)
107 return None
110def _create_checkpointer(settings_to_use: Any | None = None) -> Any:
111 """
112 Create checkpointer backend based on configuration
114 Args:
115 settings_to_use: Optional Settings object to use. If None, uses global settings.
117 Returns:
118 BaseCheckpointSaver: Configured checkpointer (MemorySaver or RedisSaver)
119 """
120 # Use provided settings or fall back to global settings
121 effective_settings = settings_to_use if settings_to_use is not None else settings
123 backend = effective_settings.checkpoint_backend.lower()
125 if backend == "redis":
126 if not REDIS_CHECKPOINTER_AVAILABLE:
127 logger.warning(
128 "Redis checkpointer not available (langgraph-checkpoint-redis not installed), "
129 "falling back to MemorySaver. Add 'langgraph-checkpoint-redis' to "
130 "pyproject.toml dependencies, then run: uv sync"
131 )
132 return MemorySaver()
134 try:
135 logger.info(
136 "Initializing Redis checkpointer for distributed conversation state",
137 extra={
138 "redis_url": effective_settings.checkpoint_redis_url,
139 "ttl_seconds": effective_settings.checkpoint_redis_ttl,
140 },
141 )
143 # Create Redis checkpointer with TTL
144 # Note: RedisSaver.from_conn_string expects redis_url (not conn_string)
145 # and returns a context manager in langgraph-checkpoint-redis 0.1.2+
146 # Ensure password is URL-encoded to prevent parsing errors (defense-in-depth)
147 encoded_redis_url = ensure_redis_password_encoded(effective_settings.checkpoint_redis_url)
148 checkpointer_ctx = RedisSaver.from_conn_string(
149 redis_url=encoded_redis_url,
150 )
152 # Enter the context manager to get the actual RedisSaver instance
153 checkpointer = checkpointer_ctx.__enter__()
155 # Store context manager reference for proper cleanup on shutdown
156 # This prevents resource leaks (Redis connections, file descriptors)
157 checkpointer.__context_manager__ = checkpointer_ctx # type: ignore[attr-defined]
159 logger.info("Redis checkpointer initialized successfully")
160 return checkpointer
162 except Exception as e:
163 logger.error(
164 f"Failed to initialize Redis checkpointer: {e}. Falling back to MemorySaver",
165 exc_info=True,
166 )
167 return MemorySaver()
169 elif backend == "memory":
170 logger.info("Using in-memory checkpointer (not suitable for multi-replica deployments)")
171 return MemorySaver()
173 else:
174 logger.warning(f"Unknown checkpoint backend '{backend}', falling back to MemorySaver. Supported: 'memory', 'redis'")
175 return MemorySaver()
178def create_checkpointer(settings_override: Any | None = None) -> Any:
179 """
180 Public API to create checkpointer backend based on configuration.
182 Args:
183 settings_override: Optional Settings object to override global settings.
184 Useful for testing with custom configurations.
186 Returns:
187 BaseCheckpointSaver: Configured checkpointer (MemorySaver or RedisSaver)
189 Example:
190 # Use global settings
191 checkpointer = create_checkpointer()
193 # Use custom settings for testing
194 test_settings = Settings(checkpoint_backend="memory")
195 checkpointer = create_checkpointer(test_settings)
196 """
197 # Pass settings directly to avoid global state mutation
198 # This eliminates race conditions in concurrent environments
199 return _create_checkpointer(settings_to_use=settings_override)
202def cleanup_checkpointer(checkpointer: BaseCheckpointSaver[Any]) -> None:
203 """
204 Clean up checkpointer resources on application shutdown.
206 Properly closes Redis connections and context managers to prevent:
207 - Connection pool exhaustion
208 - File descriptor leaks
209 - Memory leaks in long-running processes
211 Args:
212 checkpointer: Checkpointer instance to clean up
214 Usage:
215 # In FastAPI lifespan or atexit handler:
216 import atexit
217 checkpointer = create_checkpointer(settings)
218 atexit.register(lambda: cleanup_checkpointer(checkpointer))
220 Example:
221 # FastAPI lifespan context manager
222 @asynccontextmanager
223 async def lifespan(app: FastAPI):
224 checkpointer = create_checkpointer(settings)
225 yield
226 cleanup_checkpointer(checkpointer)
227 """
228 try:
229 # Check if checkpointer has context manager reference
230 if hasattr(checkpointer, "__context_manager__"):
231 context_manager = checkpointer.__context_manager__
232 logger.info("Cleaning up Redis checkpointer context manager")
234 # Exit context manager to close connections
235 context_manager.__exit__(None, None, None)
237 logger.info("Redis checkpointer cleanup completed successfully")
238 else:
239 logger.debug(f"Checkpointer {type(checkpointer).__name__} does not require cleanup")
241 except Exception as e:
242 logger.error(f"Error during checkpointer cleanup: {e}", exc_info=True)
245def _get_runnable_config(user_id: str | None = None, request_id: str | None = None) -> RunnableConfig | None:
246 """Get runnable config with LangSmith metadata"""
247 if not LANGSMITH_AVAILABLE or not langsmith_config.is_enabled():
248 return None
250 return RunnableConfig(
251 run_name="mcp-server-langgraph", tags=get_run_tags(user_id), metadata=get_run_metadata(user_id, request_id)
252 )
255def _fallback_routing(state: AgentState, last_message: HumanMessage) -> AgentState:
256 """Fallback routing logic without Pydantic AI"""
257 # Determine if this needs tools or direct response
258 content = last_message.content if isinstance(last_message.content, str) else str(last_message.content)
259 if any(keyword in content.lower() for keyword in ["search", "calculate", "lookup"]):
260 next_action = "use_tools"
261 else:
262 next_action = "respond"
264 # NOTE: Don't return "messages" key - operator.add would duplicate them!
265 # Only return fields we're modifying.
266 return { # type: ignore[typeddict-item]
267 "next_action": next_action,
268 "routing_confidence": 0.5, # Low confidence for fallback
269 "reasoning": "Fallback keyword-based routing",
270 "user_id": state.get("user_id"),
271 "request_id": state.get("request_id"),
272 }
275def _create_agent_graph_singleton(settings_override: Any | None = None) -> Any: # noqa: C901
276 """
277 Create the LangGraph agent using functional API with LiteLLM and observability.
279 Implements Anthropic's agentic loop:
280 1. Gather Context: compact_context node
281 2. Take Action: route_input → use_tools → generate_response
282 3. Verify Work: verify_response node
283 4. Repeat: refine_response loop (max 3 iterations)
285 Args:
286 settings_override: Optional Settings instance to override global settings.
287 If None, uses the global settings object.
288 """
290 # Use override settings if provided, otherwise use global settings
291 effective_settings = settings_override if settings_override is not None else settings
293 # Initialize the model via LiteLLM factory
294 model = create_llm_from_config(effective_settings)
296 # Initialize Pydantic AI agent if available
297 pydantic_agent = _initialize_pydantic_agent()
299 # Initialize context manager for compaction
300 context_manager = ContextManager(compaction_threshold=8000, target_after_compaction=4000, recent_message_count=5)
302 # Initialize output verifier for quality checks
303 output_verifier = OutputVerifier(quality_threshold=0.7)
305 # Initialize dynamic context loader if enabled
306 enable_dynamic_loading = getattr(effective_settings, "enable_dynamic_context_loading", False)
307 context_loader = None
308 if enable_dynamic_loading and DYNAMIC_CONTEXT_AVAILABLE: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 try:
310 context_loader = DynamicContextLoader()
311 logger.info("Dynamic context loader initialized")
312 except Exception as e:
313 logger.warning(f"Failed to initialize dynamic context loader: {e}", exc_info=True)
314 enable_dynamic_loading = False
316 # Feature flags for new capabilities
317 enable_context_compaction = getattr(effective_settings, "enable_context_compaction", True)
318 enable_verification = getattr(effective_settings, "enable_verification", True)
319 max_refinement_attempts = getattr(effective_settings, "max_refinement_attempts", 3)
321 # Define node functions
323 async def load_dynamic_context(state: AgentState) -> AgentState:
324 """
325 Load relevant context dynamically based on user request.
327 Implements Anthropic's Just-in-Time loading strategy.
328 """
329 if not enable_dynamic_loading or not context_loader: 329 ↛ 333line 329 didn't jump to line 333 because the condition on line 329 was always true
330 # NOTE: Not modifying any state - return empty dict to avoid duplication
331 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
333 last_message = state["messages"][-1]
335 if isinstance(last_message, HumanMessage):
336 try:
337 logger.info("Loading dynamic context")
339 # Search for relevant context
340 query = last_message.content if isinstance(last_message.content, str) else str(last_message.content)
341 loaded_contexts = await search_and_load_context(
342 query=query,
343 loader=context_loader,
344 top_k=getattr(effective_settings, "dynamic_context_top_k", 3),
345 max_tokens=getattr(effective_settings, "dynamic_context_max_tokens", 2000),
346 )
348 if loaded_contexts:
349 # Convert to messages and prepend
350 context_messages = context_loader.to_messages(loaded_contexts)
352 # Insert context before user message
353 current_messages = list(state["messages"])
354 messages_before = current_messages[:-1]
355 user_message = current_messages[-1]
356 state["messages"] = messages_before + context_messages + [user_message]
358 logger.info(
359 "Dynamic context loaded",
360 extra={
361 "contexts_loaded": len(loaded_contexts),
362 "total_tokens": sum(c.token_count for c in loaded_contexts),
363 },
364 )
366 except Exception as e:
367 logger.error(f"Dynamic context loading failed: {e}", exc_info=True)
368 # Continue without dynamic context
370 # NOTE: We modified state["messages"] in place (line 356 inserts context).
371 # Don't return "messages" - operator.add would duplicate them!
372 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
374 async def compact_context(state: AgentState) -> AgentState:
375 """
376 Compact conversation context when approaching token limits.
378 Implements Anthropic's "Compaction" technique for long-horizon tasks.
379 """
380 if not enable_context_compaction: 380 ↛ 382line 380 didn't jump to line 382 because the condition on line 380 was never true
381 # NOTE: Not modifying any state - exclude messages to avoid duplication
382 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
384 messages_list = list(state["messages"])
386 if context_manager.needs_compaction(messages_list): 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 try:
388 logger.info("Applying context compaction")
389 result = await context_manager.compact_conversation(messages_list)
391 state["messages"] = result.compacted_messages
392 state["compaction_applied"] = True
393 state["original_message_count"] = len(messages_list)
395 logger.info(
396 "Context compacted",
397 extra={
398 "original_messages": len(messages_list),
399 "compacted_messages": len(result.compacted_messages),
400 "compression_ratio": result.compression_ratio,
401 },
402 )
403 except Exception as e:
404 logger.error(f"Context compaction failed: {e}", exc_info=True)
405 # Continue without compaction on error
406 state["compaction_applied"] = False
407 else:
408 state["compaction_applied"] = False
410 # NOTE: We modified state["messages"] in place (line 388).
411 # Don't return "messages" - operator.add would duplicate them!
412 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
414 async def route_input(state: AgentState) -> AgentState:
415 """
416 Route based on message type with Pydantic AI for type-safe decisions.
418 Also captures original user request for verification later.
419 """
420 last_message = state["messages"][-1]
422 # Capture original user request for verification
423 if isinstance(last_message, HumanMessage): 423 ↛ 427line 423 didn't jump to line 427 because the condition on line 423 was always true
424 user_request = last_message.content if isinstance(last_message.content, str) else str(last_message.content)
425 state["user_request"] = user_request
427 if isinstance(last_message, HumanMessage): 427 ↛ 462line 427 didn't jump to line 462 because the condition on line 427 was always true
428 # Use Pydantic AI for intelligent routing if available
429 if pydantic_agent: 429 ↛ 455line 429 didn't jump to line 455 because the condition on line 429 was always true
430 try:
431 # Route message asynchronously
432 decision = await pydantic_agent.route_message(
433 last_message.content,
434 context={"user_id": state.get("user_id", "unknown"), "message_count": str(len(state["messages"]))},
435 )
437 # Update state with type-safe decision
438 state["next_action"] = decision.action
439 state["routing_confidence"] = decision.confidence
440 state["reasoning"] = decision.reasoning
442 logger.info(
443 "Pydantic AI routing decision",
444 extra={"action": decision.action, "confidence": decision.confidence, "reasoning": decision.reasoning},
445 )
446 except Exception as e:
447 logger.error(f"Pydantic AI routing failed, using fallback: {e}", exc_info=True)
448 # Fallback to simple routing
449 fallback_result = _fallback_routing(state, last_message)
450 # Merge fallback result into state
451 for key, value in fallback_result.items():
452 state[key] = value # type: ignore[literal-required]
453 else:
454 # Fallback routing if Pydantic AI not available
455 fallback_result = _fallback_routing(state, last_message)
456 # Merge fallback result into state
457 for key, value in fallback_result.items():
458 state[key] = value # type: ignore[literal-required]
460 # NOTE: Don't return "messages" - operator.add would duplicate them!
461 # This applies to both Pydantic AI and fallback paths.
462 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
464 async def use_tools(state: AgentState) -> AgentState:
465 """
466 Execute tools based on LangChain tool calls.
468 Supports both serial and parallel execution based on settings.
470 Implementation status:
471 - ✅ Message state preservation (appends instead of replacing)
472 - ✅ Tool call extraction from AIMessage.tool_calls
473 - ✅ Real tool execution with error handling
474 - ✅ Support for both sync and async tools
475 - ✅ Parallel execution wired with enable_parallel_execution flag
477 Features:
478 - Serial execution (default): Tools executed one at a time
479 - Parallel execution (if enabled): Independent tools run concurrently
480 - Automatic dependency detection and topological sorting
481 - Graceful error handling with informative error messages
482 - Comprehensive logging and telemetry
484 For implementation reference, see:
485 - LangChain tool binding: https://python.langchain.com/docs/how_to/tool_calling/
486 - Parallel execution: docs/adr/ADR-0023-anthropic-tool-design-best-practices.md
487 """
488 messages = state["messages"]
489 last_message = messages[-1]
491 # Check if the last message contains tool calls
492 tool_calls = getattr(last_message, "tool_calls", None) if hasattr(last_message, "tool_calls") else None
494 if not tool_calls or len(tool_calls) == 0: 494 ↛ 508line 494 didn't jump to line 508 because the condition on line 494 was always true
495 # No tool calls found - this shouldn't happen if routed to use_tools
496 # Return a message indicating no tools were called
497 logger.warning(
498 "use_tools node reached but no tool calls found in last message",
499 extra={"message_type": type(last_message).__name__},
500 )
501 tool_response = AIMessage(
502 content="No tool calls found. Proceeding with direct response.",
503 )
504 # NOTE: Return only new message, not state["messages"] + [tool_response]
505 # operator.add automatically appends to existing messages
506 return {**state, "messages": [tool_response], "next_action": "respond"}
508 logger.info(
509 "Executing tools",
510 extra={
511 "tool_count": len(tool_calls),
512 "tools": [tc.get("name", "unknown") for tc in tool_calls],
513 "parallel_enabled": effective_settings.enable_parallel_execution,
514 },
515 )
517 # Check if parallel execution is enabled (use effective_settings for DI support)
518 enable_parallel = getattr(effective_settings, "enable_parallel_execution", False)
520 if enable_parallel and len(tool_calls) > 1:
521 # Use parallel execution for multiple tool calls
522 logger.info(f"Using parallel execution for {len(tool_calls)} tools")
523 tool_messages = await _execute_tools_parallel(tool_calls)
524 else:
525 # Use serial execution (default or single tool)
526 if enable_parallel:
527 logger.info("Parallel execution enabled but only 1 tool call - using serial execution")
528 tool_messages = await _execute_tools_serial(tool_calls)
530 # NOTE: Return only new messages, not state["messages"] + tool_messages
531 # operator.add automatically appends to existing messages
532 return {**state, "messages": tool_messages, "next_action": "respond"}
534 async def _execute_tools_serial(tool_calls: list[dict]) -> list: # type: ignore[type-arg]
535 """Execute tools serially (one at a time)"""
536 from langchain_core.messages import ToolMessage
538 from mcp_server_langgraph.tools import get_tool_by_name
540 tool_messages: list = [] # type: ignore[type-arg]
541 for tool_call in tool_calls:
542 tool_name = tool_call.get("name", "unknown")
543 tool_call_id = tool_call.get("id", str(len(tool_messages)))
544 tool_args = tool_call.get("args", {})
546 try:
547 # Find the tool by name
548 tool = get_tool_by_name(tool_name)
550 if tool is None:
551 from mcp_server_langgraph.tools import ALL_TOOLS
553 result_content = f"Error: Tool '{tool_name}' not found. Available tools: {[t.name for t in ALL_TOOLS]}"
554 logger.error(f"Tool '{tool_name}' not found", extra={"available_tools": [t.name for t in ALL_TOOLS]})
555 else:
556 # Execute the tool (tools can be sync or async)
557 logger.info(f"Invoking tool '{tool_name}'", extra={"args": tool_args})
559 if hasattr(tool, "ainvoke"):
560 # Async tool
561 result_content = await tool.ainvoke(tool_args)
562 else:
563 # Sync tool - invoke directly
564 result_content = tool.invoke(tool_args)
566 logger.info(
567 f"Tool '{tool_name}' executed successfully",
568 extra={"tool": tool_name, "result_length": len(str(result_content))},
569 )
571 except Exception as e:
572 result_content = f"Error executing tool '{tool_name}': {e!s}"
573 logger.error(
574 f"Tool execution failed: {tool_name}",
575 extra={"tool": tool_name, "args": tool_args, "error": str(e)},
576 exc_info=True,
577 )
579 # Create tool message with result
580 tool_message = ToolMessage(
581 content=str(result_content),
582 tool_call_id=tool_call_id,
583 name=tool_name,
584 )
585 tool_messages.append(tool_message)
587 return tool_messages
589 async def _execute_tools_parallel(tool_calls: list[dict]) -> list: # type: ignore[type-arg]
590 """Execute tools in parallel using ParallelToolExecutor"""
591 from langchain_core.messages import ToolMessage
593 from mcp_server_langgraph.core.parallel_executor import ParallelToolExecutor, ToolInvocation
594 from mcp_server_langgraph.tools import get_tool_by_name
596 # Create parallel executor (use effective_settings for DI support)
597 max_parallelism = getattr(effective_settings, "max_parallel_tools", 5)
598 executor = ParallelToolExecutor(max_parallelism=max_parallelism)
600 # Convert tool_calls to ToolInvocation objects
601 invocations: list[ToolInvocation] = []
602 for tool_call in tool_calls:
603 tool_name = tool_call.get("name", "unknown")
604 tool_args = tool_call.get("args", {})
605 tool_call_id = tool_call.get("id", f"call_{len(invocations)}")
607 invocation = ToolInvocation(tool_name=tool_name, arguments=tool_args, invocation_id=tool_call_id, dependencies=[])
608 invocations.append(invocation)
610 # Define tool executor function for parallel executor
611 async def execute_single_tool(tool_name: str, arguments: dict): # type: ignore[type-arg, no-untyped-def]
612 """Execute a single tool"""
613 tool = get_tool_by_name(tool_name)
614 if tool is None:
615 msg = f"Tool '{tool_name}' not found"
616 raise ValueError(msg)
618 if hasattr(tool, "ainvoke"):
619 return await tool.ainvoke(arguments)
620 else:
621 return tool.invoke(arguments)
623 # Execute tools in parallel
624 try:
625 results = await executor.execute_parallel(invocations, execute_single_tool)
627 # Convert results to ToolMessage objects
628 tool_messages = []
629 for result in results:
630 # Find the original tool_call to get the correct ID
631 original_call = next((tc for tc in tool_calls if tc.get("id") == result.invocation_id), None)
632 tool_call_id = original_call.get("id") if original_call else result.invocation_id
634 if result.error:
635 content = f"Error executing tool '{result.tool_name}': {result.error!s}"
636 else:
637 content = str(result.result)
639 tool_message = ToolMessage(content=content, tool_call_id=tool_call_id, name=result.tool_name)
640 tool_messages.append(tool_message)
642 return tool_messages
644 except Exception as e:
645 logger.error(f"Parallel tool execution failed: {e}", exc_info=True)
646 # Fall back to serial execution on failure
647 logger.warning("Falling back to serial execution due to parallel execution failure")
648 return await _execute_tools_serial(tool_calls)
650 async def generate_response(state: AgentState) -> AgentState:
651 """Generate final response using LLM with Pydantic AI validation"""
652 messages = state["messages"]
654 messages_list = list(messages)
656 # Add refinement context if this is a refinement attempt
657 refinement_attempts = state.get("refinement_attempts", 0)
658 if refinement_attempts > 0 and state.get("verification_feedback"): # type: ignore[operator] 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true
659 refinement_prompt = SystemMessage(
660 content=f"<refinement_guidance>\n"
661 f"Previous response had issues. Please refine based on this feedback:\n"
662 f"{state['verification_feedback']}\n"
663 f"</refinement_guidance>"
664 )
665 messages_list = [refinement_prompt] + messages_list
667 # Use Pydantic AI for structured response if available
668 if pydantic_agent: 668 ↛ 700line 668 didn't jump to line 700 because the condition on line 668 was always true
669 try:
670 # Generate type-safe response
671 typed_response = await pydantic_agent.generate_response(
672 messages_list,
673 context={
674 "user_id": state.get("user_id", "unknown"),
675 "routing_confidence": str(state.get("routing_confidence", 0.0)),
676 "refinement_attempt": str(refinement_attempts),
677 },
678 )
680 # Convert to AIMessage
681 response = AIMessage(content=typed_response.content)
683 logger.info(
684 "Pydantic AI response generated",
685 extra={
686 "confidence": typed_response.confidence,
687 "requires_clarification": typed_response.requires_clarification,
688 "sources": typed_response.sources,
689 "refinement_attempt": refinement_attempts,
690 },
691 )
692 except Exception as e:
693 logger.error(f"Pydantic AI response generation failed, using fallback: {e}", exc_info=True)
694 # Fallback to standard LLM (use async invoke)
695 # Type cast needed: list is invariant, so list[BaseMessage] != list[BaseMessage | dict[str, Any]]
696 response = await model.ainvoke(messages_list) # type: ignore[arg-type]
697 else:
698 # Standard LLM response (use async invoke)
699 # Type cast needed: list is invariant, so list[BaseMessage] != list[BaseMessage | dict[str, Any]]
700 response = await model.ainvoke(messages_list) # type: ignore[arg-type]
702 # NOTE: Returning [response] (not state["messages"] + [response]) is correct here.
703 # Lang Graph's operator.add annotation on AgentState.messages (line 77) automatically
704 # merges/appends this list to the existing messages. Manually appending would cause
705 # duplication. See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph
706 return {**state, "messages": [response], "next_action": "verify" if enable_verification else "end"}
708 async def verify_response(state: AgentState) -> AgentState:
709 """
710 Verify response quality using LLM-as-judge pattern.
712 Implements Anthropic's "Verify Work" step in the agent loop.
713 """
714 if not enable_verification: 714 ↛ 715line 714 didn't jump to line 715 because the condition on line 714 was never true
715 state["next_action"] = "end"
716 # NOTE: Don't return "messages" - operator.add would duplicate them!
717 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
719 # Get the response to verify (last message)
720 response_message = state["messages"][-1]
721 response_content = response_message.content if hasattr(response_message, "content") else str(response_message)
722 # Ensure response_text is a string (content can be str or list)
723 response_text = response_content if isinstance(response_content, str) else str(response_content)
725 # Get user request
726 user_request = state.get("user_request") or ""
728 # Get conversation context (excluding the response we're verifying)
729 conversation_context = list(state["messages"])[:-1]
731 try:
732 logger.info("Verifying response quality")
733 verification_result = await output_verifier.verify_response(
734 response=response_text, user_request=user_request, conversation_context=conversation_context
735 )
737 state["verification_passed"] = verification_result.passed
738 state["verification_score"] = verification_result.overall_score
739 state["verification_feedback"] = verification_result.feedback
741 # Determine next action
742 refinement_attempts = state.get("refinement_attempts", 0)
744 if verification_result.passed: 744 ↛ 749line 744 didn't jump to line 749 because the condition on line 744 was always true
745 state["next_action"] = "end"
746 logger.info(
747 "Verification passed", extra={"score": verification_result.overall_score, "attempts": refinement_attempts}
748 )
749 elif (refinement_attempts or 0) < max_refinement_attempts:
750 state["next_action"] = "refine"
751 logger.info(
752 "Verification failed, refining response",
753 extra={
754 "score": verification_result.overall_score,
755 "attempt": (refinement_attempts or 0) + 1,
756 "max_attempts": max_refinement_attempts,
757 },
758 )
759 else:
760 # Max attempts reached, accept response
761 state["next_action"] = "end"
762 logger.warning(
763 "Max refinement attempts reached, accepting response",
764 extra={"score": verification_result.overall_score, "attempts": refinement_attempts},
765 )
767 except Exception as e:
768 logger.error(f"Verification failed: {e}", exc_info=True)
769 # On verification error, accept response (fail-open)
770 state["verification_passed"] = True
771 state["next_action"] = "end"
773 # NOTE: Don't return "messages" - operator.add would duplicate them!
774 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
776 async def refine_response(state: AgentState) -> AgentState:
777 """
778 Refine response based on verification feedback.
780 Implements iterative refinement loop (part of "Repeat" in agentic loop).
781 """
782 # Increment refinement attempts
783 refinement_attempts = state.get("refinement_attempts", 0) or 0
784 state["refinement_attempts"] = refinement_attempts + 1
786 # Remove the failed response from messages
787 # It will be regenerated with refinement guidance
788 state["messages"] = state["messages"][:-1]
790 # Set next action to respond (will regenerate with feedback)
791 state["next_action"] = "respond"
793 feedback = state.get("verification_feedback") or ""
794 feedback_preview = feedback[:100] if isinstance(feedback, str) else ""
795 logger.info(
796 "Refining response",
797 extra={"attempt": state["refinement_attempts"], "feedback": feedback_preview},
798 )
800 # NOTE: We modified state["messages"] in place (line 786 removes last message).
801 # Don't return "messages" - operator.add would duplicate them!
802 return {k: v for k, v in state.items() if k != "messages"} # type: ignore[return-value]
804 def should_continue(state: AgentState) -> Literal["use_tools", "respond", "end"]:
805 """Conditional edge function for routing"""
806 next_action = state.get("next_action", "respond") or "respond"
807 # Default to "respond" if next_action is empty or not set
808 if not next_action or next_action not in ["use_tools", "respond", "end"]: 808 ↛ 809line 808 didn't jump to line 809 because the condition on line 808 was never true
809 return "respond"
810 return next_action # type: ignore[return-value]
812 def should_verify(state: AgentState) -> Literal["verify", "refine", "end"]:
813 """Conditional edge function for verification loop"""
814 next_action = state.get("next_action", "end") or "end"
815 # Default to "end" if next_action is empty or invalid
816 if not next_action or next_action not in ["verify", "refine", "end"]: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true
817 return "end"
818 return next_action # type: ignore[return-value]
820 # Build the graph with full agentic loop
821 workflow = StateGraph(AgentState)
823 # Add nodes (Load → Gather → Route → Act → Verify → Repeat)
824 workflow.add_node("load_context", load_dynamic_context) # Just-in-Time Context Loading
825 workflow.add_node("compact", compact_context) # Gather Context (Compaction)
826 workflow.add_node("router", route_input) # Route Decision
827 workflow.add_node("tools", use_tools) # Take Action (tools)
828 workflow.add_node("respond", generate_response) # Take Action (response)
829 workflow.add_node("verify", verify_response) # Verify Work
830 workflow.add_node("refine", refine_response) # Repeat (refinement)
832 # Add edges for full agentic loop with dynamic context loading
833 workflow.add_edge(START, "load_context") # Start with JIT context loading
834 workflow.add_edge("load_context", "compact") # Then compaction
835 workflow.add_edge("compact", "router") # Then route
836 workflow.add_conditional_edges(
837 "router",
838 should_continue,
839 {
840 "use_tools": "tools",
841 "respond": "respond",
842 },
843 )
844 workflow.add_edge("tools", "respond")
845 workflow.add_conditional_edges(
846 "verify",
847 should_verify,
848 {
849 "verify": "verify", # Not used (defensive)
850 "refine": "refine", # Refinement needed
851 "end": END, # Verification passed
852 },
853 )
854 workflow.add_edge("respond", "verify") # Always verify responses
855 workflow.add_edge("refine", "respond") # Refine loops back to respond
857 # Compile with optional checkpointing (use effective_settings for DI support)
858 enable_checkpointing = getattr(effective_settings, "enable_checkpointing", True)
859 if enable_checkpointing:
860 checkpointer = _create_checkpointer(effective_settings)
861 return workflow.compile(checkpointer=checkpointer)
862 else:
863 # Compile without checkpointing (useful for testing with mocks)
864 logger.info("Checkpointing disabled - graph will not persist conversation state")
865 return workflow.compile()
868# IMPORTANT: Do NOT create agent_graph at module level
869# The lazy initialization pattern in telemetry.py requires observability to be initialized first
870# Entry points (mcp/server_stdio.py, mcp/server_streamable.py) must call init_observability()
871# before accessing agent_graph
872#
873# Legacy module-level export for backward compatibility (will be None until explicitly created)
874_agent_graph_cache = None
877def get_agent_graph() -> None:
878 """
879 Get or create the agent graph singleton.
881 DEPRECATED: Use create_agent() or create_agent_graph() instead.
882 This function provides lazy initialization that respects observability initialization.
883 Call this instead of accessing agent_graph directly.
885 Returns:
886 Compiled LangGraph StateGraph
888 Raises:
889 RuntimeError: If observability is not initialized
890 """
891 global _agent_graph_cache
892 if _agent_graph_cache is None:
893 _agent_graph_cache = _create_agent_graph_singleton()
894 return _agent_graph_cache # type: ignore[no-any-return]
897# Backward compatibility: agent_graph will be None until get_agent_graph() is called
898agent_graph = None
901# ==============================================================================
902# Dependency Injection API (NEW)
903# ==============================================================================
906def create_agent_graph(
907 settings: Any | None = None,
908 container: Any | None = None,
909) -> Any:
910 """
911 Create a new agent graph with dependency injection support.
913 This function creates a fresh agent graph instance using either:
914 - A container (preferred for full DI benefits)
915 - Custom settings
916 - Default settings (fallback)
918 Args:
919 settings: Optional Settings instance to use
920 container: Optional ApplicationContainer instance to use
922 Returns:
923 Compiled LangGraph StateGraph
925 Example:
926 # Using container (preferred)
927 from mcp_server_langgraph.core.container import create_test_container
928 container = create_test_container()
929 agent = create_agent_graph(container=container)
931 # Using custom settings
932 from mcp_server_langgraph.core.config import Settings
933 settings = Settings(environment="test")
934 agent = create_agent_graph(settings=settings)
936 # Using defaults
937 agent = create_agent_graph()
938 """
939 # Get settings from container or use provided/default
940 if container is not None:
941 actual_settings = container.settings
942 elif settings is not None:
943 actual_settings = settings
944 else:
945 # Use default settings (will use global settings object)
946 from mcp_server_langgraph.core.config import settings as default_settings
948 actual_settings = default_settings
950 # Create a fresh agent graph using the same create_agent_graph function
951 # This ensures we use the same logic but with injectable settings
952 return create_agent_graph_impl(actual_settings)
955def create_agent_graph_impl(settings_to_use: Any) -> Any:
956 """
957 Implementation of agent graph creation with specific settings.
959 This properly threads the settings through to the agent graph creation,
960 enabling dependency injection for testing and multi-tenant deployments.
962 Args:
963 settings_to_use: Settings instance to use
965 Returns:
966 Compiled LangGraph StateGraph
967 """
968 # Pass settings to the singleton function for proper dependency injection
969 return _create_agent_graph_singleton(settings_override=settings_to_use)
972def create_agent(
973 settings: Any | None = None,
974 container: Any | None = None,
975) -> Any:
976 """
977 Create a new agent instance with dependency injection support.
979 This is the main factory function for creating agents. It supports:
980 - Container-based dependency injection (preferred)
981 - Custom settings
982 - Default configuration
984 Args:
985 settings: Optional Settings instance to override defaults
986 container: Optional ApplicationContainer for full DI
988 Returns:
989 Compiled agent graph ready for use
991 Example:
992 # Preferred: Using container
993 from mcp_server_langgraph.core.container import create_test_container
994 container = create_test_container()
995 agent = create_agent(container=container)
996 result = agent.invoke({"messages": [...]})
998 # Using custom settings
999 from mcp_server_langgraph.core.config import Settings
1000 settings = Settings(model_name="gpt-4")
1001 agent = create_agent(settings=settings)
1003 # Using defaults
1004 agent = create_agent()
1005 """
1006 return create_agent_graph(settings=settings, container=container)