Coverage for src / mcp_server_langgraph / llm / pydantic_agent.py: 98%

105 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Pydantic AI Agent - Type-safe agent implementation with structured outputs 

3 

4This module provides strongly-typed agent responses using Pydantic AI, 

5ensuring LLM outputs match expected schemas and improving reliability. 

6 

7Note: pydantic-ai is an optional dependency. If not installed, this module will raise 

8ImportError when used. The agent.py module handles this gracefully with PYDANTIC_AI_AVAILABLE flag. 

9""" 

10 

11from typing import Any, Literal 

12 

13from langchain_core.messages import AIMessage, BaseMessage, HumanMessage 

14from pydantic import BaseModel, Field 

15 

16# Conditional import - pydantic-ai is optional for type-safe responses 

17# If not available, agent.py will fall back to standard routing 

18try: 

19 from pydantic_ai import Agent 

20 

21 PYDANTIC_AI_AVAILABLE = True 

22except ImportError: 

23 PYDANTIC_AI_AVAILABLE = False 

24 Agent = None # type: ignore[unused-ignore,assignment,misc] 

25 

26from mcp_server_langgraph.core.config import settings 

27from mcp_server_langgraph.core.prompts import RESPONSE_SYSTEM_PROMPT, ROUTER_SYSTEM_PROMPT 

28from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

29 

30 

31# Structured Response Models 

32class RouterDecision(BaseModel): 

33 """Routing decision with reasoning for agent workflow.""" 

34 

35 action: Literal["use_tools", "respond", "clarify"] = Field(description="Next action to take in the agent workflow") 

36 reasoning: str = Field(description="Explanation of why this action was chosen") 

37 tool_name: str | None = Field(default=None, description="Name of tool to use if action is 'use_tools'") 

38 confidence: float = Field(ge=0.0, le=1.0, description="Confidence score for this decision (0-1)") 

39 

40 

41class ToolExecution(BaseModel): 

42 """Result of tool execution with metadata.""" 

43 

44 tool_name: str = Field(description="Name of the executed tool") 

45 result: str = Field(description="Tool execution result") 

46 success: bool = Field(description="Whether tool execution succeeded") 

47 error_message: str | None = Field(default=None, description="Error message if execution failed") 

48 metadata: dict[str, str] = Field(default_factory=dict, description="Additional metadata about the execution") 

49 

50 

51class AgentResponse(BaseModel): 

52 """Final agent response with confidence and metadata.""" 

53 

54 content: str = Field(description="The main response content to show to the user") 

55 confidence: float = Field(ge=0.0, le=1.0, description="Confidence in this response (0-1)") 

56 requires_clarification: bool = Field(default=False, description="Whether the agent needs more information") 

57 clarification_question: str | None = Field(default=None, description="Question to ask user if clarification needed") 

58 sources: list[str] = Field(default_factory=list, description="Sources or references used to generate response") 

59 metadata: dict[str, str] = Field(default_factory=dict, description="Additional response metadata") 

60 

61 

62class PydanticAIAgentWrapper: 

63 """ 

64 Wrapper for Pydantic AI agents with type-safe responses. 

65 

66 Provides strongly-typed interactions with LLMs, ensuring outputs 

67 match expected schemas for improved reliability and debugging. 

68 """ 

69 

70 def __init__(self, model_name: str | None = None, provider: str | None = None) -> None: 

71 """ 

72 Initialize Pydantic AI agent wrapper. 

73 

74 Args: 

75 model_name: Model to use (defaults to settings.model_name) 

76 provider: Provider name (defaults to settings.llm_provider) 

77 """ 

78 self.model_name = model_name or settings.model_name 

79 self.provider = provider or settings.llm_provider 

80 

81 # Map provider to Pydantic AI model format 

82 self.pydantic_model_name = self._get_pydantic_model_name() 

83 

84 # Create specialized agents for different tasks with XML-structured prompts 

85 # Note: output_type replaces deprecated result_type in pydantic-ai 0.1.0+ 

86 self.router_agent = Agent( 

87 self.pydantic_model_name, 

88 output_type=RouterDecision, 

89 system_prompt=ROUTER_SYSTEM_PROMPT, 

90 ) 

91 

92 self.response_agent = Agent( 

93 self.pydantic_model_name, 

94 output_type=AgentResponse, 

95 system_prompt=RESPONSE_SYSTEM_PROMPT, 

96 ) 

97 

98 logger.info( 

99 "Pydantic AI agent wrapper initialized", 

100 extra={"model": self.model_name, "provider": self.provider, "pydantic_model": self.pydantic_model_name}, 

101 ) 

102 

103 def _get_pydantic_model_name(self) -> str: 

104 """ 

105 Map provider/model to Pydantic AI format with required provider prefix. 

106 

107 PYDANTIC-AI REQUIREMENT (v0.0.14+): All model names must include provider prefix. 

108 

109 Deprecation Warning Fixed: 

110 -------------------------- 

111 Old (deprecated): 'gemini-2.5-flash' 

112 New (required): 'google-gla:gemini-2.5-flash' 

113 

114 Without prefix, pydantic-ai emits: 

115 "DeprecationWarning: Specifying a model name without a provider prefix is deprecated. 

116 Instead of 'gemini-2.5-flash', use 'google-gla:gemini-2.5-flash'." 

117 

118 Provider Prefix Mapping: 

119 ------------------------ 

120 - Google Gemini: 'google-gla:' (e.g., 'google-gla:gemini-2.5-flash') 

121 - Anthropic Claude: 'anthropic:' (e.g., 'anthropic:claude-sonnet-4-5-20250929') 

122 - OpenAI: 'openai:' (e.g., 'openai:gpt-4') 

123 - Unknown: 'provider:' (generic fallback) 

124 

125 Returns: 

126 Pydantic AI compatible model name with provider prefix 

127 """ 

128 # Check if model name already has a prefix (edge case) 

129 if ":" in self.model_name: 

130 # Model name already includes provider prefix 

131 logger.debug( 

132 f"Model name '{self.model_name}' already has provider prefix", 

133 extra={"model": self.model_name, "provider": self.provider}, 

134 ) 

135 return self.model_name 

136 

137 # Add provider prefix based on provider type 

138 if self.provider == "google" or self.provider == "gemini": 

139 # Google Gemini models require google-gla prefix 

140 return f"google-gla:{self.model_name}" 

141 elif self.provider == "anthropic": 

142 return f"anthropic:{self.model_name}" 

143 elif self.provider == "openai": 

144 return f"openai:{self.model_name}" 

145 else: 

146 # Unknown provider: use provider name as prefix 

147 logger.warning( 

148 f"Unknown provider '{self.provider}', using provider-prefixed format", 

149 extra={"model": self.model_name, "provider": self.provider}, 

150 ) 

151 return f"{self.provider}:{self.model_name}" 

152 

153 async def route_message(self, message: str, context: dict | None = None) -> RouterDecision: # type: ignore[type-arg] 

154 """ 

155 Determine the appropriate action for a user message. 

156 

157 Args: 

158 message: User message to route 

159 context: Optional context about the conversation 

160 

161 Returns: 

162 RouterDecision with action and reasoning 

163 """ 

164 with tracer.start_as_current_span("pydantic_ai.route") as span: 

165 span.set_attribute("message.length", len(message)) 

166 

167 try: 

168 # Build prompt with context 

169 prompt = message 

170 if context: 

171 context_str = "\n".join(f"{k}: {v}" for k, v in context.items()) 

172 prompt = f"Context:\n{context_str}\n\nUser message: {message}" 

173 

174 # Get routing decision 

175 # Note: .output replaces deprecated .data in pydantic-ai 1.0+ 

176 # See: https://ai.pydantic.dev/changelog/ 

177 result = await self.router_agent.run(prompt) 

178 decision = result.output 

179 

180 span.set_attribute("decision.action", decision.action) 

181 span.set_attribute("decision.confidence", decision.confidence) 

182 

183 logger.info( 

184 "Message routed", 

185 extra={"action": decision.action, "confidence": decision.confidence, "reasoning": decision.reasoning}, 

186 ) 

187 

188 metrics.successful_calls.add(1, {"operation": "route_message"}) 

189 

190 return decision 

191 

192 except Exception as e: 

193 logger.error(f"Routing failed: {e}", extra={"user_message": message}, exc_info=True) 

194 metrics.failed_calls.add(1, {"operation": "route_message"}) 

195 span.record_exception(e) 

196 raise 

197 

198 async def generate_response(self, messages: list[BaseMessage], context: dict[str, Any] | None = None) -> AgentResponse: 

199 """ 

200 Generate a typed response to user messages. 

201 

202 Args: 

203 messages: Conversation history 

204 context: Optional context information 

205 

206 Returns: 

207 AgentResponse with content and metadata 

208 """ 

209 with tracer.start_as_current_span("pydantic_ai.generate_response") as span: 

210 span.set_attribute("message.count", len(messages)) 

211 

212 try: 

213 # Convert messages to prompt 

214 conversation = self._format_conversation(messages) 

215 

216 # Add context if provided 

217 if context: 

218 context_str = "\n".join(f"{k}: {v}" for k, v in context.items()) 

219 conversation = f"Context:\n{context_str}\n\n{conversation}" 

220 

221 # Generate response 

222 # Note: .output replaces deprecated .data in pydantic-ai 1.0+ 

223 # See: https://ai.pydantic.dev/changelog/ 

224 result = await self.response_agent.run(conversation) 

225 response = result.output 

226 

227 span.set_attribute("response.length", len(response.content)) 

228 span.set_attribute("response.confidence", response.confidence) 

229 span.set_attribute("response.requires_clarification", response.requires_clarification) 

230 

231 logger.info( 

232 "Response generated", 

233 extra={ 

234 "confidence": response.confidence, 

235 "requires_clarification": response.requires_clarification, 

236 "content_length": len(response.content), 

237 }, 

238 ) 

239 

240 metrics.successful_calls.add(1, {"operation": "generate_response"}) 

241 

242 return response 

243 

244 except Exception as e: 

245 logger.error(f"Response generation failed: {e}", extra={"message_count": len(messages)}, exc_info=True) 

246 metrics.failed_calls.add(1, {"operation": "generate_response"}) 

247 span.record_exception(e) 

248 raise 

249 

250 def _format_conversation(self, messages: list[BaseMessage]) -> str: 

251 """ 

252 Format conversation history for Pydantic AI. 

253 

254 Args: 

255 messages: List of LangChain messages 

256 

257 Returns: 

258 Formatted conversation string 

259 """ 

260 lines = [] 

261 for msg in messages: 

262 if isinstance(msg, HumanMessage): 

263 lines.append(f"User: {msg.content}") 

264 elif isinstance(msg, AIMessage): 

265 lines.append(f"Assistant: {msg.content}") 

266 else: 

267 lines.append(f"System: {msg.content}") 

268 

269 return "\n\n".join(lines) 

270 

271 

272# Factory function for easy integration 

273def create_pydantic_agent(model_name: str | None = None, provider: str | None = None) -> PydanticAIAgentWrapper: 

274 """ 

275 Create a Pydantic AI agent wrapper. 

276 

277 Args: 

278 model_name: Model to use (defaults to settings) 

279 provider: Provider name (defaults to settings) 

280 

281 Returns: 

282 Configured PydanticAIAgentWrapper instance 

283 

284 Raises: 

285 ImportError: If pydantic-ai is not installed 

286 """ 

287 if not PYDANTIC_AI_AVAILABLE: 

288 msg = ( 

289 "pydantic-ai is not installed. " 

290 "Add 'pydantic-ai' to pyproject.toml dependencies, then run: uv sync\n" 

291 "The agent will fall back to standard routing without Pydantic AI." 

292 ) 

293 raise ImportError(msg) 

294 

295 return PydanticAIAgentWrapper(model_name=model_name, provider=provider)