Coverage for src / mcp_server_langgraph / mcp / streaming.py: 83%
90 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2MCP Streaming with Pydantic AI Validation
4Provides type-safe streaming responses for MCP server with validation.
5"""
7import asyncio
8import json
9from collections.abc import AsyncIterator
11from pydantic import BaseModel, Field
13from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
16class StreamChunk(BaseModel):
17 """Single chunk of streaming response."""
19 content: str = Field(description="Content chunk")
20 chunk_index: int = Field(description="Index in the stream sequence")
21 is_final: bool = Field(default=False, description="Whether this is the final chunk")
22 metadata: dict[str, str] = Field(default_factory=dict, description="Chunk metadata")
25class StreamedResponse(BaseModel):
26 """Complete streaming response with validation."""
28 chunks: list[StreamChunk] = Field(description="All received chunks")
29 total_length: int = Field(description="Total content length")
30 chunk_count: int = Field(description="Number of chunks received")
31 is_complete: bool = Field(description="Whether stream completed successfully")
32 error_message: str | None = Field(default=None, description="Error message if stream failed")
34 def get_full_content(self) -> str:
35 """
36 Reconstruct full content from chunks.
38 Returns:
39 Complete concatenated content
40 """
41 return "".join(chunk.content for chunk in self.chunks)
44class MCPStreamingValidator:
45 """
46 Validator for MCP streaming responses.
48 Ensures streaming chunks are properly formatted and validates
49 the complete response once streaming is finished.
50 """
52 def __init__(self) -> None:
53 """Initialize streaming validator."""
54 self.active_streams: dict[str, list[StreamChunk]] = {}
56 async def validate_chunk(self, chunk_data: dict, stream_id: str) -> StreamChunk | None: # type: ignore[type-arg]
57 """
58 Validate a single stream chunk.
60 Args:
61 chunk_data: Raw chunk data
62 stream_id: Unique stream identifier
64 Returns:
65 Validated StreamChunk or None if invalid
66 """
67 with tracer.start_as_current_span("mcp.validate_chunk") as span:
68 span.set_attribute("stream.id", stream_id)
70 try:
71 chunk = StreamChunk(**chunk_data)
73 # Track chunk in active stream
74 if stream_id not in self.active_streams:
75 self.active_streams[stream_id] = []
77 self.active_streams[stream_id].append(chunk)
79 span.set_attribute("chunk.index", chunk.chunk_index)
80 span.set_attribute("chunk.is_final", chunk.is_final)
82 logger.debug(
83 "Chunk validated",
84 extra={"stream_id": stream_id, "chunk_index": chunk.chunk_index, "content_length": len(chunk.content)},
85 )
87 metrics.successful_calls.add(1, {"operation": "validate_chunk"})
89 return chunk
91 except Exception as e:
92 logger.error(f"Chunk validation failed: {e}", extra={"stream_id": stream_id}, exc_info=True)
94 metrics.failed_calls.add(1, {"operation": "validate_chunk"})
95 span.record_exception(e)
97 return None
99 async def finalize_stream(self, stream_id: str) -> StreamedResponse:
100 """
101 Finalize and validate complete stream.
103 Args:
104 stream_id: Stream to finalize
106 Returns:
107 Complete StreamedResponse with all chunks
108 """
109 with tracer.start_as_current_span("mcp.finalize_stream") as span:
110 span.set_attribute("stream.id", stream_id)
112 if stream_id not in self.active_streams:
113 logger.warning(f"Attempted to finalize unknown stream: {stream_id}")
115 return StreamedResponse(
116 chunks=[], total_length=0, chunk_count=0, is_complete=False, error_message="Stream not found"
117 )
119 chunks = self.active_streams[stream_id]
121 # Calculate totals
122 total_length = sum(len(chunk.content) for chunk in chunks)
123 chunk_count = len(chunks)
125 # Check if stream completed properly
126 is_complete = any(chunk.is_final for chunk in chunks)
128 response = StreamedResponse(
129 chunks=chunks,
130 total_length=total_length,
131 chunk_count=chunk_count,
132 is_complete=is_complete,
133 error_message=None if is_complete else "Stream incomplete",
134 )
136 span.set_attribute("stream.total_length", total_length)
137 span.set_attribute("stream.chunk_count", chunk_count)
138 span.set_attribute("stream.is_complete", is_complete)
140 logger.info(
141 "Stream finalized",
142 extra={
143 "stream_id": stream_id,
144 "total_length": total_length,
145 "chunk_count": chunk_count,
146 "is_complete": is_complete,
147 },
148 )
150 # Clean up
151 del self.active_streams[stream_id]
153 metrics.successful_calls.add(1, {"operation": "finalize_stream"})
155 return response
158async def stream_validated_response(content: str, chunk_size: int = 100, stream_id: str | None = None) -> AsyncIterator[str]:
159 """
160 Stream response with validation as newline-delimited JSON.
162 Args:
163 content: Content to stream
164 chunk_size: Size of each chunk in characters
165 stream_id: Optional stream identifier
167 Yields:
168 JSON-serialized StreamChunk objects
169 """
170 stream_id = stream_id or "default"
172 with tracer.start_as_current_span("mcp.stream_validated") as span:
173 span.set_attribute("stream.id", stream_id)
174 span.set_attribute("content.length", len(content))
175 span.set_attribute("chunk.size", chunk_size)
177 try:
178 # Split content into chunks
179 chunks = [content[i : i + chunk_size] for i in range(0, len(content), chunk_size)]
181 total_chunks = len(chunks)
182 span.set_attribute("stream.total_chunks", total_chunks)
184 logger.info(
185 "Starting validated stream",
186 extra={"stream_id": stream_id, "total_chunks": total_chunks, "content_length": len(content)},
187 )
189 # Stream each chunk
190 for idx, chunk_content in enumerate(chunks):
191 is_final = idx == total_chunks - 1
193 chunk = StreamChunk(
194 content=chunk_content,
195 chunk_index=idx,
196 is_final=is_final,
197 metadata={"stream_id": stream_id, "total_chunks": str(total_chunks)},
198 )
200 # Yield as newline-delimited JSON
201 yield json.dumps(chunk.model_dump()) + "\n"
203 # Small delay to simulate realistic streaming
204 await asyncio.sleep(0.01)
206 metrics.successful_calls.add(1, {"operation": "stream_validated"})
208 except Exception as e:
209 logger.error(f"Streaming failed: {e}", extra={"stream_id": stream_id}, exc_info=True)
211 metrics.failed_calls.add(1, {"operation": "stream_validated"})
212 span.record_exception(e)
214 # Yield error chunk
215 error_chunk = StreamChunk(content="", chunk_index=-1, is_final=True, metadata={"error": str(e)})
217 yield json.dumps(error_chunk.model_dump()) + "\n"
220async def stream_agent_response(response_content: str, include_metadata: bool = True) -> AsyncIterator[str]:
221 """
222 Stream agent response with optional metadata.
224 Args:
225 response_content: Agent response to stream
226 include_metadata: Whether to include metadata in chunks
228 Yields:
229 JSON chunks with validated structure
230 """
231 with tracer.start_as_current_span("mcp.stream_agent_response"):
232 # Stream with validation
233 async for chunk in stream_validated_response(response_content):
234 yield chunk
236 # Optionally yield final metadata chunk
237 if include_metadata:
238 metadata_chunk = StreamChunk(
239 content="",
240 chunk_index=-1,
241 is_final=True,
242 metadata={"type": "completion", "total_length": str(len(response_content))},
243 )
245 yield json.dumps(metadata_chunk.model_dump()) + "\n"
248# Global validator instance
249_streaming_validator = MCPStreamingValidator()
252def get_streaming_validator() -> MCPStreamingValidator:
253 """
254 Get global streaming validator instance.
256 Returns:
257 MCPStreamingValidator singleton
258 """
259 return _streaming_validator