Coverage for src / mcp_server_langgraph / llm / validators.py: 83%

88 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2LLM Response Validators using Pydantic AI 

3 

4Provides type-safe validation and structured extraction from LLM responses. 

5""" 

6 

7from typing import Generic, TypeVar 

8 

9from langchain_core.messages import AIMessage 

10from pydantic import BaseModel, Field, ValidationError 

11 

12from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

13 

14T = TypeVar("T", bound=BaseModel) 

15 

16 

17class ValidatedResponse(Generic[T]): 

18 """ 

19 Container for validated LLM response with metadata. 

20 

21 Generic type parameter T should be a Pydantic model defining 

22 the expected structure of the response. 

23 """ 

24 

25 def __init__(self, data: T, raw_content: str, validation_success: bool = True, validation_errors: list[str] | None = None): 

26 """ 

27 Initialize validated response. 

28 

29 Args: 

30 data: Validated and parsed data 

31 raw_content: Original LLM response text 

32 validation_success: Whether validation passed 

33 validation_errors: List of validation error messages 

34 """ 

35 self.data = data 

36 self.raw_content = raw_content 

37 self.validation_success = validation_success 

38 self.validation_errors = validation_errors or [] 

39 

40 def is_valid(self) -> bool: 

41 """Check if response passed validation.""" 

42 return self.validation_success 

43 

44 def get_errors(self) -> list[str]: 

45 """Get validation error messages.""" 

46 return self.validation_errors 

47 

48 

49class EntityExtraction(BaseModel): 

50 """Extracted entities from text.""" 

51 

52 entities: list[dict[str, str]] = Field(description="List of extracted entities with type and value") 

53 confidence: float = Field(ge=0.0, le=1.0, description="Confidence in extraction quality") 

54 

55 

56class IntentClassification(BaseModel): 

57 """User intent classification.""" 

58 

59 intent: str = Field(description="Primary intent category") 

60 confidence: float = Field(ge=0.0, le=1.0, description="Confidence score") 

61 sub_intents: list[str] = Field(default_factory=list, description="Secondary or related intents") 

62 parameters: dict[str, str] = Field(default_factory=dict, description="Extracted parameters for the intent") 

63 

64 

65class SentimentAnalysis(BaseModel): 

66 """Sentiment analysis result.""" 

67 

68 sentiment: str = Field(description="Overall sentiment (positive/negative/neutral)") 

69 score: float = Field(ge=-1.0, le=1.0, description="Sentiment score (-1 to 1)") 

70 emotions: list[str] = Field(default_factory=list, description="Detected emotions") 

71 

72 

73class SummaryExtraction(BaseModel): 

74 """Extracted summary with key points.""" 

75 

76 summary: str = Field(description="Concise summary of content") 

77 key_points: list[str] = Field(description="Key points from content") 

78 length: int = Field(description="Character count of summary") 

79 compression_ratio: float = Field(ge=0.0, le=1.0, description="Ratio of summary to original length") 

80 

81 

82class LLMValidator: 

83 """ 

84 Validator for LLM responses with Pydantic models. 

85 

86 Provides structured validation and extraction from free-text LLM outputs. 

87 """ 

88 

89 @staticmethod 

90 def validate_response(response: AIMessage | str, model_class: type[T], strict: bool = False) -> ValidatedResponse[T]: 

91 """ 

92 Validate LLM response against a Pydantic model. 

93 

94 Args: 

95 response: LLM response (AIMessage or string) 

96 model_class: Pydantic model class for validation 

97 strict: Whether to raise exception on validation failure 

98 

99 Returns: 

100 ValidatedResponse containing parsed data or errors 

101 

102 Raises: 

103 ValidationError: If strict=True and validation fails 

104 """ 

105 with tracer.start_as_current_span("llm.validate_response") as span: 

106 # Extract content 

107 if isinstance(response, AIMessage): 

108 # Handle structured content (list) by converting to string 

109 content = response.content if isinstance(response.content, str) else str(response.content) 

110 else: 

111 content = str(response) 

112 

113 span.set_attribute("response.length", len(content)) 

114 span.set_attribute("model.name", model_class.__name__) 

115 

116 try: 

117 # Attempt to parse as JSON first (for structured outputs) 

118 import json 

119 

120 try: 

121 data_dict = json.loads(content) 

122 validated = model_class(**data_dict) 

123 except json.JSONDecodeError: 

124 # Not JSON, try to parse as text 

125 # This requires the model to handle string input 

126 validated = model_class(content=content) if hasattr(model_class, "content") else None # type: ignore[assignment] 

127 

128 if validated is None: 128 ↛ 134line 128 didn't jump to line 134 because the condition on line 128 was always true

129 raise ValueError( # noqa: TRY003 

130 f"Cannot parse non-JSON content for {model_class.__name__}. " # noqa: EM102 

131 "Model must accept 'content' field or response must be valid JSON." 

132 ) 

133 

134 span.set_attribute("validation.success", True) 

135 

136 logger.info( 

137 "Response validated successfully", extra={"model": model_class.__name__, "content_length": len(content)} 

138 ) 

139 

140 metrics.successful_calls.add(1, {"operation": "validate_response"}) 

141 

142 return ValidatedResponse(data=validated, raw_content=content, validation_success=True) 

143 

144 except ValidationError as e: 

145 span.set_attribute("validation.success", False) 

146 span.record_exception(e) 

147 

148 errors = [str(err) for err in e.errors()] 

149 

150 logger.warning("Response validation failed", extra={"model": model_class.__name__, "errors": errors}) 

151 

152 metrics.failed_calls.add(1, {"operation": "validate_response"}) 

153 

154 if strict: 

155 raise 

156 

157 # Return invalid response with errors 

158 # Create empty instance for data 

159 try: 

160 empty_data = model_class() 

161 except Exception: 

162 empty_data = None 

163 

164 return ValidatedResponse( 

165 data=empty_data, # type: ignore[arg-type] 

166 raw_content=content, 

167 validation_success=False, 

168 validation_errors=errors, 

169 ) 

170 

171 except Exception as e: 

172 span.record_exception(e) 

173 

174 logger.error(f"Unexpected validation error: {e}", exc_info=True) 

175 

176 metrics.failed_calls.add(1, {"operation": "validate_response"}) 

177 

178 if strict: 

179 raise 

180 

181 return ValidatedResponse(data=None, raw_content=content, validation_success=False, validation_errors=[str(e)]) # type: ignore[arg-type] 

182 

183 @staticmethod 

184 def extract_entities(response: AIMessage | str) -> ValidatedResponse[EntityExtraction]: 

185 """ 

186 Extract named entities from LLM response. 

187 

188 Args: 

189 response: LLM response to extract from 

190 

191 Returns: 

192 ValidatedResponse with EntityExtraction data 

193 """ 

194 return LLMValidator.validate_response(response, EntityExtraction, strict=False) 

195 

196 @staticmethod 

197 def classify_intent(response: AIMessage | str) -> ValidatedResponse[IntentClassification]: 

198 """ 

199 Classify user intent from LLM response. 

200 

201 Args: 

202 response: LLM response with intent classification 

203 

204 Returns: 

205 ValidatedResponse with IntentClassification data 

206 """ 

207 return LLMValidator.validate_response(response, IntentClassification, strict=False) 

208 

209 @staticmethod 

210 def analyze_sentiment(response: AIMessage | str) -> ValidatedResponse[SentimentAnalysis]: 

211 """ 

212 Analyze sentiment from LLM response. 

213 

214 Args: 

215 response: LLM response with sentiment analysis 

216 

217 Returns: 

218 ValidatedResponse with SentimentAnalysis data 

219 """ 

220 return LLMValidator.validate_response(response, SentimentAnalysis, strict=False) 

221 

222 @staticmethod 

223 def extract_summary(response: AIMessage | str) -> ValidatedResponse[SummaryExtraction]: 

224 """ 

225 Extract summary and key points from LLM response. 

226 

227 Args: 

228 response: LLM response with summary 

229 

230 Returns: 

231 ValidatedResponse with SummaryExtraction data 

232 """ 

233 return LLMValidator.validate_response(response, SummaryExtraction, strict=False) 

234 

235 

236def validate_llm_response(response: AIMessage | str, expected_model: type[T]) -> ValidatedResponse[T]: 

237 """ 

238 Convenience function to validate LLM response. 

239 

240 Args: 

241 response: LLM response to validate 

242 expected_model: Pydantic model defining expected structure 

243 

244 Returns: 

245 ValidatedResponse with parsed data 

246 """ 

247 return LLMValidator.validate_response(response, expected_model, strict=False)