Coverage for src / mcp_server_langgraph / mcp / server_streamable.py: 59%
549 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2MCP Server with StreamableHTTP transport
3Implements the MCP StreamableHTTP specification (replaces deprecated SSE)
5Implements Anthropic's best practices for writing tools for agents:
6- Token-efficient responses with truncation
7- Search-focused tools instead of list-all
8- Response format control (concise vs detailed)
9- Namespaced tools for clarity
10- High-signal information in responses
11"""
13import json
14import logging
15import sys
16from collections.abc import AsyncIterator
17from contextlib import asynccontextmanager
18from typing import Any, Literal
20import uvicorn
21from fastapi import FastAPI, HTTPException, Request
22from fastapi.middleware.cors import CORSMiddleware
23from fastapi.responses import JSONResponse, StreamingResponse
24from langchain_core.messages import HumanMessage
25from mcp.server import Server
26from mcp.types import Resource, TextContent, Tool
27from pydantic import AnyUrl, BaseModel, Field
29from mcp_server_langgraph.api.auth_request_middleware import AuthRequestMiddleware
30from mcp_server_langgraph.auth.factory import create_auth_middleware, create_user_provider
31from mcp_server_langgraph.auth.middleware import AuthMiddleware
32from mcp_server_langgraph.auth.openfga import OpenFGAClient
33from mcp_server_langgraph.auth.user_provider import KeycloakUserProvider
34from mcp_server_langgraph.core.agent import AgentState, get_agent_graph
35from mcp_server_langgraph.core.config import Settings, settings
36from mcp_server_langgraph.core.security import sanitize_for_logging
37from mcp_server_langgraph.middleware.rate_limiter import custom_rate_limit_exceeded_handler, limiter
38from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
39from mcp_server_langgraph.utils.response_optimizer import format_response
42@asynccontextmanager
43async def lifespan(app: FastAPI) -> AsyncIterator[None]:
44 """
45 Lifespan context manager for application startup and shutdown.
47 CRITICAL: This ensures observability is initialized before handling requests,
48 preventing crashes when launching with: uvicorn mcp_server_langgraph.mcp.server_streamable:app
50 Without this handler, logger/tracer/metrics usage in get_mcp_server() and downstream
51 code would fail when observability hasn't been initialized yet.
52 """
53 # Startup
54 from mcp_server_langgraph.observability.telemetry import init_observability, is_initialized
56 if not is_initialized():
57 logger_temp = logging.getLogger(__name__)
58 logger_temp.info("Initializing observability from startup event")
60 # Initialize with file logging if configured
61 enable_file_logging = getattr(settings, "enable_file_logging", False)
62 init_observability(settings=settings, enable_file_logging=enable_file_logging)
64 logger_temp.info("Observability initialized successfully")
66 # Initialize global auth middleware for FastAPI dependencies
67 # This must happen AFTER observability is initialized (for logging)
68 try:
69 from mcp_server_langgraph.auth.middleware import FASTAPI_AVAILABLE, set_global_auth_middleware
71 if FASTAPI_AVAILABLE:
72 # Get MCP server instance (creates it if needed)
73 mcp_server = get_mcp_server()
75 # Set the auth middleware globally for FastAPI dependencies
76 set_global_auth_middleware(mcp_server.auth)
78 logger.info("Global auth middleware initialized for FastAPI dependencies")
79 except Exception as e:
80 logger.warning(f"Failed to initialize global auth middleware: {e}")
82 yield
84 # Shutdown - cleanup observability and close connections
85 from mcp_server_langgraph.observability.telemetry import shutdown_observability
87 logger.info("Application shutdown initiated")
89 # Cleanup checkpointer resources (Redis connections, etc.)
90 try:
91 from mcp_server_langgraph.core.agent import cleanup_checkpointer
93 agent_graph = get_agent_graph() # type: ignore[func-returns-value]
94 if agent_graph and hasattr(agent_graph, "checkpointer") and agent_graph.checkpointer:
95 cleanup_checkpointer(agent_graph.checkpointer)
96 logger.info("Checkpointer resources cleaned up")
97 except Exception as e:
98 logger.warning(f"Error cleaning up checkpointer: {e}")
100 # Shutdown observability (flush spans, close exporters)
101 shutdown_observability()
103 # Close Prometheus client if initialized
104 try:
105 mcp_server = get_mcp_server()
106 if hasattr(mcp_server, "prometheus_client") and mcp_server.prometheus_client:
107 await mcp_server.prometheus_client.close()
108 logger.info("Prometheus client closed")
109 except Exception as e:
110 logger.warning(f"Error closing Prometheus client: {e}")
112 logger.info("Application shutdown complete")
115app = FastAPI(
116 title="MCP Server with LangGraph",
117 description="AI Agent with fine-grained authorization and observability - StreamableHTTP transport",
118 version=settings.service_version,
119 docs_url="/docs",
120 redoc_url="/redoc",
121 openapi_tags=[
122 {
123 "name": "API Metadata",
124 "description": "API version information and metadata for client compatibility",
125 },
126 {
127 "name": "mcp",
128 "description": "Model Context Protocol (MCP) endpoints for agent interaction",
129 },
130 {
131 "name": "health",
132 "description": "Health check and system status endpoints",
133 },
134 {
135 "name": "auth",
136 "description": "Authentication and authorization endpoints",
137 },
138 {
139 "name": "GDPR Compliance",
140 "description": "GDPR data protection and privacy rights endpoints (Articles 15-21)",
141 },
142 {
143 "name": "API Keys",
144 "description": "API key management for programmatic access",
145 },
146 {
147 "name": "Service Principals",
148 "description": "Service principal (service account) management for machine-to-machine authentication",
149 },
150 {
151 "name": "SCIM 2.0",
152 "description": "System for Cross-domain Identity Management (SCIM) 2.0 user and group provisioning",
153 },
154 ],
155 responses={
156 401: {"description": "Unauthorized - Invalid or missing authentication token"},
157 403: {"description": "Forbidden - Insufficient permissions"},
158 429: {"description": "Too Many Requests - Rate limit exceeded"},
159 500: {"description": "Internal Server Error"},
160 },
161 lifespan=lifespan,
162)
164# CORS middleware
165# SECURITY: Use config-based origins instead of wildcard
166# Empty list = no CORS (production default), specific origins in development
167cors_origins = settings.get_cors_origins()
168app.add_middleware(
169 CORSMiddleware,
170 allow_origins=cors_origins,
171 allow_credentials=bool(cors_origins), # Only allow credentials if origins specified
172 allow_methods=["*"],
173 allow_headers=["*"],
174 expose_headers=["X-Request-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining"],
175)
177# Rate limiting middleware
178# SECURITY: Protect against DoS, brute force, and API abuse
179# Uses tiered rate limits (anonymous: 10/min, free: 60/min, premium: 1000/min, enterprise: unlimited)
180# Tracks by user ID (from JWT) > IP address > global anonymous
181# Fail-open: allows requests if Redis is down (graceful degradation)
182#
183# NOTE: Use standard logging.getLogger() here instead of observability logger
184# because this code runs at module import time, before lifespan initializes observability
185_module_logger = logging.getLogger(__name__)
186try:
187 from slowapi.errors import RateLimitExceeded
189 # Register rate limiter with app
190 app.state.limiter = limiter
192 # Register custom exception handler for rate limit exceeded
193 app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) # type: ignore[arg-type]
195 _module_logger.info(
196 "Rate limiting enabled - strategy: fixed-window, tiers: anonymous/free/standard/premium/enterprise, fail_open: True"
197 )
198except Exception as e:
199 _module_logger.warning(f"Failed to initialize rate limiting: {e}. Requests will proceed without rate limits.")
201# Authentication middleware for REST API endpoints (GDPR, API Keys, Service Principals, SCIM)
202# This must be added AFTER rate limiting but BEFORE routes are defined
203# Uses module-level auth middleware creation (lazy init for observability compatibility)
204try:
205 # Create user provider and auth middleware for REST API authentication
206 # This enables request.state.user for all protected endpoints
207 _module_user_provider = create_user_provider(settings)
208 _module_auth_middleware = AuthMiddleware(
209 secret_key=settings.jwt_secret_key,
210 settings=settings,
211 user_provider=_module_user_provider,
212 )
213 app.add_middleware(AuthRequestMiddleware, auth_middleware=_module_auth_middleware)
214 _module_logger.info("AuthRequestMiddleware enabled for REST API authentication")
215except Exception as e:
216 _module_logger.warning(f"Failed to initialize AuthRequestMiddleware: {e}. REST API endpoints will require manual auth.")
219class ChatInput(BaseModel):
220 """
221 Input schema for agent_chat tool.
223 Follows Anthropic best practices:
224 - Unambiguous parameter names (user_id not username, message not query)
225 - Response format control for token efficiency
226 - Clear field descriptions
227 """
229 message: str = Field(description="The user message to send to the agent", min_length=1, max_length=10000)
230 token: str = Field(
231 description=(
232 "JWT authentication token. Obtain via /auth/login endpoint (HTTP) "
233 "or external authentication service. Required for all tool calls."
234 )
235 )
236 user_id: str = Field(description="User identifier for authentication and authorization")
237 thread_id: str | None = Field(
238 default=None,
239 description="Optional thread ID for conversation continuity (e.g., 'conv_123')",
240 pattern=r"^[a-zA-Z0-9_-]{1,128}$", # SECURITY: Prevent CWE-20 injection into OpenFGA/Redis/logs
241 )
242 response_format: Literal["concise", "detailed"] = Field(
243 default="concise",
244 description=(
245 "Response verbosity level. "
246 "'concise' returns ~500 tokens (faster, less context). "
247 "'detailed' returns ~2000 tokens (comprehensive, more context)."
248 ),
249 )
251 # Backward compatibility - DEPRECATED
252 username: str | None = Field(
253 default=None, deprecated=True, description="DEPRECATED: Use 'user_id' instead. Maintained for backward compatibility."
254 )
256 @property
257 def effective_user_id(self) -> str:
258 """Get effective user ID, prioritizing user_id over deprecated username."""
259 return self.user_id if hasattr(self, "user_id") and self.user_id else (self.username or "")
262class SearchConversationsInput(BaseModel):
263 """Input schema for conversation_search tool."""
265 query: str = Field(description="Search query to filter conversations", min_length=1, max_length=500)
266 token: str = Field(
267 description=(
268 "JWT authentication token. Obtain via /auth/login endpoint (HTTP) "
269 "or external authentication service. Required for all tool calls."
270 )
271 )
272 user_id: str = Field(description="User identifier for authentication and authorization")
273 limit: int = Field(default=10, ge=1, le=50, description="Maximum number of conversations to return (1-50)")
275 # Backward compatibility - DEPRECATED
276 username: str | None = Field(
277 default=None, deprecated=True, description="DEPRECATED: Use 'user_id' instead. Maintained for backward compatibility."
278 )
280 @property
281 def effective_user_id(self) -> str:
282 """Get effective user ID, prioritizing user_id over deprecated username."""
283 return self.user_id if hasattr(self, "user_id") and self.user_id else (self.username or "")
286class MCPAgentStreamableServer:
287 """MCP Server with StreamableHTTP transport"""
289 def __init__(
290 self,
291 openfga_client: OpenFGAClient | None = None,
292 settings: Settings | None = None,
293 ) -> None:
294 """
295 Initialize MCP Agent Streamable Server with optional dependency injection.
297 Args:
298 openfga_client: Optional OpenFGA client for authorization.
299 If None, creates one from settings.
300 settings: Optional Settings instance for runtime configuration.
301 If provided, enables dynamic feature toggling (e.g., code execution).
302 If None, uses global settings. This allows tests to inject custom
303 configuration without module reloading.
305 Example:
306 # Default creation (production):
307 server = MCPAgentStreamableServer()
309 # Custom settings injection (testing):
310 test_settings = Settings(enable_code_execution=True)
311 server = MCPAgentStreamableServer(settings=test_settings)
312 """
313 # Store settings for runtime configuration
314 # NOTE: When settings=None, we must reference the module-level 'settings'
315 # imported at the top of this file. This allows tests to mock settings via
316 # @patch("mcp_server_langgraph.mcp.server_streamable.settings", ...)
317 # Using sys.modules[__name__] to avoid self-import (CodeQL py/import-own-module)
318 self.settings = settings if settings is not None else sys.modules[__name__].settings
320 self.server = Server("langgraph-agent")
322 # Initialize OpenFGA client
323 self.openfga = openfga_client or self._create_openfga_client()
325 # Validate JWT secret is configured (fail-closed security pattern)
326 if not self.settings.jwt_secret_key:
327 msg = (
328 "CRITICAL: JWT secret key not configured. "
329 "Set JWT_SECRET_KEY environment variable or configure via Infisical. "
330 "The service cannot start without a secure secret key."
331 )
332 raise ValueError(msg)
334 # SECURITY: Fail-closed pattern - require OpenFGA in production
335 if self.settings.environment == "production" and self.openfga is None:
336 msg = (
337 "CRITICAL: OpenFGA authorization is required in production mode. "
338 "Configure OPENFGA_STORE_ID and OPENFGA_MODEL_ID environment variables, "
339 "or set ENVIRONMENT=development for local testing. "
340 "Fallback authorization is not secure enough for production use."
341 )
342 raise ValueError(msg)
344 # Initialize auth using factory (respects settings.auth_provider)
345 self.auth = create_auth_middleware(self.settings, openfga_client=self.openfga)
347 self._setup_handlers()
349 def _create_openfga_client(self) -> OpenFGAClient | None:
350 """Create OpenFGA client from settings"""
351 if self.settings.openfga_store_id and self.settings.openfga_model_id:
352 logger.info(
353 "Initializing OpenFGA client",
354 extra={"store_id": self.settings.openfga_store_id, "model_id": self.settings.openfga_model_id},
355 )
356 return OpenFGAClient(
357 api_url=self.settings.openfga_api_url,
358 store_id=self.settings.openfga_store_id,
359 model_id=self.settings.openfga_model_id,
360 )
361 else:
362 logger.warning("OpenFGA not configured, authorization will use fallback mode")
363 return None
365 async def list_tools_public(self) -> list[Tool]:
366 """
367 Public API to list available tools.
369 This wraps the internal MCP handler to avoid accessing private SDK attributes.
370 """
371 # Call the registered handler
372 # The handler is registered below in _setup_handlers()
373 return await self._list_tools_handler() # type: ignore[no-any-return]
375 async def call_tool_public(self, name: str, arguments: dict[str, Any]) -> list[TextContent]:
376 """
377 Public API to call a tool.
379 This wraps the internal MCP handler to avoid accessing private SDK attributes.
380 """
381 return await self._call_tool_handler(name, arguments) # type: ignore[no-any-return]
383 async def list_resources_public(self) -> list[Resource]:
384 """
385 Public API to list available resources.
387 This wraps the internal MCP handler to avoid accessing private SDK attributes.
388 """
389 return await self._list_resources_handler() # type: ignore[no-any-return]
391 def _setup_handlers(self) -> None:
392 """Setup MCP protocol handlers and store references for public API"""
394 @self.server.list_tools() # type: ignore[no-untyped-call, untyped-decorator]
395 async def list_tools() -> list[Tool]:
396 """
397 List available tools.
399 Tools follow Anthropic best practices:
400 - Namespaced for clarity (agent_*, conversation_*)
401 - Search-focused instead of list-all
402 - Clear usage guidance in descriptions
403 - Token limits and expected response times documented
404 """
405 with tracer.start_as_current_span("mcp.list_tools"):
406 logger.info("Listing available tools")
407 tools = [
408 Tool(
409 name="agent_chat",
410 description=(
411 "Chat with the AI agent for questions, research, and problem-solving. "
412 "Returns responses optimized for agent consumption. "
413 "Response format: 'concise' (~500 tokens, 2-5 sec) or 'detailed' (~2000 tokens, 5-10 sec). "
414 "For specialized tasks like code execution or web search, use dedicated tools instead. "
415 "Rate limit: 60 requests/minute per user."
416 ),
417 inputSchema=ChatInput.model_json_schema(),
418 ),
419 Tool(
420 name="conversation_get",
421 description=(
422 "Retrieve a specific conversation thread by ID. "
423 "Returns conversation history with messages, participants, and metadata. "
424 "Response time: <1 second. "
425 "Use conversation_search to find conversation IDs first."
426 ),
427 inputSchema={
428 "type": "object",
429 "properties": {
430 "thread_id": {
431 "type": "string",
432 "description": "Conversation thread identifier (e.g., 'conv_abc123')",
433 },
434 "token": {
435 "type": "string",
436 "description": "JWT authentication token. Required for all tool calls.",
437 },
438 "user_id": {
439 "type": "string",
440 "description": "User identifier for authentication and authorization",
441 },
442 "username": {"type": "string", "description": "DEPRECATED: Use 'user_id' instead"},
443 },
444 "required": ["thread_id", "token", "user_id"],
445 },
446 ),
447 Tool(
448 name="conversation_search",
449 description=(
450 "Search conversations using keywords or filters. "
451 "Returns matching conversations sorted by relevance. "
452 "Much more efficient than listing all conversations. "
453 "Response time: <2 seconds. "
454 "Examples: 'project updates', 'conversations with alice', 'last week'. "
455 "Results limited to 50 conversations max to prevent context overflow."
456 ),
457 inputSchema=SearchConversationsInput.model_json_schema(),
458 ),
459 ]
461 # Add search_tools for progressive discovery (Anthropic best practice)
462 tools.append(
463 Tool(
464 name="search_tools",
465 description=(
466 "Search and discover available tools using progressive disclosure. "
467 "Query by keyword or category instead of loading all tool definitions. "
468 "Saves 98%+ tokens compared to list-all approach. "
469 "Detail levels: minimal (name+desc), standard (+params), full (+schema). "
470 "Categories: calculator, search, filesystem, execution. "
471 "Response time: <1 second."
472 ),
473 inputSchema={
474 "type": "object",
475 "properties": {
476 "query": {"type": "string", "description": "Search query (keyword)"},
477 "category": {"type": "string", "description": "Tool category filter"},
478 "detail_level": {
479 "type": "string",
480 "enum": ["minimal", "standard", "full"],
481 "description": "Level of detail in results",
482 },
483 },
484 },
485 )
486 )
488 # Add execute_python if code execution is enabled
489 if self.settings.enable_code_execution:
490 from mcp_server_langgraph.tools.code_execution_tools import ExecutePythonInput
492 tools.append(
493 Tool(
494 name="execute_python",
495 description=(
496 "Execute Python code in a secure sandboxed environment. "
497 "Security: Import whitelist, no eval/exec, resource limits (CPU, memory, timeout). "
498 "Backends: docker-engine (local/dev) or kubernetes (production). "
499 "Network: Configurable isolation (none/allowlist/unrestricted). "
500 "Response time: 1-30 seconds depending on code complexity. "
501 "Use for data processing, calculations, and Python-specific tasks."
502 ),
503 inputSchema=ExecutePythonInput.model_json_schema(),
504 )
505 )
507 return tools
509 # Store reference to handler for public API
510 self._list_tools_handler = list_tools
512 @self.server.call_tool() # type: ignore[untyped-decorator] # MCP library decorator lacks type stubs
513 async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
514 """Handle tool calls with OpenFGA authorization and tracing"""
516 with tracer.start_as_current_span("mcp.call_tool", attributes={"tool.name": name}) as span:
517 # SECURITY: Sanitize arguments before logging to prevent CWE-200/CWE-532 (token exposure in logs)
518 logger.info(f"Tool called: {name}", extra={"tool": name, "tool_args": sanitize_for_logging(arguments)})
519 metrics.tool_calls.add(1, {"tool": name})
521 # SECURITY: Require JWT token for all tool calls
522 token = arguments.get("token")
524 if not token:
525 logger.warning("No authentication token provided")
526 metrics.auth_failures.add(1)
527 msg = (
528 "Authentication token required. Provide 'token' parameter with a valid JWT. "
529 "Obtain token via /auth/login endpoint or external authentication service."
530 )
531 raise PermissionError(msg)
533 # Verify JWT token
534 token_verification = await self.auth.verify_token(token)
536 if not token_verification.valid:
537 logger.warning("Token verification failed", extra={"error": token_verification.error})
538 metrics.auth_failures.add(1)
539 msg = f"Invalid authentication token: {token_verification.error or 'token verification failed'}"
540 raise PermissionError(msg)
542 # Extract user_id from validated token payload
543 if not token_verification.payload: 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true
544 logger.error("Token payload is empty")
545 metrics.auth_failures.add(1)
546 msg = "Invalid token: missing user identifier"
547 raise PermissionError(msg)
549 # Extract username with defensive fallback
550 # Priority: preferred_username > username claim > sub parsing
551 # Keycloak uses UUID in 'sub', but OpenFGA needs 'user:username' format
552 # NOTE: Some Keycloak configurations may not include 'sub' in access tokens
553 username = token_verification.payload.get("preferred_username")
554 if not username: 554 ↛ 556line 554 didn't jump to line 556 because the condition on line 554 was never true
555 # Try 'username' claim (alternative standard claim)
556 username = token_verification.payload.get("username")
557 if not username: 557 ↛ 559line 557 didn't jump to line 559 because the condition on line 557 was never true
558 # Fallback: extract from sub if available
559 sub = token_verification.payload.get("sub", "")
560 if sub.startswith("user:"):
561 username = sub.split(":", 1)[1]
562 elif sub and ":" not in sub:
563 # Log warning for UUID-style subs (may cause issues)
564 logger.warning(
565 f"Using sub as username fallback (may be UUID): {sub[:8]}...",
566 extra={"sub_prefix": sub[:8]},
567 )
568 username = sub
570 # Final check: ensure we have a username
571 if not username: 571 ↛ 572line 571 didn't jump to line 572 because the condition on line 571 was never true
572 logger.error("Token missing user identifier (no sub, preferred_username, or username claim)")
573 metrics.auth_failures.add(1)
574 msg = "Invalid token: cannot extract username from claims"
575 raise PermissionError(msg)
577 # Normalize user_id to "user:username" format for OpenFGA compatibility
578 user_id = f"user:{username}" if not username.startswith("user:") else username
579 span.set_attribute("user.id", user_id)
581 logger.info("User authenticated via token", extra={"user_id": user_id, "tool": name})
583 # Check OpenFGA authorization
584 resource = f"tool:{name}"
586 authorized = await self.auth.authorize(user_id=user_id, relation="executor", resource=resource)
588 if not authorized: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true
589 logger.warning(
590 "Authorization failed (OpenFGA)",
591 extra={"user_id": user_id, "resource": resource, "relation": "executor"},
592 )
593 metrics.authz_failures.add(1, {"resource": resource})
594 msg = f"Not authorized to execute {resource}"
595 raise PermissionError(msg)
597 logger.info("Authorization granted", extra={"user_id": user_id, "resource": resource})
599 # Route to appropriate handler (with backward compatibility)
600 if name == "agent_chat" or name == "chat": # Support old name for compatibility 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true
601 return await self._handle_chat(arguments, span, user_id)
602 elif name == "conversation_get" or name == "get_conversation": 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true
603 return await self._handle_get_conversation(arguments, span, user_id)
604 elif name == "conversation_search" or name == "list_conversations": 604 ↛ 605line 604 didn't jump to line 605 because the condition on line 604 was never true
605 return await self._handle_search_conversations(arguments, span, user_id)
606 elif name == "search_tools": 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true
607 return await self._handle_search_tools(arguments, span)
608 elif name == "execute_python": 608 ↛ 609line 608 didn't jump to line 609 because the condition on line 608 was never true
609 return await self._handle_execute_python(arguments, span, user_id)
610 else:
611 msg = f"Unknown tool: {name}"
612 raise ValueError(msg)
614 # Store reference to handler for public API
615 self._call_tool_handler = call_tool
617 @self.server.list_resources() # type: ignore[no-untyped-call, untyped-decorator]
618 async def list_resources() -> list[Resource]:
619 """List available resources"""
620 with tracer.start_as_current_span("mcp.list_resources"):
621 return [Resource(uri=AnyUrl("agent://config"), name="Agent Configuration", mimeType="application/json")]
623 # Store reference to handler for public API
624 self._list_resources_handler = list_resources
626 async def _handle_chat(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]:
627 """
628 Handle agent_chat tool invocation.
630 Implements Anthropic best practices:
631 - Response format control (concise vs detailed)
632 - Token-efficient responses with truncation
633 - Clear error messages
634 - Performance tracking
635 """
636 with tracer.start_as_current_span("agent.chat"):
637 # BUGFIX: Validate input with Pydantic schema to enforce length limits and required fields
638 try:
639 chat_input = ChatInput.model_validate(arguments)
640 except Exception as e:
641 # SECURITY: Sanitize arguments before logging to prevent token exposure in error logs
642 logger.error(f"Invalid chat input: {e}", extra={"arguments": sanitize_for_logging(arguments)})
643 msg = f"Invalid chat input: {e}"
644 raise ValueError(msg)
646 message = chat_input.message
647 thread_id = chat_input.thread_id or "default"
648 response_format_type = chat_input.response_format
650 span.set_attribute("message.length", len(message))
651 span.set_attribute("thread.id", thread_id)
652 span.set_attribute("user.id", user_id)
653 span.set_attribute("response.format", response_format_type)
655 # Check if user can access this conversation
656 # BUGFIX: Allow first-time conversation creation without pre-existing OpenFGA tuples
657 # For new conversations, we short-circuit authorization and will seed ownership after creation
658 conversation_resource = f"conversation:{thread_id}"
660 # Check if conversation exists by trying to get state from checkpointer
661 graph = get_agent_graph() # type: ignore[func-returns-value]
662 conversation_exists = False
663 if hasattr(graph, "checkpointer") and graph.checkpointer is not None:
664 try:
665 config = {"configurable": {"thread_id": thread_id}}
666 state_snapshot = await graph.aget_state(config)
667 conversation_exists = state_snapshot is not None and state_snapshot.values is not None
668 except Exception:
669 # If we can't check, assume it doesn't exist (fail-open for creation)
670 conversation_exists = False
672 # Only check authorization for existing conversations
673 if conversation_exists:
674 can_edit = await self.auth.authorize(user_id=user_id, relation="editor", resource=conversation_resource)
675 if not can_edit:
676 logger.warning("User cannot edit conversation", extra={"user_id": user_id, "thread_id": thread_id})
677 msg = (
678 f"Not authorized to edit conversation '{thread_id}'. "
679 f"Request access from conversation owner or use a different thread_id."
680 )
681 raise PermissionError(msg)
682 else:
683 # New conversation - user becomes implicit owner (OpenFGA tuples should be seeded after creation)
684 logger.info(
685 "Creating new conversation, user granted implicit ownership",
686 extra={"user_id": user_id, "thread_id": thread_id},
687 )
689 logger.info(
690 "Processing chat message",
691 extra={
692 "thread_id": thread_id,
693 "user_id": user_id,
694 "message_preview": message[:100],
695 "response_format": response_format_type,
696 },
697 )
699 # Create initial state with proper LangChain message objects
700 # BUGFIX: Use HumanMessage instead of dict to avoid type errors in graph nodes
701 initial_state: AgentState = {
702 "messages": [HumanMessage(content=message)],
703 "next_action": "",
704 "user_id": user_id,
705 "request_id": str(span.get_span_context().trace_id) if span.get_span_context() else None,
706 "routing_confidence": None,
707 "reasoning": None,
708 "compaction_applied": None,
709 "original_message_count": None,
710 "verification_passed": None,
711 "verification_score": None,
712 "verification_feedback": None,
713 "refinement_attempts": None,
714 "user_request": message,
715 }
717 # Run the agent graph
718 config = {"configurable": {"thread_id": thread_id}}
720 try:
721 result = await get_agent_graph().ainvoke(initial_state, config) # type: ignore[func-returns-value]
723 # Seed OpenFGA tuples for new conversations
724 if not conversation_exists and self.openfga is not None:
725 try:
726 # Create ownership, viewer, and editor tuples in batch
727 await self.openfga.write_tuples(
728 [
729 {"user": user_id, "relation": "owner", "object": conversation_resource},
730 {"user": user_id, "relation": "viewer", "object": conversation_resource},
731 {"user": user_id, "relation": "editor", "object": conversation_resource},
732 ]
733 )
735 logger.info(
736 "OpenFGA tuples seeded for new conversation",
737 extra={"user_id": user_id, "thread_id": thread_id},
738 )
740 except Exception as e:
741 # Log warning but don't fail the request
742 # The conversation was created successfully, ACL seeding is best-effort
743 logger.warning(
744 f"Failed to seed OpenFGA tuples for new conversation: {e}",
745 extra={"user_id": user_id, "thread_id": thread_id},
746 exc_info=True,
747 )
749 # Extract response
750 response_message = result["messages"][-1]
751 response_text = response_message.content
753 # Apply response formatting based on format type
754 # Follows Anthropic guidance: offer response_format enum parameter
755 formatted_response = format_response(response_text, format_type=response_format_type)
757 span.set_attribute("response.length.original", len(response_text))
758 span.set_attribute("response.length.formatted", len(formatted_response))
759 metrics.successful_calls.add(1, {"tool": "agent_chat", "format": response_format_type})
761 logger.info(
762 "Chat response generated",
763 extra={
764 "thread_id": thread_id,
765 "original_length": len(response_text),
766 "formatted_length": len(formatted_response),
767 "format": response_format_type,
768 },
769 )
771 return [TextContent(type="text", text=formatted_response)]
773 except Exception as e:
774 logger.error(f"Error processing chat: {e}", extra={"error": str(e), "thread_id": thread_id}, exc_info=True)
775 metrics.failed_calls.add(1, {"tool": "agent_chat", "error": type(e).__name__})
776 span.record_exception(e)
777 raise
779 async def _handle_get_conversation(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]:
780 """
781 Retrieve conversation history from LangGraph checkpointer.
783 Returns formatted conversation with messages, metadata, and participants.
784 """
785 with tracer.start_as_current_span("agent.get_conversation"):
786 thread_id = arguments["thread_id"]
788 # Check if user can view this conversation
789 conversation_resource = f"conversation:{thread_id}"
791 can_view = await self.auth.authorize(user_id=user_id, relation="viewer", resource=conversation_resource)
793 if not can_view: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true
794 logger.warning("User cannot view conversation", extra={"user_id": user_id, "thread_id": thread_id})
795 msg = f"Not authorized to view conversation {thread_id}"
796 raise PermissionError(msg)
798 # Retrieve conversation from checkpointer
799 graph = get_agent_graph() # type: ignore[func-returns-value]
801 if not hasattr(graph, "checkpointer") or graph.checkpointer is None: 801 ↛ 802line 801 didn't jump to line 802 because the condition on line 801 was never true
802 logger.warning("No checkpointer available, cannot retrieve conversation history")
803 return [
804 TextContent(
805 type="text",
806 text="Conversation history unavailable: checkpointing is disabled. "
807 "Enable checkpointing to persist conversation history.",
808 )
809 ]
811 try:
812 config = {"configurable": {"thread_id": thread_id}}
813 state_snapshot = await graph.aget_state(config)
815 if state_snapshot is None or state_snapshot.values is None: 815 ↛ 816line 815 didn't jump to line 816 because the condition on line 815 was never true
816 logger.info("Conversation not found", extra={"thread_id": thread_id})
817 return [TextContent(type="text", text=f"Conversation '{thread_id}' not found or has no history.")]
819 # Extract messages from state
820 messages = state_snapshot.values.get("messages", [])
822 if not messages: 822 ↛ 823line 822 didn't jump to line 823 because the condition on line 822 was never true
823 return [TextContent(type="text", text=f"Conversation '{thread_id}' has no messages yet.")]
825 # Format conversation history
826 formatted_lines = [f"Conversation: {thread_id}", f"Messages: {len(messages)}", ""]
828 for i, msg in enumerate(messages, 1):
829 msg_type = type(msg).__name__
830 content = getattr(msg, "content", str(msg))
832 # Truncate long messages for readability
833 if len(content) > 200: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true
834 content = content[:200] + "..."
836 formatted_lines.append(f"{i}. [{msg_type}] {content}")
838 # Add metadata if available
839 metadata = state_snapshot.values.get("metadata", {})
840 if metadata: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true
841 formatted_lines.append("")
842 formatted_lines.append("Metadata:")
843 for key, value in metadata.items():
844 formatted_lines.append(f" {key}: {value}")
846 response_text = "\n".join(formatted_lines)
848 logger.info(
849 "Conversation history retrieved",
850 extra={"thread_id": thread_id, "message_count": len(messages), "user_id": user_id},
851 )
853 return [TextContent(type="text", text=response_text)]
855 except Exception as e:
856 logger.error(f"Failed to retrieve conversation history: {e}", extra={"thread_id": thread_id}, exc_info=True)
857 return [
858 TextContent(
859 type="text",
860 text=f"Error retrieving conversation history: {e!s}",
861 )
862 ]
864 async def _handle_search_conversations(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]:
865 """
866 Search conversations (replacing list-all approach).
868 Implements Anthropic best practice:
869 "Implement search-focused tools (like search_contacts) rather than
870 list-all tools (list_contacts)"
872 Benefits:
873 - Prevents context overflow with large conversation lists
874 - Forces agents to be specific in their requests
875 - More token-efficient
876 - Better for users with many conversations
877 """
878 with tracer.start_as_current_span("agent.search_conversations"):
879 # BUGFIX: Validate input with Pydantic schema to enforce query length and limit constraints
880 try:
881 search_input = SearchConversationsInput.model_validate(arguments)
882 except Exception as e:
883 # SECURITY: Sanitize arguments before logging to prevent token exposure in error logs
884 logger.error(f"Invalid search input: {e}", extra={"arguments": sanitize_for_logging(arguments)})
885 msg = f"Invalid search input: {e}"
886 raise ValueError(msg)
888 query = search_input.query
889 limit = search_input.limit
891 span.set_attribute("search.query", query)
892 span.set_attribute("search.limit", limit)
894 # Get all conversations user can view
895 all_conversations = await self.auth.list_accessible_resources(
896 user_id=user_id, relation="viewer", resource_type="conversation"
897 )
899 # Filter conversations based on query (simple implementation)
900 # In production, this would use a proper search index
901 # Normalize query and conversation names to handle spaces/underscores/hyphens
902 if query: 902 ↛ 911line 902 didn't jump to line 911 because the condition on line 902 was always true
903 normalized_query = query.lower().replace(" ", "_").replace("-", "_")
904 filtered_conversations = [
905 conv
906 for conv in all_conversations
907 if (query.lower() in conv.lower() or normalized_query in conv.lower().replace(" ", "_").replace("-", "_"))
908 ]
909 else:
910 # No query: return most recent (up to limit)
911 filtered_conversations = all_conversations
913 # Apply limit to prevent context overflow
914 # Follows Anthropic guidance: "Restrict responses to ~25,000 tokens"
915 limited_conversations = filtered_conversations[:limit]
917 # Build response with high-signal information
918 # Avoid technical IDs where possible
919 if not limited_conversations:
920 response_text = (
921 f"No conversations found matching '{query}'. "
922 f"Try a different search query or request access to more conversations."
923 )
924 else:
925 response_lines = [
926 (
927 f"Found {len(limited_conversations)} conversation(s) matching '{query}':"
928 if query
929 else f"Showing {len(limited_conversations)} recent conversation(s):"
930 )
931 ]
933 for i, conv_id in enumerate(limited_conversations, 1):
934 # Extract human-readable info from conversation ID
935 # In production, fetch metadata like title, date, participants
936 response_lines.append(f"{i}. {conv_id}")
938 # Add guidance if results were truncated
939 if len(filtered_conversations) > limit: 939 ↛ 940line 939 didn't jump to line 940 because the condition on line 939 was never true
940 response_lines.append(
941 f"\n[Showing {limit} of {len(filtered_conversations)} results. "
942 f"Use a more specific query to narrow results.]"
943 )
945 response_text = "\n".join(response_lines)
947 logger.info(
948 "Searched conversations",
949 extra={
950 "user_id": user_id,
951 "query": query,
952 "total_accessible": len(all_conversations),
953 "filtered_count": len(filtered_conversations),
954 "returned_count": len(limited_conversations),
955 },
956 )
958 return [TextContent(type="text", text=response_text)]
960 async def _handle_search_tools(self, arguments: dict[str, Any], span: Any) -> list[TextContent]:
961 """
962 Handle search_tools invocation for progressive tool discovery.
964 Implements Anthropic best practice for token-efficient tool discovery.
965 """
966 with tracer.start_as_current_span("tools.search"):
967 from mcp_server_langgraph.tools.tool_discovery import search_tools
969 # Extract arguments
970 query = arguments.get("query")
971 category = arguments.get("category")
972 detail_level = arguments.get("detail_level", "minimal")
974 logger.info(
975 "Searching tools",
976 extra={"query": query, "category": category, "detail_level": detail_level},
977 )
979 # Execute search_tools
980 result = search_tools.invoke(
981 {
982 "query": query,
983 "category": category,
984 "detail_level": detail_level,
985 }
986 )
988 span.set_attribute("tools.query", query or "")
989 span.set_attribute("tools.category", category or "")
990 span.set_attribute("tools.detail_level", detail_level)
992 return [TextContent(type="text", text=result)]
994 async def _handle_execute_python(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]:
995 """
996 Handle execute_python invocation for secure code execution.
998 Implements sandboxed Python execution with validation and resource limits.
999 """
1000 with tracer.start_as_current_span("code.execute"):
1001 import time
1003 from mcp_server_langgraph.tools.code_execution_tools import execute_python
1005 # Extract arguments
1006 code = arguments.get("code", "")
1007 timeout = arguments.get("timeout")
1009 logger.info(
1010 "Executing Python code",
1011 extra={
1012 "user_id": user_id,
1013 "code_length": len(code),
1014 "timeout": timeout,
1015 },
1016 )
1018 # Execute code
1019 start_time = time.time()
1020 result = execute_python.invoke({"code": code, "timeout": timeout})
1021 execution_time = time.time() - start_time
1023 span.set_attribute("code.length", len(code))
1024 span.set_attribute("code.execution_time", execution_time)
1025 span.set_attribute("code.success", "success" in result.lower())
1027 metrics.code_executions.add(1, {"user_id": user_id, "success": "success" in result.lower()})
1029 return [TextContent(type="text", text=result)]
1032# ============================================================================
1033# Lazy Singleton Pattern for MCP Server
1034# ============================================================================
1035# Prevents import-time initialization which causes issues with:
1036# - Missing environment variables
1037# - Observability not yet initialized
1038# - Test dependency injection
1040_mcp_server_instance: MCPAgentStreamableServer | None = None
1043def get_mcp_server() -> MCPAgentStreamableServer:
1044 """
1045 Get or create the MCP server instance (lazy singleton)
1047 This ensures the server is only created after observability is initialized,
1048 avoiding import-time side effects and logging errors.
1050 Returns:
1051 MCPAgentStreamableServer singleton instance
1053 Raises:
1054 RuntimeError: If observability is not initialized
1055 """
1056 from mcp_server_langgraph.observability.telemetry import is_initialized
1058 # Guard: Ensure observability is initialized before creating server
1059 if not is_initialized():
1060 msg = (
1061 "Observability must be initialized before creating MCP server. "
1062 "This should be done automatically via the startup event handler. "
1063 "If you see this error, observability initialization failed or was skipped."
1064 )
1065 raise RuntimeError(msg)
1067 global _mcp_server_instance
1068 if _mcp_server_instance is None: 1068 ↛ 1070line 1068 didn't jump to line 1070 because the condition on line 1068 was always true
1069 _mcp_server_instance = MCPAgentStreamableServer()
1070 return _mcp_server_instance
1073# ============================================================================
1074# Authentication Models
1075# ============================================================================
1078class LoginRequest(BaseModel):
1079 """Login request with username and password"""
1081 username: str = Field(description="Username", min_length=1, max_length=100)
1082 password: str = Field(description="Password", min_length=1, max_length=500)
1085class LoginResponse(BaseModel):
1086 """Login response with JWT token"""
1088 access_token: str = Field(description="JWT access token")
1089 token_type: str = Field(default="bearer", description="Token type (always 'bearer')")
1090 expires_in: int = Field(description="Token expiration in seconds")
1091 user_id: str = Field(description="User identifier")
1092 username: str = Field(description="Username")
1093 roles: list[str] = Field(description="User roles")
1096class RefreshTokenRequest(BaseModel):
1097 """Token refresh request"""
1099 refresh_token: str | None = Field(description="Refresh token (Keycloak only)", default=None)
1100 current_token: str | None = Field(description="Current access token (for InMemory provider)", default=None)
1103class RefreshTokenResponse(BaseModel):
1104 """Token refresh response"""
1106 access_token: str = Field(description="New JWT access token")
1107 token_type: str = Field(default="bearer", description="Token type")
1108 expires_in: int = Field(description="Token expiration in seconds")
1109 refresh_token: str | None = Field(default=None, description="New refresh token (Keycloak only)")
1112# FastAPI endpoints for MCP StreamableHTTP transport
1113@app.get("/")
1114async def root() -> dict[str, Any]:
1115 """Root endpoint with server info"""
1116 return {
1117 "name": "MCP Server with LangGraph",
1118 "version": settings.service_version,
1119 "transport": "streamable-http",
1120 "protocol": "mcp",
1121 "endpoints": {
1122 "auth": {
1123 "login": "/auth/login",
1124 "refresh": "/auth/refresh",
1125 },
1126 "message": "/message",
1127 "tools": "/tools",
1128 "resources": "/resources",
1129 "health": "/health",
1130 },
1131 "capabilities": {
1132 "tools": {"listSupported": True, "callSupported": True},
1133 "resources": {"listSupported": True, "readSupported": True},
1134 "streaming": True,
1135 "authentication": {
1136 "methods": ["jwt"],
1137 "tokenRefresh": True,
1138 "providers": [settings.auth_provider],
1139 },
1140 },
1141 }
1144@app.post("/auth/login", tags=["auth"])
1145async def login(request: LoginRequest) -> LoginResponse:
1146 """
1147 Authenticate user and return JWT token
1149 This endpoint accepts username and password, validates credentials,
1150 and returns a JWT token that can be used for subsequent tool calls.
1152 The token should be included in the 'token' field of all tool call requests.
1154 Example:
1155 POST /auth/login
1156 {
1157 "username": "your-username",
1158 "password": "your-secure-password"
1159 }
1161 Response:
1162 {
1163 "access_token": "eyJ...",
1164 "token_type": "bearer",
1165 "expires_in": 3600,
1166 "user_id": "user:your-username",
1167 "username": "your-username",
1168 "roles": ["user"]
1169 }
1171 Note: InMemoryUserProvider no longer seeds default users (v2.8.0+).
1172 Create users explicitly via provider.add_user() for testing.
1173 """
1174 with tracer.start_as_current_span("auth.login") as span:
1175 span.set_attribute("auth.username", request.username)
1177 # Get MCP server instance (which has auth middleware)
1178 mcp_server_instance = get_mcp_server()
1180 # Authenticate user via configured provider
1181 auth_result = await mcp_server_instance.auth.user_provider.authenticate(request.username, request.password)
1183 if not auth_result.authorized:
1184 logger.warning(
1185 "Login failed", extra={"username": request.username, "reason": auth_result.reason, "error": auth_result.error}
1186 )
1187 metrics.auth_failures.add(1)
1188 raise HTTPException(
1189 status_code=401,
1190 detail=f"Authentication failed: {auth_result.reason or auth_result.error or 'invalid credentials'}",
1191 )
1193 # Create JWT token
1194 # For InMemoryUserProvider, use the create_token method
1195 # For KeycloakUserProvider, tokens are returned directly in auth_result
1196 if auth_result.access_token:
1197 # Keycloak provider returns token directly
1198 access_token = auth_result.access_token
1199 expires_in = auth_result.expires_in or 3600
1200 else:
1201 # InMemoryUserProvider needs to create token
1202 try:
1203 access_token = mcp_server_instance.auth.create_token(
1204 username=request.username, expires_in=mcp_server_instance.settings.jwt_expiration_seconds
1205 )
1206 expires_in = mcp_server_instance.settings.jwt_expiration_seconds
1207 except Exception as e:
1208 logger.error(f"Failed to create token: {e}", exc_info=True)
1209 raise HTTPException(status_code=500, detail="Failed to create authentication token")
1211 logger.info(
1212 "Login successful",
1213 extra={
1214 "username": request.username,
1215 "user_id": auth_result.user_id,
1216 "provider": type(mcp_server_instance.auth.user_provider).__name__,
1217 },
1218 )
1220 return LoginResponse(
1221 access_token=access_token,
1222 token_type="bearer",
1223 expires_in=expires_in,
1224 user_id=auth_result.user_id or "",
1225 username=auth_result.username or request.username,
1226 roles=auth_result.roles or [],
1227 )
1230@app.post("/auth/refresh", tags=["auth"])
1231async def refresh_token(request: RefreshTokenRequest) -> RefreshTokenResponse:
1232 """
1233 Refresh authentication token
1235 Supports two refresh methods:
1236 1. Keycloak: Uses refresh_token to get new access token
1237 2. InMemory: Validates current token and issues new one
1239 Example (Keycloak):
1240 POST /auth/refresh
1241 {
1242 "refresh_token": "eyJ..."
1243 }
1245 Example (InMemory):
1246 POST /auth/refresh
1247 {
1248 "current_token": "eyJ..."
1249 }
1251 Response:
1252 {
1253 "access_token": "eyJ...",
1254 "token_type": "bearer",
1255 "expires_in": 3600,
1256 "refresh_token": "eyJ..." // Keycloak only
1257 }
1258 """
1259 with tracer.start_as_current_span("auth.refresh") as span:
1260 mcp_server_instance = get_mcp_server()
1262 # Handle Keycloak refresh token
1263 if request.refresh_token:
1264 if not isinstance(mcp_server_instance.auth.user_provider, KeycloakUserProvider):
1265 logger.warning("Refresh token provided but provider is not Keycloak")
1266 raise HTTPException(
1267 status_code=400,
1268 detail="Refresh tokens are only supported with KeycloakUserProvider. "
1269 "Current provider does not support refresh tokens.",
1270 )
1272 try:
1273 # Use KeycloakUserProvider's refresh_token method
1274 result = await mcp_server_instance.auth.user_provider.refresh_token(request.refresh_token)
1276 if not result.get("success"):
1277 raise HTTPException(status_code=401, detail=f"Token refresh failed: {result.get('error')}")
1279 tokens = result["tokens"]
1281 logger.info("Token refreshed via Keycloak")
1283 return RefreshTokenResponse(
1284 access_token=tokens["access_token"],
1285 token_type="bearer",
1286 expires_in=tokens.get("expires_in", 300),
1287 refresh_token=tokens.get("refresh_token"),
1288 )
1290 except HTTPException:
1291 raise
1292 except Exception as e:
1293 logger.error(f"Token refresh failed: {e}", exc_info=True)
1294 raise HTTPException(status_code=500, detail="Token refresh failed")
1296 # Handle InMemory token refresh (verify current token and issue new one)
1297 elif request.current_token:
1298 try:
1299 # Verify current token
1300 token_verification = await mcp_server_instance.auth.verify_token(request.current_token)
1302 if not token_verification.valid:
1303 logger.warning("Current token invalid for refresh", extra={"error": token_verification.error})
1304 raise HTTPException(status_code=401, detail=f"Invalid token: {token_verification.error}")
1306 # Extract user info from token
1307 if not token_verification.payload or "username" not in token_verification.payload:
1308 raise HTTPException(status_code=401, detail="Invalid token: missing username")
1310 username = token_verification.payload["username"]
1311 span.set_attribute("auth.username", username)
1313 # Issue new token
1314 new_token = mcp_server_instance.auth.create_token(
1315 username=username, expires_in=mcp_server_instance.settings.jwt_expiration_seconds
1316 )
1318 logger.info("Token refreshed for user", extra={"username": username})
1320 return RefreshTokenResponse(
1321 access_token=new_token,
1322 token_type="bearer",
1323 expires_in=mcp_server_instance.settings.jwt_expiration_seconds,
1324 )
1326 except HTTPException:
1327 raise
1328 except Exception as e:
1329 logger.error(f"Token refresh failed: {e}", exc_info=True)
1330 raise HTTPException(status_code=500, detail="Token refresh failed")
1332 else:
1333 raise HTTPException(
1334 status_code=400, detail="Either 'refresh_token' (Keycloak) or 'current_token' (InMemory) must be provided"
1335 )
1338async def stream_jsonrpc_response(data: dict[str, Any]) -> AsyncIterator[str]:
1339 """
1340 Stream a JSON-RPC response in chunks
1342 Yields newline-delimited JSON for streaming responses
1343 """
1344 # For now, send the complete response
1345 # In the future, this could stream token-by-token for LLM responses
1346 yield json.dumps(data) + "\n"
1349@app.post("/message", response_model=None)
1350async def handle_message(request: Request) -> JSONResponse | StreamingResponse:
1351 """
1352 Handle MCP messages via StreamableHTTP POST
1354 This is the main endpoint for MCP protocol messages.
1355 Supports both regular and streaming responses.
1356 """
1357 try:
1358 # Parse JSON-RPC request
1359 message = await request.json()
1361 with tracer.start_as_current_span("mcp.streamable.message") as span:
1362 span.set_attribute("mcp.method", message.get("method", "unknown"))
1363 span.set_attribute("mcp.id", str(message.get("id", "")))
1365 logger.info("Received MCP message", extra={"method": message.get("method"), "id": message.get("id")})
1367 method = message.get("method")
1368 message_id = message.get("id")
1370 # Handle different MCP methods
1371 if method == "initialize":
1372 mcp_server = get_mcp_server()
1373 response = {
1374 "jsonrpc": "2.0",
1375 "id": message_id,
1376 "result": {
1377 "protocolVersion": "2024-11-05",
1378 "serverInfo": {"name": "langgraph-agent", "version": mcp_server.settings.service_version},
1379 "capabilities": {
1380 "tools": {"listChanged": False},
1381 "resources": {"listChanged": False},
1382 "prompts": {},
1383 "logging": {},
1384 },
1385 },
1386 }
1387 return JSONResponse(response)
1389 elif method == "tools/list":
1390 # Use public API instead of private _tool_manager
1391 tools = await get_mcp_server().list_tools_public()
1392 response = {
1393 "jsonrpc": "2.0",
1394 "id": message_id,
1395 "result": {"tools": [tool.model_dump(mode="json") for tool in tools]},
1396 }
1397 return JSONResponse(response)
1399 elif method == "tools/call":
1400 params = message.get("params", {})
1401 tool_name = params.get("name")
1402 arguments = params.get("arguments", {})
1404 # Check if client supports streaming
1405 accept_header = request.headers.get("accept", "")
1406 supports_streaming = "text/event-stream" in accept_header or "application/x-ndjson" in accept_header
1408 # Use public API instead of private _tool_manager
1409 result = await get_mcp_server().call_tool_public(tool_name, arguments)
1411 response_data = {
1412 "jsonrpc": "2.0",
1413 "id": message_id,
1414 "result": {"content": [item.model_dump(mode="json") for item in result]},
1415 }
1417 # If streaming is supported, stream the response
1418 if supports_streaming:
1419 return StreamingResponse(
1420 stream_jsonrpc_response(response_data),
1421 media_type="application/x-ndjson",
1422 headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache"},
1423 )
1424 else:
1425 return JSONResponse(response_data)
1427 elif method == "resources/list":
1428 # Use public API instead of private _resource_manager
1429 resources = await get_mcp_server().list_resources_public()
1430 response = {
1431 "jsonrpc": "2.0",
1432 "id": message_id,
1433 "result": {"resources": [res.model_dump(mode="json") for res in resources]},
1434 }
1435 return JSONResponse(response)
1437 elif method == "resources/read":
1438 params = message.get("params", {})
1439 resource_uri = params.get("uri")
1441 # Handle resource read (implement based on your needs)
1442 response = {
1443 "jsonrpc": "2.0",
1444 "id": message_id,
1445 "result": {
1446 "contents": [
1447 {
1448 "uri": resource_uri,
1449 "mimeType": "application/json",
1450 "text": json.dumps({"config": "placeholder"}),
1451 }
1452 ]
1453 },
1454 }
1455 return JSONResponse(response)
1457 else:
1458 raise HTTPException(status_code=400, detail=f"Unknown method: {method}")
1460 except PermissionError as e:
1461 logger.warning(f"Permission denied: {e}")
1462 return JSONResponse(
1463 status_code=200, # JSON-RPC errors should return 200 with error in body
1464 content={
1465 "jsonrpc": "2.0",
1466 "id": message.get("id") if "message" in locals() else None,
1467 "error": {"code": -32001, "message": str(e)},
1468 },
1469 )
1470 except ValueError as e:
1471 # Validation errors (missing fields, invalid arguments)
1472 logger.warning(f"Validation error: {e}")
1473 return JSONResponse(
1474 status_code=200, # JSON-RPC errors should return 200 with error in body
1475 content={
1476 "jsonrpc": "2.0",
1477 "id": message.get("id") if "message" in locals() else None,
1478 "error": {"code": -32602, "message": f"Invalid params: {e!s}"},
1479 },
1480 )
1481 except HTTPException as e:
1482 # HTTP exceptions (unknown methods, etc.)
1483 logger.warning(f"HTTP exception: {e.detail}")
1484 return JSONResponse(
1485 status_code=200, # JSON-RPC errors should return 200 with error in body
1486 content={
1487 "jsonrpc": "2.0",
1488 "id": message.get("id") if "message" in locals() else None,
1489 "error": {"code": -32601 if e.status_code == 400 else -32603, "message": e.detail},
1490 },
1491 )
1492 except json.JSONDecodeError as e:
1493 # Invalid JSON in request body
1494 logger.warning(f"JSON parse error: {e}")
1495 return JSONResponse(
1496 status_code=400, # Parse errors can return 400
1497 content={
1498 "jsonrpc": "2.0",
1499 "id": None,
1500 "error": {"code": -32700, "message": "Parse error: Invalid JSON"},
1501 },
1502 )
1503 except Exception as e:
1504 logger.error(f"Error handling message: {e}", exc_info=True)
1505 return JSONResponse(
1506 status_code=500,
1507 content={
1508 "jsonrpc": "2.0",
1509 "id": message.get("id") if "message" in locals() else None,
1510 "error": {"code": -32603, "message": str(e)},
1511 },
1512 )
1515@app.get("/tools")
1516async def list_tools() -> dict[str, Any]:
1517 """List available tools (convenience endpoint)"""
1518 # Use public API instead of private _tool_manager
1519 tools = await get_mcp_server().list_tools_public()
1520 return {"tools": [tool.model_dump(mode="json") for tool in tools]}
1523@app.get("/resources")
1524async def list_resources() -> dict[str, Any]:
1525 """List available resources (convenience endpoint)"""
1526 # Use public API instead of private _resource_manager
1527 resources = await get_mcp_server().list_resources_public()
1528 return {"resources": [res.model_dump(mode="json") for res in resources]}
1531# Include health check routes
1532from mcp_server_langgraph.health.checks import app as health_app # noqa: E402
1534app.mount("/health", health_app)
1536from mcp_server_langgraph.api.api_keys import router as api_keys_router # noqa: E402
1537from mcp_server_langgraph.api.gdpr import router as gdpr_router # noqa: E402
1538from mcp_server_langgraph.api.scim import router as scim_router # noqa: E402
1539from mcp_server_langgraph.api.service_principals import router as service_principals_router # noqa: E402
1541# Include REST API routes
1542from mcp_server_langgraph.api.version import router as version_router # noqa: E402
1544app.include_router(version_router)
1545app.include_router(gdpr_router)
1546app.include_router(api_keys_router)
1547app.include_router(service_principals_router)
1548app.include_router(scim_router)
1551# ==============================================================================
1552# Custom OpenAPI Schema Generation
1553# ==============================================================================
1554# Pagination models need to be included in OpenAPI schema for API documentation,
1555# even if not all endpoints use them yet. This ensures consistent API contract.
1558def custom_openapi() -> dict[str, Any]:
1559 """
1560 Custom OpenAPI schema generator that includes pagination models.
1562 FastAPI only includes models in the OpenAPI schema if they're used in endpoint
1563 signatures. We explicitly add pagination models here to ensure they're documented
1564 for API consumers, even if not all endpoints implement pagination yet.
1566 This follows TDD principles - tests define the expected API contract first.
1567 """
1568 from typing import cast
1570 if app.openapi_schema:
1571 return cast(dict[str, Any], app.openapi_schema) # type: ignore[redundant-cast]
1573 # Generate base OpenAPI schema
1574 from fastapi.openapi.utils import get_openapi
1576 openapi_schema = get_openapi(
1577 title=app.title,
1578 version=app.version,
1579 description=app.description,
1580 routes=app.routes,
1581 tags=app.openapi_tags,
1582 )
1584 # Add pagination models to schema components
1585 from mcp_server_langgraph.api.pagination import PaginationMetadata, PaginationParams
1587 # Get JSON schemas for pagination models
1588 pagination_params_schema = PaginationParams.model_json_schema()
1589 pagination_metadata_schema = PaginationMetadata.model_json_schema()
1591 # Add to components/schemas
1592 if "components" not in openapi_schema: 1592 ↛ 1593line 1592 didn't jump to line 1593 because the condition on line 1592 was never true
1593 openapi_schema["components"] = {}
1594 if "schemas" not in openapi_schema["components"]: 1594 ↛ 1595line 1594 didn't jump to line 1595 because the condition on line 1594 was never true
1595 openapi_schema["components"]["schemas"] = {}
1597 openapi_schema["components"]["schemas"]["PaginationParams"] = pagination_params_schema
1598 openapi_schema["components"]["schemas"]["PaginationMetadata"] = pagination_metadata_schema
1600 # Note: PaginatedResponse is generic, so we document it in the description
1601 # Individual endpoint responses (e.g., PaginatedResponse[APIKey]) will be
1602 # generated automatically when endpoints use them
1604 app.openapi_schema = openapi_schema
1605 return cast(dict[str, Any], app.openapi_schema) # type: ignore[redundant-cast]
1608# Apply custom OpenAPI schema
1609app.openapi = custom_openapi # type: ignore[method-assign]
1612def main() -> None:
1613 """Entry point for console script"""
1614 # Initialize observability system before creating server
1615 import atexit
1617 from mcp_server_langgraph.core.config import settings
1618 from mcp_server_langgraph.observability.telemetry import init_observability, shutdown_observability
1620 # Initialize with settings and enable file logging if configured
1621 init_observability(settings=settings, enable_file_logging=getattr(settings, "enable_file_logging", False))
1623 # Register shutdown handler as fallback (lifespan is primary)
1624 atexit.register(shutdown_observability)
1626 # SECURITY: Validate CORS configuration before starting server
1627 settings.validate_cors_config()
1629 logger.info(f"Starting MCP StreamableHTTP server on port {settings.get_secret('PORT', fallback='8000')}")
1631 port_str = settings.get_secret("PORT", fallback="8000")
1632 port = int(port_str) if port_str else 8000
1634 uvicorn.run(
1635 app,
1636 host="0.0.0.0", # nosec B104 - Required for containerized deployment
1637 port=port,
1638 log_level=settings.log_level.lower(),
1639 access_log=True,
1640 )
1643if __name__ == "__main__":
1644 main()