Coverage for src / mcp_server_langgraph / llm / pydantic_agent.py: 98%
105 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Pydantic AI Agent - Type-safe agent implementation with structured outputs
4This module provides strongly-typed agent responses using Pydantic AI,
5ensuring LLM outputs match expected schemas and improving reliability.
7Note: pydantic-ai is an optional dependency. If not installed, this module will raise
8ImportError when used. The agent.py module handles this gracefully with PYDANTIC_AI_AVAILABLE flag.
9"""
11from typing import Any, Literal
13from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
14from pydantic import BaseModel, Field
16# Conditional import - pydantic-ai is optional for type-safe responses
17# If not available, agent.py will fall back to standard routing
18try:
19 from pydantic_ai import Agent
21 PYDANTIC_AI_AVAILABLE = True
22except ImportError:
23 PYDANTIC_AI_AVAILABLE = False
24 Agent = None # type: ignore[unused-ignore,assignment,misc]
26from mcp_server_langgraph.core.config import settings
27from mcp_server_langgraph.core.prompts import RESPONSE_SYSTEM_PROMPT, ROUTER_SYSTEM_PROMPT
28from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
31# Structured Response Models
32class RouterDecision(BaseModel):
33 """Routing decision with reasoning for agent workflow."""
35 action: Literal["use_tools", "respond", "clarify"] = Field(description="Next action to take in the agent workflow")
36 reasoning: str = Field(description="Explanation of why this action was chosen")
37 tool_name: str | None = Field(default=None, description="Name of tool to use if action is 'use_tools'")
38 confidence: float = Field(ge=0.0, le=1.0, description="Confidence score for this decision (0-1)")
41class ToolExecution(BaseModel):
42 """Result of tool execution with metadata."""
44 tool_name: str = Field(description="Name of the executed tool")
45 result: str = Field(description="Tool execution result")
46 success: bool = Field(description="Whether tool execution succeeded")
47 error_message: str | None = Field(default=None, description="Error message if execution failed")
48 metadata: dict[str, str] = Field(default_factory=dict, description="Additional metadata about the execution")
51class AgentResponse(BaseModel):
52 """Final agent response with confidence and metadata."""
54 content: str = Field(description="The main response content to show to the user")
55 confidence: float = Field(ge=0.0, le=1.0, description="Confidence in this response (0-1)")
56 requires_clarification: bool = Field(default=False, description="Whether the agent needs more information")
57 clarification_question: str | None = Field(default=None, description="Question to ask user if clarification needed")
58 sources: list[str] = Field(default_factory=list, description="Sources or references used to generate response")
59 metadata: dict[str, str] = Field(default_factory=dict, description="Additional response metadata")
62class PydanticAIAgentWrapper:
63 """
64 Wrapper for Pydantic AI agents with type-safe responses.
66 Provides strongly-typed interactions with LLMs, ensuring outputs
67 match expected schemas for improved reliability and debugging.
68 """
70 def __init__(self, model_name: str | None = None, provider: str | None = None) -> None:
71 """
72 Initialize Pydantic AI agent wrapper.
74 Args:
75 model_name: Model to use (defaults to settings.model_name)
76 provider: Provider name (defaults to settings.llm_provider)
77 """
78 self.model_name = model_name or settings.model_name
79 self.provider = provider or settings.llm_provider
81 # Map provider to Pydantic AI model format
82 self.pydantic_model_name = self._get_pydantic_model_name()
84 # Create specialized agents for different tasks with XML-structured prompts
85 # Note: output_type replaces deprecated result_type in pydantic-ai 0.1.0+
86 self.router_agent = Agent(
87 self.pydantic_model_name,
88 output_type=RouterDecision,
89 system_prompt=ROUTER_SYSTEM_PROMPT,
90 )
92 self.response_agent = Agent(
93 self.pydantic_model_name,
94 output_type=AgentResponse,
95 system_prompt=RESPONSE_SYSTEM_PROMPT,
96 )
98 logger.info(
99 "Pydantic AI agent wrapper initialized",
100 extra={"model": self.model_name, "provider": self.provider, "pydantic_model": self.pydantic_model_name},
101 )
103 def _get_pydantic_model_name(self) -> str:
104 """
105 Map provider/model to Pydantic AI format with required provider prefix.
107 PYDANTIC-AI REQUIREMENT (v0.0.14+): All model names must include provider prefix.
109 Deprecation Warning Fixed:
110 --------------------------
111 Old (deprecated): 'gemini-2.5-flash'
112 New (required): 'google-gla:gemini-2.5-flash'
114 Without prefix, pydantic-ai emits:
115 "DeprecationWarning: Specifying a model name without a provider prefix is deprecated.
116 Instead of 'gemini-2.5-flash', use 'google-gla:gemini-2.5-flash'."
118 Provider Prefix Mapping:
119 ------------------------
120 - Google Gemini: 'google-gla:' (e.g., 'google-gla:gemini-2.5-flash')
121 - Anthropic Claude: 'anthropic:' (e.g., 'anthropic:claude-sonnet-4-5-20250929')
122 - OpenAI: 'openai:' (e.g., 'openai:gpt-4')
123 - Unknown: 'provider:' (generic fallback)
125 Returns:
126 Pydantic AI compatible model name with provider prefix
127 """
128 # Check if model name already has a prefix (edge case)
129 if ":" in self.model_name:
130 # Model name already includes provider prefix
131 logger.debug(
132 f"Model name '{self.model_name}' already has provider prefix",
133 extra={"model": self.model_name, "provider": self.provider},
134 )
135 return self.model_name
137 # Add provider prefix based on provider type
138 if self.provider == "google" or self.provider == "gemini":
139 # Google Gemini models require google-gla prefix
140 return f"google-gla:{self.model_name}"
141 elif self.provider == "anthropic":
142 return f"anthropic:{self.model_name}"
143 elif self.provider == "openai":
144 return f"openai:{self.model_name}"
145 else:
146 # Unknown provider: use provider name as prefix
147 logger.warning(
148 f"Unknown provider '{self.provider}', using provider-prefixed format",
149 extra={"model": self.model_name, "provider": self.provider},
150 )
151 return f"{self.provider}:{self.model_name}"
153 async def route_message(self, message: str, context: dict | None = None) -> RouterDecision: # type: ignore[type-arg]
154 """
155 Determine the appropriate action for a user message.
157 Args:
158 message: User message to route
159 context: Optional context about the conversation
161 Returns:
162 RouterDecision with action and reasoning
163 """
164 with tracer.start_as_current_span("pydantic_ai.route") as span:
165 span.set_attribute("message.length", len(message))
167 try:
168 # Build prompt with context
169 prompt = message
170 if context:
171 context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
172 prompt = f"Context:\n{context_str}\n\nUser message: {message}"
174 # Get routing decision
175 # Note: .output replaces deprecated .data in pydantic-ai 1.0+
176 # See: https://ai.pydantic.dev/changelog/
177 result = await self.router_agent.run(prompt)
178 decision = result.output
180 span.set_attribute("decision.action", decision.action)
181 span.set_attribute("decision.confidence", decision.confidence)
183 logger.info(
184 "Message routed",
185 extra={"action": decision.action, "confidence": decision.confidence, "reasoning": decision.reasoning},
186 )
188 metrics.successful_calls.add(1, {"operation": "route_message"})
190 return decision
192 except Exception as e:
193 logger.error(f"Routing failed: {e}", extra={"user_message": message}, exc_info=True)
194 metrics.failed_calls.add(1, {"operation": "route_message"})
195 span.record_exception(e)
196 raise
198 async def generate_response(self, messages: list[BaseMessage], context: dict[str, Any] | None = None) -> AgentResponse:
199 """
200 Generate a typed response to user messages.
202 Args:
203 messages: Conversation history
204 context: Optional context information
206 Returns:
207 AgentResponse with content and metadata
208 """
209 with tracer.start_as_current_span("pydantic_ai.generate_response") as span:
210 span.set_attribute("message.count", len(messages))
212 try:
213 # Convert messages to prompt
214 conversation = self._format_conversation(messages)
216 # Add context if provided
217 if context:
218 context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
219 conversation = f"Context:\n{context_str}\n\n{conversation}"
221 # Generate response
222 # Note: .output replaces deprecated .data in pydantic-ai 1.0+
223 # See: https://ai.pydantic.dev/changelog/
224 result = await self.response_agent.run(conversation)
225 response = result.output
227 span.set_attribute("response.length", len(response.content))
228 span.set_attribute("response.confidence", response.confidence)
229 span.set_attribute("response.requires_clarification", response.requires_clarification)
231 logger.info(
232 "Response generated",
233 extra={
234 "confidence": response.confidence,
235 "requires_clarification": response.requires_clarification,
236 "content_length": len(response.content),
237 },
238 )
240 metrics.successful_calls.add(1, {"operation": "generate_response"})
242 return response
244 except Exception as e:
245 logger.error(f"Response generation failed: {e}", extra={"message_count": len(messages)}, exc_info=True)
246 metrics.failed_calls.add(1, {"operation": "generate_response"})
247 span.record_exception(e)
248 raise
250 def _format_conversation(self, messages: list[BaseMessage]) -> str:
251 """
252 Format conversation history for Pydantic AI.
254 Args:
255 messages: List of LangChain messages
257 Returns:
258 Formatted conversation string
259 """
260 lines = []
261 for msg in messages:
262 if isinstance(msg, HumanMessage):
263 lines.append(f"User: {msg.content}")
264 elif isinstance(msg, AIMessage):
265 lines.append(f"Assistant: {msg.content}")
266 else:
267 lines.append(f"System: {msg.content}")
269 return "\n\n".join(lines)
272# Factory function for easy integration
273def create_pydantic_agent(model_name: str | None = None, provider: str | None = None) -> PydanticAIAgentWrapper:
274 """
275 Create a Pydantic AI agent wrapper.
277 Args:
278 model_name: Model to use (defaults to settings)
279 provider: Provider name (defaults to settings)
281 Returns:
282 Configured PydanticAIAgentWrapper instance
284 Raises:
285 ImportError: If pydantic-ai is not installed
286 """
287 if not PYDANTIC_AI_AVAILABLE:
288 msg = (
289 "pydantic-ai is not installed. "
290 "Add 'pydantic-ai' to pyproject.toml dependencies, then run: uv sync\n"
291 "The agent will fall back to standard routing without Pydantic AI."
292 )
293 raise ImportError(msg)
295 return PydanticAIAgentWrapper(model_name=model_name, provider=provider)