Coverage for src / mcp_server_langgraph / mcp / streaming.py: 83%

90 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2MCP Streaming with Pydantic AI Validation 

3 

4Provides type-safe streaming responses for MCP server with validation. 

5""" 

6 

7import asyncio 

8import json 

9from collections.abc import AsyncIterator 

10 

11from pydantic import BaseModel, Field 

12 

13from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

14 

15 

16class StreamChunk(BaseModel): 

17 """Single chunk of streaming response.""" 

18 

19 content: str = Field(description="Content chunk") 

20 chunk_index: int = Field(description="Index in the stream sequence") 

21 is_final: bool = Field(default=False, description="Whether this is the final chunk") 

22 metadata: dict[str, str] = Field(default_factory=dict, description="Chunk metadata") 

23 

24 

25class StreamedResponse(BaseModel): 

26 """Complete streaming response with validation.""" 

27 

28 chunks: list[StreamChunk] = Field(description="All received chunks") 

29 total_length: int = Field(description="Total content length") 

30 chunk_count: int = Field(description="Number of chunks received") 

31 is_complete: bool = Field(description="Whether stream completed successfully") 

32 error_message: str | None = Field(default=None, description="Error message if stream failed") 

33 

34 def get_full_content(self) -> str: 

35 """ 

36 Reconstruct full content from chunks. 

37 

38 Returns: 

39 Complete concatenated content 

40 """ 

41 return "".join(chunk.content for chunk in self.chunks) 

42 

43 

44class MCPStreamingValidator: 

45 """ 

46 Validator for MCP streaming responses. 

47 

48 Ensures streaming chunks are properly formatted and validates 

49 the complete response once streaming is finished. 

50 """ 

51 

52 def __init__(self) -> None: 

53 """Initialize streaming validator.""" 

54 self.active_streams: dict[str, list[StreamChunk]] = {} 

55 

56 async def validate_chunk(self, chunk_data: dict, stream_id: str) -> StreamChunk | None: # type: ignore[type-arg] 

57 """ 

58 Validate a single stream chunk. 

59 

60 Args: 

61 chunk_data: Raw chunk data 

62 stream_id: Unique stream identifier 

63 

64 Returns: 

65 Validated StreamChunk or None if invalid 

66 """ 

67 with tracer.start_as_current_span("mcp.validate_chunk") as span: 

68 span.set_attribute("stream.id", stream_id) 

69 

70 try: 

71 chunk = StreamChunk(**chunk_data) 

72 

73 # Track chunk in active stream 

74 if stream_id not in self.active_streams: 

75 self.active_streams[stream_id] = [] 

76 

77 self.active_streams[stream_id].append(chunk) 

78 

79 span.set_attribute("chunk.index", chunk.chunk_index) 

80 span.set_attribute("chunk.is_final", chunk.is_final) 

81 

82 logger.debug( 

83 "Chunk validated", 

84 extra={"stream_id": stream_id, "chunk_index": chunk.chunk_index, "content_length": len(chunk.content)}, 

85 ) 

86 

87 metrics.successful_calls.add(1, {"operation": "validate_chunk"}) 

88 

89 return chunk 

90 

91 except Exception as e: 

92 logger.error(f"Chunk validation failed: {e}", extra={"stream_id": stream_id}, exc_info=True) 

93 

94 metrics.failed_calls.add(1, {"operation": "validate_chunk"}) 

95 span.record_exception(e) 

96 

97 return None 

98 

99 async def finalize_stream(self, stream_id: str) -> StreamedResponse: 

100 """ 

101 Finalize and validate complete stream. 

102 

103 Args: 

104 stream_id: Stream to finalize 

105 

106 Returns: 

107 Complete StreamedResponse with all chunks 

108 """ 

109 with tracer.start_as_current_span("mcp.finalize_stream") as span: 

110 span.set_attribute("stream.id", stream_id) 

111 

112 if stream_id not in self.active_streams: 

113 logger.warning(f"Attempted to finalize unknown stream: {stream_id}") 

114 

115 return StreamedResponse( 

116 chunks=[], total_length=0, chunk_count=0, is_complete=False, error_message="Stream not found" 

117 ) 

118 

119 chunks = self.active_streams[stream_id] 

120 

121 # Calculate totals 

122 total_length = sum(len(chunk.content) for chunk in chunks) 

123 chunk_count = len(chunks) 

124 

125 # Check if stream completed properly 

126 is_complete = any(chunk.is_final for chunk in chunks) 

127 

128 response = StreamedResponse( 

129 chunks=chunks, 

130 total_length=total_length, 

131 chunk_count=chunk_count, 

132 is_complete=is_complete, 

133 error_message=None if is_complete else "Stream incomplete", 

134 ) 

135 

136 span.set_attribute("stream.total_length", total_length) 

137 span.set_attribute("stream.chunk_count", chunk_count) 

138 span.set_attribute("stream.is_complete", is_complete) 

139 

140 logger.info( 

141 "Stream finalized", 

142 extra={ 

143 "stream_id": stream_id, 

144 "total_length": total_length, 

145 "chunk_count": chunk_count, 

146 "is_complete": is_complete, 

147 }, 

148 ) 

149 

150 # Clean up 

151 del self.active_streams[stream_id] 

152 

153 metrics.successful_calls.add(1, {"operation": "finalize_stream"}) 

154 

155 return response 

156 

157 

158async def stream_validated_response(content: str, chunk_size: int = 100, stream_id: str | None = None) -> AsyncIterator[str]: 

159 """ 

160 Stream response with validation as newline-delimited JSON. 

161 

162 Args: 

163 content: Content to stream 

164 chunk_size: Size of each chunk in characters 

165 stream_id: Optional stream identifier 

166 

167 Yields: 

168 JSON-serialized StreamChunk objects 

169 """ 

170 stream_id = stream_id or "default" 

171 

172 with tracer.start_as_current_span("mcp.stream_validated") as span: 

173 span.set_attribute("stream.id", stream_id) 

174 span.set_attribute("content.length", len(content)) 

175 span.set_attribute("chunk.size", chunk_size) 

176 

177 try: 

178 # Split content into chunks 

179 chunks = [content[i : i + chunk_size] for i in range(0, len(content), chunk_size)] 

180 

181 total_chunks = len(chunks) 

182 span.set_attribute("stream.total_chunks", total_chunks) 

183 

184 logger.info( 

185 "Starting validated stream", 

186 extra={"stream_id": stream_id, "total_chunks": total_chunks, "content_length": len(content)}, 

187 ) 

188 

189 # Stream each chunk 

190 for idx, chunk_content in enumerate(chunks): 

191 is_final = idx == total_chunks - 1 

192 

193 chunk = StreamChunk( 

194 content=chunk_content, 

195 chunk_index=idx, 

196 is_final=is_final, 

197 metadata={"stream_id": stream_id, "total_chunks": str(total_chunks)}, 

198 ) 

199 

200 # Yield as newline-delimited JSON 

201 yield json.dumps(chunk.model_dump()) + "\n" 

202 

203 # Small delay to simulate realistic streaming 

204 await asyncio.sleep(0.01) 

205 

206 metrics.successful_calls.add(1, {"operation": "stream_validated"}) 

207 

208 except Exception as e: 

209 logger.error(f"Streaming failed: {e}", extra={"stream_id": stream_id}, exc_info=True) 

210 

211 metrics.failed_calls.add(1, {"operation": "stream_validated"}) 

212 span.record_exception(e) 

213 

214 # Yield error chunk 

215 error_chunk = StreamChunk(content="", chunk_index=-1, is_final=True, metadata={"error": str(e)}) 

216 

217 yield json.dumps(error_chunk.model_dump()) + "\n" 

218 

219 

220async def stream_agent_response(response_content: str, include_metadata: bool = True) -> AsyncIterator[str]: 

221 """ 

222 Stream agent response with optional metadata. 

223 

224 Args: 

225 response_content: Agent response to stream 

226 include_metadata: Whether to include metadata in chunks 

227 

228 Yields: 

229 JSON chunks with validated structure 

230 """ 

231 with tracer.start_as_current_span("mcp.stream_agent_response"): 

232 # Stream with validation 

233 async for chunk in stream_validated_response(response_content): 

234 yield chunk 

235 

236 # Optionally yield final metadata chunk 

237 if include_metadata: 

238 metadata_chunk = StreamChunk( 

239 content="", 

240 chunk_index=-1, 

241 is_final=True, 

242 metadata={"type": "completion", "total_length": str(len(response_content))}, 

243 ) 

244 

245 yield json.dumps(metadata_chunk.model_dump()) + "\n" 

246 

247 

248# Global validator instance 

249_streaming_validator = MCPStreamingValidator() 

250 

251 

252def get_streaming_validator() -> MCPStreamingValidator: 

253 """ 

254 Get global streaming validator instance. 

255 

256 Returns: 

257 MCPStreamingValidator singleton 

258 """ 

259 return _streaming_validator