Coverage for src / mcp_server_langgraph / llm / validators.py: 83%
88 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2LLM Response Validators using Pydantic AI
4Provides type-safe validation and structured extraction from LLM responses.
5"""
7from typing import Generic, TypeVar
9from langchain_core.messages import AIMessage
10from pydantic import BaseModel, Field, ValidationError
12from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
14T = TypeVar("T", bound=BaseModel)
17class ValidatedResponse(Generic[T]):
18 """
19 Container for validated LLM response with metadata.
21 Generic type parameter T should be a Pydantic model defining
22 the expected structure of the response.
23 """
25 def __init__(self, data: T, raw_content: str, validation_success: bool = True, validation_errors: list[str] | None = None):
26 """
27 Initialize validated response.
29 Args:
30 data: Validated and parsed data
31 raw_content: Original LLM response text
32 validation_success: Whether validation passed
33 validation_errors: List of validation error messages
34 """
35 self.data = data
36 self.raw_content = raw_content
37 self.validation_success = validation_success
38 self.validation_errors = validation_errors or []
40 def is_valid(self) -> bool:
41 """Check if response passed validation."""
42 return self.validation_success
44 def get_errors(self) -> list[str]:
45 """Get validation error messages."""
46 return self.validation_errors
49class EntityExtraction(BaseModel):
50 """Extracted entities from text."""
52 entities: list[dict[str, str]] = Field(description="List of extracted entities with type and value")
53 confidence: float = Field(ge=0.0, le=1.0, description="Confidence in extraction quality")
56class IntentClassification(BaseModel):
57 """User intent classification."""
59 intent: str = Field(description="Primary intent category")
60 confidence: float = Field(ge=0.0, le=1.0, description="Confidence score")
61 sub_intents: list[str] = Field(default_factory=list, description="Secondary or related intents")
62 parameters: dict[str, str] = Field(default_factory=dict, description="Extracted parameters for the intent")
65class SentimentAnalysis(BaseModel):
66 """Sentiment analysis result."""
68 sentiment: str = Field(description="Overall sentiment (positive/negative/neutral)")
69 score: float = Field(ge=-1.0, le=1.0, description="Sentiment score (-1 to 1)")
70 emotions: list[str] = Field(default_factory=list, description="Detected emotions")
73class SummaryExtraction(BaseModel):
74 """Extracted summary with key points."""
76 summary: str = Field(description="Concise summary of content")
77 key_points: list[str] = Field(description="Key points from content")
78 length: int = Field(description="Character count of summary")
79 compression_ratio: float = Field(ge=0.0, le=1.0, description="Ratio of summary to original length")
82class LLMValidator:
83 """
84 Validator for LLM responses with Pydantic models.
86 Provides structured validation and extraction from free-text LLM outputs.
87 """
89 @staticmethod
90 def validate_response(response: AIMessage | str, model_class: type[T], strict: bool = False) -> ValidatedResponse[T]:
91 """
92 Validate LLM response against a Pydantic model.
94 Args:
95 response: LLM response (AIMessage or string)
96 model_class: Pydantic model class for validation
97 strict: Whether to raise exception on validation failure
99 Returns:
100 ValidatedResponse containing parsed data or errors
102 Raises:
103 ValidationError: If strict=True and validation fails
104 """
105 with tracer.start_as_current_span("llm.validate_response") as span:
106 # Extract content
107 if isinstance(response, AIMessage):
108 # Handle structured content (list) by converting to string
109 content = response.content if isinstance(response.content, str) else str(response.content)
110 else:
111 content = str(response)
113 span.set_attribute("response.length", len(content))
114 span.set_attribute("model.name", model_class.__name__)
116 try:
117 # Attempt to parse as JSON first (for structured outputs)
118 import json
120 try:
121 data_dict = json.loads(content)
122 validated = model_class(**data_dict)
123 except json.JSONDecodeError:
124 # Not JSON, try to parse as text
125 # This requires the model to handle string input
126 validated = model_class(content=content) if hasattr(model_class, "content") else None # type: ignore[assignment]
128 if validated is None: 128 ↛ 134line 128 didn't jump to line 134 because the condition on line 128 was always true
129 raise ValueError( # noqa: TRY003
130 f"Cannot parse non-JSON content for {model_class.__name__}. " # noqa: EM102
131 "Model must accept 'content' field or response must be valid JSON."
132 )
134 span.set_attribute("validation.success", True)
136 logger.info(
137 "Response validated successfully", extra={"model": model_class.__name__, "content_length": len(content)}
138 )
140 metrics.successful_calls.add(1, {"operation": "validate_response"})
142 return ValidatedResponse(data=validated, raw_content=content, validation_success=True)
144 except ValidationError as e:
145 span.set_attribute("validation.success", False)
146 span.record_exception(e)
148 errors = [str(err) for err in e.errors()]
150 logger.warning("Response validation failed", extra={"model": model_class.__name__, "errors": errors})
152 metrics.failed_calls.add(1, {"operation": "validate_response"})
154 if strict:
155 raise
157 # Return invalid response with errors
158 # Create empty instance for data
159 try:
160 empty_data = model_class()
161 except Exception:
162 empty_data = None
164 return ValidatedResponse(
165 data=empty_data, # type: ignore[arg-type]
166 raw_content=content,
167 validation_success=False,
168 validation_errors=errors,
169 )
171 except Exception as e:
172 span.record_exception(e)
174 logger.error(f"Unexpected validation error: {e}", exc_info=True)
176 metrics.failed_calls.add(1, {"operation": "validate_response"})
178 if strict:
179 raise
181 return ValidatedResponse(data=None, raw_content=content, validation_success=False, validation_errors=[str(e)]) # type: ignore[arg-type]
183 @staticmethod
184 def extract_entities(response: AIMessage | str) -> ValidatedResponse[EntityExtraction]:
185 """
186 Extract named entities from LLM response.
188 Args:
189 response: LLM response to extract from
191 Returns:
192 ValidatedResponse with EntityExtraction data
193 """
194 return LLMValidator.validate_response(response, EntityExtraction, strict=False)
196 @staticmethod
197 def classify_intent(response: AIMessage | str) -> ValidatedResponse[IntentClassification]:
198 """
199 Classify user intent from LLM response.
201 Args:
202 response: LLM response with intent classification
204 Returns:
205 ValidatedResponse with IntentClassification data
206 """
207 return LLMValidator.validate_response(response, IntentClassification, strict=False)
209 @staticmethod
210 def analyze_sentiment(response: AIMessage | str) -> ValidatedResponse[SentimentAnalysis]:
211 """
212 Analyze sentiment from LLM response.
214 Args:
215 response: LLM response with sentiment analysis
217 Returns:
218 ValidatedResponse with SentimentAnalysis data
219 """
220 return LLMValidator.validate_response(response, SentimentAnalysis, strict=False)
222 @staticmethod
223 def extract_summary(response: AIMessage | str) -> ValidatedResponse[SummaryExtraction]:
224 """
225 Extract summary and key points from LLM response.
227 Args:
228 response: LLM response with summary
230 Returns:
231 ValidatedResponse with SummaryExtraction data
232 """
233 return LLMValidator.validate_response(response, SummaryExtraction, strict=False)
236def validate_llm_response(response: AIMessage | str, expected_model: type[T]) -> ValidatedResponse[T]:
237 """
238 Convenience function to validate LLM response.
240 Args:
241 response: LLM response to validate
242 expected_model: Pydantic model defining expected structure
244 Returns:
245 ValidatedResponse with parsed data
246 """
247 return LLMValidator.validate_response(response, expected_model, strict=False)