Coverage for src / mcp_server_langgraph / mcp / server_streamable.py: 59%

549 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2MCP Server with StreamableHTTP transport 

3Implements the MCP StreamableHTTP specification (replaces deprecated SSE) 

4 

5Implements Anthropic's best practices for writing tools for agents: 

6- Token-efficient responses with truncation 

7- Search-focused tools instead of list-all 

8- Response format control (concise vs detailed) 

9- Namespaced tools for clarity 

10- High-signal information in responses 

11""" 

12 

13import json 

14import logging 

15import sys 

16from collections.abc import AsyncIterator 

17from contextlib import asynccontextmanager 

18from typing import Any, Literal 

19 

20import uvicorn 

21from fastapi import FastAPI, HTTPException, Request 

22from fastapi.middleware.cors import CORSMiddleware 

23from fastapi.responses import JSONResponse, StreamingResponse 

24from langchain_core.messages import HumanMessage 

25from mcp.server import Server 

26from mcp.types import Resource, TextContent, Tool 

27from pydantic import AnyUrl, BaseModel, Field 

28 

29from mcp_server_langgraph.api.auth_request_middleware import AuthRequestMiddleware 

30from mcp_server_langgraph.auth.factory import create_auth_middleware, create_user_provider 

31from mcp_server_langgraph.auth.middleware import AuthMiddleware 

32from mcp_server_langgraph.auth.openfga import OpenFGAClient 

33from mcp_server_langgraph.auth.user_provider import KeycloakUserProvider 

34from mcp_server_langgraph.core.agent import AgentState, get_agent_graph 

35from mcp_server_langgraph.core.config import Settings, settings 

36from mcp_server_langgraph.core.security import sanitize_for_logging 

37from mcp_server_langgraph.middleware.rate_limiter import custom_rate_limit_exceeded_handler, limiter 

38from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

39from mcp_server_langgraph.utils.response_optimizer import format_response 

40 

41 

42@asynccontextmanager 

43async def lifespan(app: FastAPI) -> AsyncIterator[None]: 

44 """ 

45 Lifespan context manager for application startup and shutdown. 

46 

47 CRITICAL: This ensures observability is initialized before handling requests, 

48 preventing crashes when launching with: uvicorn mcp_server_langgraph.mcp.server_streamable:app 

49 

50 Without this handler, logger/tracer/metrics usage in get_mcp_server() and downstream 

51 code would fail when observability hasn't been initialized yet. 

52 """ 

53 # Startup 

54 from mcp_server_langgraph.observability.telemetry import init_observability, is_initialized 

55 

56 if not is_initialized(): 

57 logger_temp = logging.getLogger(__name__) 

58 logger_temp.info("Initializing observability from startup event") 

59 

60 # Initialize with file logging if configured 

61 enable_file_logging = getattr(settings, "enable_file_logging", False) 

62 init_observability(settings=settings, enable_file_logging=enable_file_logging) 

63 

64 logger_temp.info("Observability initialized successfully") 

65 

66 # Initialize global auth middleware for FastAPI dependencies 

67 # This must happen AFTER observability is initialized (for logging) 

68 try: 

69 from mcp_server_langgraph.auth.middleware import FASTAPI_AVAILABLE, set_global_auth_middleware 

70 

71 if FASTAPI_AVAILABLE: 

72 # Get MCP server instance (creates it if needed) 

73 mcp_server = get_mcp_server() 

74 

75 # Set the auth middleware globally for FastAPI dependencies 

76 set_global_auth_middleware(mcp_server.auth) 

77 

78 logger.info("Global auth middleware initialized for FastAPI dependencies") 

79 except Exception as e: 

80 logger.warning(f"Failed to initialize global auth middleware: {e}") 

81 

82 yield 

83 

84 # Shutdown - cleanup observability and close connections 

85 from mcp_server_langgraph.observability.telemetry import shutdown_observability 

86 

87 logger.info("Application shutdown initiated") 

88 

89 # Cleanup checkpointer resources (Redis connections, etc.) 

90 try: 

91 from mcp_server_langgraph.core.agent import cleanup_checkpointer 

92 

93 agent_graph = get_agent_graph() # type: ignore[func-returns-value] 

94 if agent_graph and hasattr(agent_graph, "checkpointer") and agent_graph.checkpointer: 

95 cleanup_checkpointer(agent_graph.checkpointer) 

96 logger.info("Checkpointer resources cleaned up") 

97 except Exception as e: 

98 logger.warning(f"Error cleaning up checkpointer: {e}") 

99 

100 # Shutdown observability (flush spans, close exporters) 

101 shutdown_observability() 

102 

103 # Close Prometheus client if initialized 

104 try: 

105 mcp_server = get_mcp_server() 

106 if hasattr(mcp_server, "prometheus_client") and mcp_server.prometheus_client: 

107 await mcp_server.prometheus_client.close() 

108 logger.info("Prometheus client closed") 

109 except Exception as e: 

110 logger.warning(f"Error closing Prometheus client: {e}") 

111 

112 logger.info("Application shutdown complete") 

113 

114 

115app = FastAPI( 

116 title="MCP Server with LangGraph", 

117 description="AI Agent with fine-grained authorization and observability - StreamableHTTP transport", 

118 version=settings.service_version, 

119 docs_url="/docs", 

120 redoc_url="/redoc", 

121 openapi_tags=[ 

122 { 

123 "name": "API Metadata", 

124 "description": "API version information and metadata for client compatibility", 

125 }, 

126 { 

127 "name": "mcp", 

128 "description": "Model Context Protocol (MCP) endpoints for agent interaction", 

129 }, 

130 { 

131 "name": "health", 

132 "description": "Health check and system status endpoints", 

133 }, 

134 { 

135 "name": "auth", 

136 "description": "Authentication and authorization endpoints", 

137 }, 

138 { 

139 "name": "GDPR Compliance", 

140 "description": "GDPR data protection and privacy rights endpoints (Articles 15-21)", 

141 }, 

142 { 

143 "name": "API Keys", 

144 "description": "API key management for programmatic access", 

145 }, 

146 { 

147 "name": "Service Principals", 

148 "description": "Service principal (service account) management for machine-to-machine authentication", 

149 }, 

150 { 

151 "name": "SCIM 2.0", 

152 "description": "System for Cross-domain Identity Management (SCIM) 2.0 user and group provisioning", 

153 }, 

154 ], 

155 responses={ 

156 401: {"description": "Unauthorized - Invalid or missing authentication token"}, 

157 403: {"description": "Forbidden - Insufficient permissions"}, 

158 429: {"description": "Too Many Requests - Rate limit exceeded"}, 

159 500: {"description": "Internal Server Error"}, 

160 }, 

161 lifespan=lifespan, 

162) 

163 

164# CORS middleware 

165# SECURITY: Use config-based origins instead of wildcard 

166# Empty list = no CORS (production default), specific origins in development 

167cors_origins = settings.get_cors_origins() 

168app.add_middleware( 

169 CORSMiddleware, 

170 allow_origins=cors_origins, 

171 allow_credentials=bool(cors_origins), # Only allow credentials if origins specified 

172 allow_methods=["*"], 

173 allow_headers=["*"], 

174 expose_headers=["X-Request-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining"], 

175) 

176 

177# Rate limiting middleware 

178# SECURITY: Protect against DoS, brute force, and API abuse 

179# Uses tiered rate limits (anonymous: 10/min, free: 60/min, premium: 1000/min, enterprise: unlimited) 

180# Tracks by user ID (from JWT) > IP address > global anonymous 

181# Fail-open: allows requests if Redis is down (graceful degradation) 

182# 

183# NOTE: Use standard logging.getLogger() here instead of observability logger 

184# because this code runs at module import time, before lifespan initializes observability 

185_module_logger = logging.getLogger(__name__) 

186try: 

187 from slowapi.errors import RateLimitExceeded 

188 

189 # Register rate limiter with app 

190 app.state.limiter = limiter 

191 

192 # Register custom exception handler for rate limit exceeded 

193 app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) # type: ignore[arg-type] 

194 

195 _module_logger.info( 

196 "Rate limiting enabled - strategy: fixed-window, tiers: anonymous/free/standard/premium/enterprise, fail_open: True" 

197 ) 

198except Exception as e: 

199 _module_logger.warning(f"Failed to initialize rate limiting: {e}. Requests will proceed without rate limits.") 

200 

201# Authentication middleware for REST API endpoints (GDPR, API Keys, Service Principals, SCIM) 

202# This must be added AFTER rate limiting but BEFORE routes are defined 

203# Uses module-level auth middleware creation (lazy init for observability compatibility) 

204try: 

205 # Create user provider and auth middleware for REST API authentication 

206 # This enables request.state.user for all protected endpoints 

207 _module_user_provider = create_user_provider(settings) 

208 _module_auth_middleware = AuthMiddleware( 

209 secret_key=settings.jwt_secret_key, 

210 settings=settings, 

211 user_provider=_module_user_provider, 

212 ) 

213 app.add_middleware(AuthRequestMiddleware, auth_middleware=_module_auth_middleware) 

214 _module_logger.info("AuthRequestMiddleware enabled for REST API authentication") 

215except Exception as e: 

216 _module_logger.warning(f"Failed to initialize AuthRequestMiddleware: {e}. REST API endpoints will require manual auth.") 

217 

218 

219class ChatInput(BaseModel): 

220 """ 

221 Input schema for agent_chat tool. 

222 

223 Follows Anthropic best practices: 

224 - Unambiguous parameter names (user_id not username, message not query) 

225 - Response format control for token efficiency 

226 - Clear field descriptions 

227 """ 

228 

229 message: str = Field(description="The user message to send to the agent", min_length=1, max_length=10000) 

230 token: str = Field( 

231 description=( 

232 "JWT authentication token. Obtain via /auth/login endpoint (HTTP) " 

233 "or external authentication service. Required for all tool calls." 

234 ) 

235 ) 

236 user_id: str = Field(description="User identifier for authentication and authorization") 

237 thread_id: str | None = Field( 

238 default=None, 

239 description="Optional thread ID for conversation continuity (e.g., 'conv_123')", 

240 pattern=r"^[a-zA-Z0-9_-]{1,128}$", # SECURITY: Prevent CWE-20 injection into OpenFGA/Redis/logs 

241 ) 

242 response_format: Literal["concise", "detailed"] = Field( 

243 default="concise", 

244 description=( 

245 "Response verbosity level. " 

246 "'concise' returns ~500 tokens (faster, less context). " 

247 "'detailed' returns ~2000 tokens (comprehensive, more context)." 

248 ), 

249 ) 

250 

251 # Backward compatibility - DEPRECATED 

252 username: str | None = Field( 

253 default=None, deprecated=True, description="DEPRECATED: Use 'user_id' instead. Maintained for backward compatibility." 

254 ) 

255 

256 @property 

257 def effective_user_id(self) -> str: 

258 """Get effective user ID, prioritizing user_id over deprecated username.""" 

259 return self.user_id if hasattr(self, "user_id") and self.user_id else (self.username or "") 

260 

261 

262class SearchConversationsInput(BaseModel): 

263 """Input schema for conversation_search tool.""" 

264 

265 query: str = Field(description="Search query to filter conversations", min_length=1, max_length=500) 

266 token: str = Field( 

267 description=( 

268 "JWT authentication token. Obtain via /auth/login endpoint (HTTP) " 

269 "or external authentication service. Required for all tool calls." 

270 ) 

271 ) 

272 user_id: str = Field(description="User identifier for authentication and authorization") 

273 limit: int = Field(default=10, ge=1, le=50, description="Maximum number of conversations to return (1-50)") 

274 

275 # Backward compatibility - DEPRECATED 

276 username: str | None = Field( 

277 default=None, deprecated=True, description="DEPRECATED: Use 'user_id' instead. Maintained for backward compatibility." 

278 ) 

279 

280 @property 

281 def effective_user_id(self) -> str: 

282 """Get effective user ID, prioritizing user_id over deprecated username.""" 

283 return self.user_id if hasattr(self, "user_id") and self.user_id else (self.username or "") 

284 

285 

286class MCPAgentStreamableServer: 

287 """MCP Server with StreamableHTTP transport""" 

288 

289 def __init__( 

290 self, 

291 openfga_client: OpenFGAClient | None = None, 

292 settings: Settings | None = None, 

293 ) -> None: 

294 """ 

295 Initialize MCP Agent Streamable Server with optional dependency injection. 

296 

297 Args: 

298 openfga_client: Optional OpenFGA client for authorization. 

299 If None, creates one from settings. 

300 settings: Optional Settings instance for runtime configuration. 

301 If provided, enables dynamic feature toggling (e.g., code execution). 

302 If None, uses global settings. This allows tests to inject custom 

303 configuration without module reloading. 

304 

305 Example: 

306 # Default creation (production): 

307 server = MCPAgentStreamableServer() 

308 

309 # Custom settings injection (testing): 

310 test_settings = Settings(enable_code_execution=True) 

311 server = MCPAgentStreamableServer(settings=test_settings) 

312 """ 

313 # Store settings for runtime configuration 

314 # NOTE: When settings=None, we must reference the module-level 'settings' 

315 # imported at the top of this file. This allows tests to mock settings via 

316 # @patch("mcp_server_langgraph.mcp.server_streamable.settings", ...) 

317 # Using sys.modules[__name__] to avoid self-import (CodeQL py/import-own-module) 

318 self.settings = settings if settings is not None else sys.modules[__name__].settings 

319 

320 self.server = Server("langgraph-agent") 

321 

322 # Initialize OpenFGA client 

323 self.openfga = openfga_client or self._create_openfga_client() 

324 

325 # Validate JWT secret is configured (fail-closed security pattern) 

326 if not self.settings.jwt_secret_key: 

327 msg = ( 

328 "CRITICAL: JWT secret key not configured. " 

329 "Set JWT_SECRET_KEY environment variable or configure via Infisical. " 

330 "The service cannot start without a secure secret key." 

331 ) 

332 raise ValueError(msg) 

333 

334 # SECURITY: Fail-closed pattern - require OpenFGA in production 

335 if self.settings.environment == "production" and self.openfga is None: 

336 msg = ( 

337 "CRITICAL: OpenFGA authorization is required in production mode. " 

338 "Configure OPENFGA_STORE_ID and OPENFGA_MODEL_ID environment variables, " 

339 "or set ENVIRONMENT=development for local testing. " 

340 "Fallback authorization is not secure enough for production use." 

341 ) 

342 raise ValueError(msg) 

343 

344 # Initialize auth using factory (respects settings.auth_provider) 

345 self.auth = create_auth_middleware(self.settings, openfga_client=self.openfga) 

346 

347 self._setup_handlers() 

348 

349 def _create_openfga_client(self) -> OpenFGAClient | None: 

350 """Create OpenFGA client from settings""" 

351 if self.settings.openfga_store_id and self.settings.openfga_model_id: 

352 logger.info( 

353 "Initializing OpenFGA client", 

354 extra={"store_id": self.settings.openfga_store_id, "model_id": self.settings.openfga_model_id}, 

355 ) 

356 return OpenFGAClient( 

357 api_url=self.settings.openfga_api_url, 

358 store_id=self.settings.openfga_store_id, 

359 model_id=self.settings.openfga_model_id, 

360 ) 

361 else: 

362 logger.warning("OpenFGA not configured, authorization will use fallback mode") 

363 return None 

364 

365 async def list_tools_public(self) -> list[Tool]: 

366 """ 

367 Public API to list available tools. 

368 

369 This wraps the internal MCP handler to avoid accessing private SDK attributes. 

370 """ 

371 # Call the registered handler 

372 # The handler is registered below in _setup_handlers() 

373 return await self._list_tools_handler() # type: ignore[no-any-return] 

374 

375 async def call_tool_public(self, name: str, arguments: dict[str, Any]) -> list[TextContent]: 

376 """ 

377 Public API to call a tool. 

378 

379 This wraps the internal MCP handler to avoid accessing private SDK attributes. 

380 """ 

381 return await self._call_tool_handler(name, arguments) # type: ignore[no-any-return] 

382 

383 async def list_resources_public(self) -> list[Resource]: 

384 """ 

385 Public API to list available resources. 

386 

387 This wraps the internal MCP handler to avoid accessing private SDK attributes. 

388 """ 

389 return await self._list_resources_handler() # type: ignore[no-any-return] 

390 

391 def _setup_handlers(self) -> None: 

392 """Setup MCP protocol handlers and store references for public API""" 

393 

394 @self.server.list_tools() # type: ignore[no-untyped-call, untyped-decorator] 

395 async def list_tools() -> list[Tool]: 

396 """ 

397 List available tools. 

398 

399 Tools follow Anthropic best practices: 

400 - Namespaced for clarity (agent_*, conversation_*) 

401 - Search-focused instead of list-all 

402 - Clear usage guidance in descriptions 

403 - Token limits and expected response times documented 

404 """ 

405 with tracer.start_as_current_span("mcp.list_tools"): 

406 logger.info("Listing available tools") 

407 tools = [ 

408 Tool( 

409 name="agent_chat", 

410 description=( 

411 "Chat with the AI agent for questions, research, and problem-solving. " 

412 "Returns responses optimized for agent consumption. " 

413 "Response format: 'concise' (~500 tokens, 2-5 sec) or 'detailed' (~2000 tokens, 5-10 sec). " 

414 "For specialized tasks like code execution or web search, use dedicated tools instead. " 

415 "Rate limit: 60 requests/minute per user." 

416 ), 

417 inputSchema=ChatInput.model_json_schema(), 

418 ), 

419 Tool( 

420 name="conversation_get", 

421 description=( 

422 "Retrieve a specific conversation thread by ID. " 

423 "Returns conversation history with messages, participants, and metadata. " 

424 "Response time: <1 second. " 

425 "Use conversation_search to find conversation IDs first." 

426 ), 

427 inputSchema={ 

428 "type": "object", 

429 "properties": { 

430 "thread_id": { 

431 "type": "string", 

432 "description": "Conversation thread identifier (e.g., 'conv_abc123')", 

433 }, 

434 "token": { 

435 "type": "string", 

436 "description": "JWT authentication token. Required for all tool calls.", 

437 }, 

438 "user_id": { 

439 "type": "string", 

440 "description": "User identifier for authentication and authorization", 

441 }, 

442 "username": {"type": "string", "description": "DEPRECATED: Use 'user_id' instead"}, 

443 }, 

444 "required": ["thread_id", "token", "user_id"], 

445 }, 

446 ), 

447 Tool( 

448 name="conversation_search", 

449 description=( 

450 "Search conversations using keywords or filters. " 

451 "Returns matching conversations sorted by relevance. " 

452 "Much more efficient than listing all conversations. " 

453 "Response time: <2 seconds. " 

454 "Examples: 'project updates', 'conversations with alice', 'last week'. " 

455 "Results limited to 50 conversations max to prevent context overflow." 

456 ), 

457 inputSchema=SearchConversationsInput.model_json_schema(), 

458 ), 

459 ] 

460 

461 # Add search_tools for progressive discovery (Anthropic best practice) 

462 tools.append( 

463 Tool( 

464 name="search_tools", 

465 description=( 

466 "Search and discover available tools using progressive disclosure. " 

467 "Query by keyword or category instead of loading all tool definitions. " 

468 "Saves 98%+ tokens compared to list-all approach. " 

469 "Detail levels: minimal (name+desc), standard (+params), full (+schema). " 

470 "Categories: calculator, search, filesystem, execution. " 

471 "Response time: <1 second." 

472 ), 

473 inputSchema={ 

474 "type": "object", 

475 "properties": { 

476 "query": {"type": "string", "description": "Search query (keyword)"}, 

477 "category": {"type": "string", "description": "Tool category filter"}, 

478 "detail_level": { 

479 "type": "string", 

480 "enum": ["minimal", "standard", "full"], 

481 "description": "Level of detail in results", 

482 }, 

483 }, 

484 }, 

485 ) 

486 ) 

487 

488 # Add execute_python if code execution is enabled 

489 if self.settings.enable_code_execution: 

490 from mcp_server_langgraph.tools.code_execution_tools import ExecutePythonInput 

491 

492 tools.append( 

493 Tool( 

494 name="execute_python", 

495 description=( 

496 "Execute Python code in a secure sandboxed environment. " 

497 "Security: Import whitelist, no eval/exec, resource limits (CPU, memory, timeout). " 

498 "Backends: docker-engine (local/dev) or kubernetes (production). " 

499 "Network: Configurable isolation (none/allowlist/unrestricted). " 

500 "Response time: 1-30 seconds depending on code complexity. " 

501 "Use for data processing, calculations, and Python-specific tasks." 

502 ), 

503 inputSchema=ExecutePythonInput.model_json_schema(), 

504 ) 

505 ) 

506 

507 return tools 

508 

509 # Store reference to handler for public API 

510 self._list_tools_handler = list_tools 

511 

512 @self.server.call_tool() # type: ignore[untyped-decorator] # MCP library decorator lacks type stubs 

513 async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: 

514 """Handle tool calls with OpenFGA authorization and tracing""" 

515 

516 with tracer.start_as_current_span("mcp.call_tool", attributes={"tool.name": name}) as span: 

517 # SECURITY: Sanitize arguments before logging to prevent CWE-200/CWE-532 (token exposure in logs) 

518 logger.info(f"Tool called: {name}", extra={"tool": name, "tool_args": sanitize_for_logging(arguments)}) 

519 metrics.tool_calls.add(1, {"tool": name}) 

520 

521 # SECURITY: Require JWT token for all tool calls 

522 token = arguments.get("token") 

523 

524 if not token: 

525 logger.warning("No authentication token provided") 

526 metrics.auth_failures.add(1) 

527 msg = ( 

528 "Authentication token required. Provide 'token' parameter with a valid JWT. " 

529 "Obtain token via /auth/login endpoint or external authentication service." 

530 ) 

531 raise PermissionError(msg) 

532 

533 # Verify JWT token 

534 token_verification = await self.auth.verify_token(token) 

535 

536 if not token_verification.valid: 

537 logger.warning("Token verification failed", extra={"error": token_verification.error}) 

538 metrics.auth_failures.add(1) 

539 msg = f"Invalid authentication token: {token_verification.error or 'token verification failed'}" 

540 raise PermissionError(msg) 

541 

542 # Extract user_id from validated token payload 

543 if not token_verification.payload: 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true

544 logger.error("Token payload is empty") 

545 metrics.auth_failures.add(1) 

546 msg = "Invalid token: missing user identifier" 

547 raise PermissionError(msg) 

548 

549 # Extract username with defensive fallback 

550 # Priority: preferred_username > username claim > sub parsing 

551 # Keycloak uses UUID in 'sub', but OpenFGA needs 'user:username' format 

552 # NOTE: Some Keycloak configurations may not include 'sub' in access tokens 

553 username = token_verification.payload.get("preferred_username") 

554 if not username: 554 ↛ 556line 554 didn't jump to line 556 because the condition on line 554 was never true

555 # Try 'username' claim (alternative standard claim) 

556 username = token_verification.payload.get("username") 

557 if not username: 557 ↛ 559line 557 didn't jump to line 559 because the condition on line 557 was never true

558 # Fallback: extract from sub if available 

559 sub = token_verification.payload.get("sub", "") 

560 if sub.startswith("user:"): 

561 username = sub.split(":", 1)[1] 

562 elif sub and ":" not in sub: 

563 # Log warning for UUID-style subs (may cause issues) 

564 logger.warning( 

565 f"Using sub as username fallback (may be UUID): {sub[:8]}...", 

566 extra={"sub_prefix": sub[:8]}, 

567 ) 

568 username = sub 

569 

570 # Final check: ensure we have a username 

571 if not username: 571 ↛ 572line 571 didn't jump to line 572 because the condition on line 571 was never true

572 logger.error("Token missing user identifier (no sub, preferred_username, or username claim)") 

573 metrics.auth_failures.add(1) 

574 msg = "Invalid token: cannot extract username from claims" 

575 raise PermissionError(msg) 

576 

577 # Normalize user_id to "user:username" format for OpenFGA compatibility 

578 user_id = f"user:{username}" if not username.startswith("user:") else username 

579 span.set_attribute("user.id", user_id) 

580 

581 logger.info("User authenticated via token", extra={"user_id": user_id, "tool": name}) 

582 

583 # Check OpenFGA authorization 

584 resource = f"tool:{name}" 

585 

586 authorized = await self.auth.authorize(user_id=user_id, relation="executor", resource=resource) 

587 

588 if not authorized: 588 ↛ 589line 588 didn't jump to line 589 because the condition on line 588 was never true

589 logger.warning( 

590 "Authorization failed (OpenFGA)", 

591 extra={"user_id": user_id, "resource": resource, "relation": "executor"}, 

592 ) 

593 metrics.authz_failures.add(1, {"resource": resource}) 

594 msg = f"Not authorized to execute {resource}" 

595 raise PermissionError(msg) 

596 

597 logger.info("Authorization granted", extra={"user_id": user_id, "resource": resource}) 

598 

599 # Route to appropriate handler (with backward compatibility) 

600 if name == "agent_chat" or name == "chat": # Support old name for compatibility 600 ↛ 601line 600 didn't jump to line 601 because the condition on line 600 was never true

601 return await self._handle_chat(arguments, span, user_id) 

602 elif name == "conversation_get" or name == "get_conversation": 602 ↛ 603line 602 didn't jump to line 603 because the condition on line 602 was never true

603 return await self._handle_get_conversation(arguments, span, user_id) 

604 elif name == "conversation_search" or name == "list_conversations": 604 ↛ 605line 604 didn't jump to line 605 because the condition on line 604 was never true

605 return await self._handle_search_conversations(arguments, span, user_id) 

606 elif name == "search_tools": 606 ↛ 607line 606 didn't jump to line 607 because the condition on line 606 was never true

607 return await self._handle_search_tools(arguments, span) 

608 elif name == "execute_python": 608 ↛ 609line 608 didn't jump to line 609 because the condition on line 608 was never true

609 return await self._handle_execute_python(arguments, span, user_id) 

610 else: 

611 msg = f"Unknown tool: {name}" 

612 raise ValueError(msg) 

613 

614 # Store reference to handler for public API 

615 self._call_tool_handler = call_tool 

616 

617 @self.server.list_resources() # type: ignore[no-untyped-call, untyped-decorator] 

618 async def list_resources() -> list[Resource]: 

619 """List available resources""" 

620 with tracer.start_as_current_span("mcp.list_resources"): 

621 return [Resource(uri=AnyUrl("agent://config"), name="Agent Configuration", mimeType="application/json")] 

622 

623 # Store reference to handler for public API 

624 self._list_resources_handler = list_resources 

625 

626 async def _handle_chat(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]: 

627 """ 

628 Handle agent_chat tool invocation. 

629 

630 Implements Anthropic best practices: 

631 - Response format control (concise vs detailed) 

632 - Token-efficient responses with truncation 

633 - Clear error messages 

634 - Performance tracking 

635 """ 

636 with tracer.start_as_current_span("agent.chat"): 

637 # BUGFIX: Validate input with Pydantic schema to enforce length limits and required fields 

638 try: 

639 chat_input = ChatInput.model_validate(arguments) 

640 except Exception as e: 

641 # SECURITY: Sanitize arguments before logging to prevent token exposure in error logs 

642 logger.error(f"Invalid chat input: {e}", extra={"arguments": sanitize_for_logging(arguments)}) 

643 msg = f"Invalid chat input: {e}" 

644 raise ValueError(msg) 

645 

646 message = chat_input.message 

647 thread_id = chat_input.thread_id or "default" 

648 response_format_type = chat_input.response_format 

649 

650 span.set_attribute("message.length", len(message)) 

651 span.set_attribute("thread.id", thread_id) 

652 span.set_attribute("user.id", user_id) 

653 span.set_attribute("response.format", response_format_type) 

654 

655 # Check if user can access this conversation 

656 # BUGFIX: Allow first-time conversation creation without pre-existing OpenFGA tuples 

657 # For new conversations, we short-circuit authorization and will seed ownership after creation 

658 conversation_resource = f"conversation:{thread_id}" 

659 

660 # Check if conversation exists by trying to get state from checkpointer 

661 graph = get_agent_graph() # type: ignore[func-returns-value] 

662 conversation_exists = False 

663 if hasattr(graph, "checkpointer") and graph.checkpointer is not None: 

664 try: 

665 config = {"configurable": {"thread_id": thread_id}} 

666 state_snapshot = await graph.aget_state(config) 

667 conversation_exists = state_snapshot is not None and state_snapshot.values is not None 

668 except Exception: 

669 # If we can't check, assume it doesn't exist (fail-open for creation) 

670 conversation_exists = False 

671 

672 # Only check authorization for existing conversations 

673 if conversation_exists: 

674 can_edit = await self.auth.authorize(user_id=user_id, relation="editor", resource=conversation_resource) 

675 if not can_edit: 

676 logger.warning("User cannot edit conversation", extra={"user_id": user_id, "thread_id": thread_id}) 

677 msg = ( 

678 f"Not authorized to edit conversation '{thread_id}'. " 

679 f"Request access from conversation owner or use a different thread_id." 

680 ) 

681 raise PermissionError(msg) 

682 else: 

683 # New conversation - user becomes implicit owner (OpenFGA tuples should be seeded after creation) 

684 logger.info( 

685 "Creating new conversation, user granted implicit ownership", 

686 extra={"user_id": user_id, "thread_id": thread_id}, 

687 ) 

688 

689 logger.info( 

690 "Processing chat message", 

691 extra={ 

692 "thread_id": thread_id, 

693 "user_id": user_id, 

694 "message_preview": message[:100], 

695 "response_format": response_format_type, 

696 }, 

697 ) 

698 

699 # Create initial state with proper LangChain message objects 

700 # BUGFIX: Use HumanMessage instead of dict to avoid type errors in graph nodes 

701 initial_state: AgentState = { 

702 "messages": [HumanMessage(content=message)], 

703 "next_action": "", 

704 "user_id": user_id, 

705 "request_id": str(span.get_span_context().trace_id) if span.get_span_context() else None, 

706 "routing_confidence": None, 

707 "reasoning": None, 

708 "compaction_applied": None, 

709 "original_message_count": None, 

710 "verification_passed": None, 

711 "verification_score": None, 

712 "verification_feedback": None, 

713 "refinement_attempts": None, 

714 "user_request": message, 

715 } 

716 

717 # Run the agent graph 

718 config = {"configurable": {"thread_id": thread_id}} 

719 

720 try: 

721 result = await get_agent_graph().ainvoke(initial_state, config) # type: ignore[func-returns-value] 

722 

723 # Seed OpenFGA tuples for new conversations 

724 if not conversation_exists and self.openfga is not None: 

725 try: 

726 # Create ownership, viewer, and editor tuples in batch 

727 await self.openfga.write_tuples( 

728 [ 

729 {"user": user_id, "relation": "owner", "object": conversation_resource}, 

730 {"user": user_id, "relation": "viewer", "object": conversation_resource}, 

731 {"user": user_id, "relation": "editor", "object": conversation_resource}, 

732 ] 

733 ) 

734 

735 logger.info( 

736 "OpenFGA tuples seeded for new conversation", 

737 extra={"user_id": user_id, "thread_id": thread_id}, 

738 ) 

739 

740 except Exception as e: 

741 # Log warning but don't fail the request 

742 # The conversation was created successfully, ACL seeding is best-effort 

743 logger.warning( 

744 f"Failed to seed OpenFGA tuples for new conversation: {e}", 

745 extra={"user_id": user_id, "thread_id": thread_id}, 

746 exc_info=True, 

747 ) 

748 

749 # Extract response 

750 response_message = result["messages"][-1] 

751 response_text = response_message.content 

752 

753 # Apply response formatting based on format type 

754 # Follows Anthropic guidance: offer response_format enum parameter 

755 formatted_response = format_response(response_text, format_type=response_format_type) 

756 

757 span.set_attribute("response.length.original", len(response_text)) 

758 span.set_attribute("response.length.formatted", len(formatted_response)) 

759 metrics.successful_calls.add(1, {"tool": "agent_chat", "format": response_format_type}) 

760 

761 logger.info( 

762 "Chat response generated", 

763 extra={ 

764 "thread_id": thread_id, 

765 "original_length": len(response_text), 

766 "formatted_length": len(formatted_response), 

767 "format": response_format_type, 

768 }, 

769 ) 

770 

771 return [TextContent(type="text", text=formatted_response)] 

772 

773 except Exception as e: 

774 logger.error(f"Error processing chat: {e}", extra={"error": str(e), "thread_id": thread_id}, exc_info=True) 

775 metrics.failed_calls.add(1, {"tool": "agent_chat", "error": type(e).__name__}) 

776 span.record_exception(e) 

777 raise 

778 

779 async def _handle_get_conversation(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]: 

780 """ 

781 Retrieve conversation history from LangGraph checkpointer. 

782 

783 Returns formatted conversation with messages, metadata, and participants. 

784 """ 

785 with tracer.start_as_current_span("agent.get_conversation"): 

786 thread_id = arguments["thread_id"] 

787 

788 # Check if user can view this conversation 

789 conversation_resource = f"conversation:{thread_id}" 

790 

791 can_view = await self.auth.authorize(user_id=user_id, relation="viewer", resource=conversation_resource) 

792 

793 if not can_view: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true

794 logger.warning("User cannot view conversation", extra={"user_id": user_id, "thread_id": thread_id}) 

795 msg = f"Not authorized to view conversation {thread_id}" 

796 raise PermissionError(msg) 

797 

798 # Retrieve conversation from checkpointer 

799 graph = get_agent_graph() # type: ignore[func-returns-value] 

800 

801 if not hasattr(graph, "checkpointer") or graph.checkpointer is None: 801 ↛ 802line 801 didn't jump to line 802 because the condition on line 801 was never true

802 logger.warning("No checkpointer available, cannot retrieve conversation history") 

803 return [ 

804 TextContent( 

805 type="text", 

806 text="Conversation history unavailable: checkpointing is disabled. " 

807 "Enable checkpointing to persist conversation history.", 

808 ) 

809 ] 

810 

811 try: 

812 config = {"configurable": {"thread_id": thread_id}} 

813 state_snapshot = await graph.aget_state(config) 

814 

815 if state_snapshot is None or state_snapshot.values is None: 815 ↛ 816line 815 didn't jump to line 816 because the condition on line 815 was never true

816 logger.info("Conversation not found", extra={"thread_id": thread_id}) 

817 return [TextContent(type="text", text=f"Conversation '{thread_id}' not found or has no history.")] 

818 

819 # Extract messages from state 

820 messages = state_snapshot.values.get("messages", []) 

821 

822 if not messages: 822 ↛ 823line 822 didn't jump to line 823 because the condition on line 822 was never true

823 return [TextContent(type="text", text=f"Conversation '{thread_id}' has no messages yet.")] 

824 

825 # Format conversation history 

826 formatted_lines = [f"Conversation: {thread_id}", f"Messages: {len(messages)}", ""] 

827 

828 for i, msg in enumerate(messages, 1): 

829 msg_type = type(msg).__name__ 

830 content = getattr(msg, "content", str(msg)) 

831 

832 # Truncate long messages for readability 

833 if len(content) > 200: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true

834 content = content[:200] + "..." 

835 

836 formatted_lines.append(f"{i}. [{msg_type}] {content}") 

837 

838 # Add metadata if available 

839 metadata = state_snapshot.values.get("metadata", {}) 

840 if metadata: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true

841 formatted_lines.append("") 

842 formatted_lines.append("Metadata:") 

843 for key, value in metadata.items(): 

844 formatted_lines.append(f" {key}: {value}") 

845 

846 response_text = "\n".join(formatted_lines) 

847 

848 logger.info( 

849 "Conversation history retrieved", 

850 extra={"thread_id": thread_id, "message_count": len(messages), "user_id": user_id}, 

851 ) 

852 

853 return [TextContent(type="text", text=response_text)] 

854 

855 except Exception as e: 

856 logger.error(f"Failed to retrieve conversation history: {e}", extra={"thread_id": thread_id}, exc_info=True) 

857 return [ 

858 TextContent( 

859 type="text", 

860 text=f"Error retrieving conversation history: {e!s}", 

861 ) 

862 ] 

863 

864 async def _handle_search_conversations(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]: 

865 """ 

866 Search conversations (replacing list-all approach). 

867 

868 Implements Anthropic best practice: 

869 "Implement search-focused tools (like search_contacts) rather than 

870 list-all tools (list_contacts)" 

871 

872 Benefits: 

873 - Prevents context overflow with large conversation lists 

874 - Forces agents to be specific in their requests 

875 - More token-efficient 

876 - Better for users with many conversations 

877 """ 

878 with tracer.start_as_current_span("agent.search_conversations"): 

879 # BUGFIX: Validate input with Pydantic schema to enforce query length and limit constraints 

880 try: 

881 search_input = SearchConversationsInput.model_validate(arguments) 

882 except Exception as e: 

883 # SECURITY: Sanitize arguments before logging to prevent token exposure in error logs 

884 logger.error(f"Invalid search input: {e}", extra={"arguments": sanitize_for_logging(arguments)}) 

885 msg = f"Invalid search input: {e}" 

886 raise ValueError(msg) 

887 

888 query = search_input.query 

889 limit = search_input.limit 

890 

891 span.set_attribute("search.query", query) 

892 span.set_attribute("search.limit", limit) 

893 

894 # Get all conversations user can view 

895 all_conversations = await self.auth.list_accessible_resources( 

896 user_id=user_id, relation="viewer", resource_type="conversation" 

897 ) 

898 

899 # Filter conversations based on query (simple implementation) 

900 # In production, this would use a proper search index 

901 # Normalize query and conversation names to handle spaces/underscores/hyphens 

902 if query: 902 ↛ 911line 902 didn't jump to line 911 because the condition on line 902 was always true

903 normalized_query = query.lower().replace(" ", "_").replace("-", "_") 

904 filtered_conversations = [ 

905 conv 

906 for conv in all_conversations 

907 if (query.lower() in conv.lower() or normalized_query in conv.lower().replace(" ", "_").replace("-", "_")) 

908 ] 

909 else: 

910 # No query: return most recent (up to limit) 

911 filtered_conversations = all_conversations 

912 

913 # Apply limit to prevent context overflow 

914 # Follows Anthropic guidance: "Restrict responses to ~25,000 tokens" 

915 limited_conversations = filtered_conversations[:limit] 

916 

917 # Build response with high-signal information 

918 # Avoid technical IDs where possible 

919 if not limited_conversations: 

920 response_text = ( 

921 f"No conversations found matching '{query}'. " 

922 f"Try a different search query or request access to more conversations." 

923 ) 

924 else: 

925 response_lines = [ 

926 ( 

927 f"Found {len(limited_conversations)} conversation(s) matching '{query}':" 

928 if query 

929 else f"Showing {len(limited_conversations)} recent conversation(s):" 

930 ) 

931 ] 

932 

933 for i, conv_id in enumerate(limited_conversations, 1): 

934 # Extract human-readable info from conversation ID 

935 # In production, fetch metadata like title, date, participants 

936 response_lines.append(f"{i}. {conv_id}") 

937 

938 # Add guidance if results were truncated 

939 if len(filtered_conversations) > limit: 939 ↛ 940line 939 didn't jump to line 940 because the condition on line 939 was never true

940 response_lines.append( 

941 f"\n[Showing {limit} of {len(filtered_conversations)} results. " 

942 f"Use a more specific query to narrow results.]" 

943 ) 

944 

945 response_text = "\n".join(response_lines) 

946 

947 logger.info( 

948 "Searched conversations", 

949 extra={ 

950 "user_id": user_id, 

951 "query": query, 

952 "total_accessible": len(all_conversations), 

953 "filtered_count": len(filtered_conversations), 

954 "returned_count": len(limited_conversations), 

955 }, 

956 ) 

957 

958 return [TextContent(type="text", text=response_text)] 

959 

960 async def _handle_search_tools(self, arguments: dict[str, Any], span: Any) -> list[TextContent]: 

961 """ 

962 Handle search_tools invocation for progressive tool discovery. 

963 

964 Implements Anthropic best practice for token-efficient tool discovery. 

965 """ 

966 with tracer.start_as_current_span("tools.search"): 

967 from mcp_server_langgraph.tools.tool_discovery import search_tools 

968 

969 # Extract arguments 

970 query = arguments.get("query") 

971 category = arguments.get("category") 

972 detail_level = arguments.get("detail_level", "minimal") 

973 

974 logger.info( 

975 "Searching tools", 

976 extra={"query": query, "category": category, "detail_level": detail_level}, 

977 ) 

978 

979 # Execute search_tools 

980 result = search_tools.invoke( 

981 { 

982 "query": query, 

983 "category": category, 

984 "detail_level": detail_level, 

985 } 

986 ) 

987 

988 span.set_attribute("tools.query", query or "") 

989 span.set_attribute("tools.category", category or "") 

990 span.set_attribute("tools.detail_level", detail_level) 

991 

992 return [TextContent(type="text", text=result)] 

993 

994 async def _handle_execute_python(self, arguments: dict[str, Any], span: Any, user_id: str) -> list[TextContent]: 

995 """ 

996 Handle execute_python invocation for secure code execution. 

997 

998 Implements sandboxed Python execution with validation and resource limits. 

999 """ 

1000 with tracer.start_as_current_span("code.execute"): 

1001 import time 

1002 

1003 from mcp_server_langgraph.tools.code_execution_tools import execute_python 

1004 

1005 # Extract arguments 

1006 code = arguments.get("code", "") 

1007 timeout = arguments.get("timeout") 

1008 

1009 logger.info( 

1010 "Executing Python code", 

1011 extra={ 

1012 "user_id": user_id, 

1013 "code_length": len(code), 

1014 "timeout": timeout, 

1015 }, 

1016 ) 

1017 

1018 # Execute code 

1019 start_time = time.time() 

1020 result = execute_python.invoke({"code": code, "timeout": timeout}) 

1021 execution_time = time.time() - start_time 

1022 

1023 span.set_attribute("code.length", len(code)) 

1024 span.set_attribute("code.execution_time", execution_time) 

1025 span.set_attribute("code.success", "success" in result.lower()) 

1026 

1027 metrics.code_executions.add(1, {"user_id": user_id, "success": "success" in result.lower()}) 

1028 

1029 return [TextContent(type="text", text=result)] 

1030 

1031 

1032# ============================================================================ 

1033# Lazy Singleton Pattern for MCP Server 

1034# ============================================================================ 

1035# Prevents import-time initialization which causes issues with: 

1036# - Missing environment variables 

1037# - Observability not yet initialized 

1038# - Test dependency injection 

1039 

1040_mcp_server_instance: MCPAgentStreamableServer | None = None 

1041 

1042 

1043def get_mcp_server() -> MCPAgentStreamableServer: 

1044 """ 

1045 Get or create the MCP server instance (lazy singleton) 

1046 

1047 This ensures the server is only created after observability is initialized, 

1048 avoiding import-time side effects and logging errors. 

1049 

1050 Returns: 

1051 MCPAgentStreamableServer singleton instance 

1052 

1053 Raises: 

1054 RuntimeError: If observability is not initialized 

1055 """ 

1056 from mcp_server_langgraph.observability.telemetry import is_initialized 

1057 

1058 # Guard: Ensure observability is initialized before creating server 

1059 if not is_initialized(): 

1060 msg = ( 

1061 "Observability must be initialized before creating MCP server. " 

1062 "This should be done automatically via the startup event handler. " 

1063 "If you see this error, observability initialization failed or was skipped." 

1064 ) 

1065 raise RuntimeError(msg) 

1066 

1067 global _mcp_server_instance 

1068 if _mcp_server_instance is None: 1068 ↛ 1070line 1068 didn't jump to line 1070 because the condition on line 1068 was always true

1069 _mcp_server_instance = MCPAgentStreamableServer() 

1070 return _mcp_server_instance 

1071 

1072 

1073# ============================================================================ 

1074# Authentication Models 

1075# ============================================================================ 

1076 

1077 

1078class LoginRequest(BaseModel): 

1079 """Login request with username and password""" 

1080 

1081 username: str = Field(description="Username", min_length=1, max_length=100) 

1082 password: str = Field(description="Password", min_length=1, max_length=500) 

1083 

1084 

1085class LoginResponse(BaseModel): 

1086 """Login response with JWT token""" 

1087 

1088 access_token: str = Field(description="JWT access token") 

1089 token_type: str = Field(default="bearer", description="Token type (always 'bearer')") 

1090 expires_in: int = Field(description="Token expiration in seconds") 

1091 user_id: str = Field(description="User identifier") 

1092 username: str = Field(description="Username") 

1093 roles: list[str] = Field(description="User roles") 

1094 

1095 

1096class RefreshTokenRequest(BaseModel): 

1097 """Token refresh request""" 

1098 

1099 refresh_token: str | None = Field(description="Refresh token (Keycloak only)", default=None) 

1100 current_token: str | None = Field(description="Current access token (for InMemory provider)", default=None) 

1101 

1102 

1103class RefreshTokenResponse(BaseModel): 

1104 """Token refresh response""" 

1105 

1106 access_token: str = Field(description="New JWT access token") 

1107 token_type: str = Field(default="bearer", description="Token type") 

1108 expires_in: int = Field(description="Token expiration in seconds") 

1109 refresh_token: str | None = Field(default=None, description="New refresh token (Keycloak only)") 

1110 

1111 

1112# FastAPI endpoints for MCP StreamableHTTP transport 

1113@app.get("/") 

1114async def root() -> dict[str, Any]: 

1115 """Root endpoint with server info""" 

1116 return { 

1117 "name": "MCP Server with LangGraph", 

1118 "version": settings.service_version, 

1119 "transport": "streamable-http", 

1120 "protocol": "mcp", 

1121 "endpoints": { 

1122 "auth": { 

1123 "login": "/auth/login", 

1124 "refresh": "/auth/refresh", 

1125 }, 

1126 "message": "/message", 

1127 "tools": "/tools", 

1128 "resources": "/resources", 

1129 "health": "/health", 

1130 }, 

1131 "capabilities": { 

1132 "tools": {"listSupported": True, "callSupported": True}, 

1133 "resources": {"listSupported": True, "readSupported": True}, 

1134 "streaming": True, 

1135 "authentication": { 

1136 "methods": ["jwt"], 

1137 "tokenRefresh": True, 

1138 "providers": [settings.auth_provider], 

1139 }, 

1140 }, 

1141 } 

1142 

1143 

1144@app.post("/auth/login", tags=["auth"]) 

1145async def login(request: LoginRequest) -> LoginResponse: 

1146 """ 

1147 Authenticate user and return JWT token 

1148 

1149 This endpoint accepts username and password, validates credentials, 

1150 and returns a JWT token that can be used for subsequent tool calls. 

1151 

1152 The token should be included in the 'token' field of all tool call requests. 

1153 

1154 Example: 

1155 POST /auth/login 

1156 { 

1157 "username": "your-username", 

1158 "password": "your-secure-password" 

1159 } 

1160 

1161 Response: 

1162 { 

1163 "access_token": "eyJ...", 

1164 "token_type": "bearer", 

1165 "expires_in": 3600, 

1166 "user_id": "user:your-username", 

1167 "username": "your-username", 

1168 "roles": ["user"] 

1169 } 

1170 

1171 Note: InMemoryUserProvider no longer seeds default users (v2.8.0+). 

1172 Create users explicitly via provider.add_user() for testing. 

1173 """ 

1174 with tracer.start_as_current_span("auth.login") as span: 

1175 span.set_attribute("auth.username", request.username) 

1176 

1177 # Get MCP server instance (which has auth middleware) 

1178 mcp_server_instance = get_mcp_server() 

1179 

1180 # Authenticate user via configured provider 

1181 auth_result = await mcp_server_instance.auth.user_provider.authenticate(request.username, request.password) 

1182 

1183 if not auth_result.authorized: 

1184 logger.warning( 

1185 "Login failed", extra={"username": request.username, "reason": auth_result.reason, "error": auth_result.error} 

1186 ) 

1187 metrics.auth_failures.add(1) 

1188 raise HTTPException( 

1189 status_code=401, 

1190 detail=f"Authentication failed: {auth_result.reason or auth_result.error or 'invalid credentials'}", 

1191 ) 

1192 

1193 # Create JWT token 

1194 # For InMemoryUserProvider, use the create_token method 

1195 # For KeycloakUserProvider, tokens are returned directly in auth_result 

1196 if auth_result.access_token: 

1197 # Keycloak provider returns token directly 

1198 access_token = auth_result.access_token 

1199 expires_in = auth_result.expires_in or 3600 

1200 else: 

1201 # InMemoryUserProvider needs to create token 

1202 try: 

1203 access_token = mcp_server_instance.auth.create_token( 

1204 username=request.username, expires_in=mcp_server_instance.settings.jwt_expiration_seconds 

1205 ) 

1206 expires_in = mcp_server_instance.settings.jwt_expiration_seconds 

1207 except Exception as e: 

1208 logger.error(f"Failed to create token: {e}", exc_info=True) 

1209 raise HTTPException(status_code=500, detail="Failed to create authentication token") 

1210 

1211 logger.info( 

1212 "Login successful", 

1213 extra={ 

1214 "username": request.username, 

1215 "user_id": auth_result.user_id, 

1216 "provider": type(mcp_server_instance.auth.user_provider).__name__, 

1217 }, 

1218 ) 

1219 

1220 return LoginResponse( 

1221 access_token=access_token, 

1222 token_type="bearer", 

1223 expires_in=expires_in, 

1224 user_id=auth_result.user_id or "", 

1225 username=auth_result.username or request.username, 

1226 roles=auth_result.roles or [], 

1227 ) 

1228 

1229 

1230@app.post("/auth/refresh", tags=["auth"]) 

1231async def refresh_token(request: RefreshTokenRequest) -> RefreshTokenResponse: 

1232 """ 

1233 Refresh authentication token 

1234 

1235 Supports two refresh methods: 

1236 1. Keycloak: Uses refresh_token to get new access token 

1237 2. InMemory: Validates current token and issues new one 

1238 

1239 Example (Keycloak): 

1240 POST /auth/refresh 

1241 { 

1242 "refresh_token": "eyJ..." 

1243 } 

1244 

1245 Example (InMemory): 

1246 POST /auth/refresh 

1247 { 

1248 "current_token": "eyJ..." 

1249 } 

1250 

1251 Response: 

1252 { 

1253 "access_token": "eyJ...", 

1254 "token_type": "bearer", 

1255 "expires_in": 3600, 

1256 "refresh_token": "eyJ..." // Keycloak only 

1257 } 

1258 """ 

1259 with tracer.start_as_current_span("auth.refresh") as span: 

1260 mcp_server_instance = get_mcp_server() 

1261 

1262 # Handle Keycloak refresh token 

1263 if request.refresh_token: 

1264 if not isinstance(mcp_server_instance.auth.user_provider, KeycloakUserProvider): 

1265 logger.warning("Refresh token provided but provider is not Keycloak") 

1266 raise HTTPException( 

1267 status_code=400, 

1268 detail="Refresh tokens are only supported with KeycloakUserProvider. " 

1269 "Current provider does not support refresh tokens.", 

1270 ) 

1271 

1272 try: 

1273 # Use KeycloakUserProvider's refresh_token method 

1274 result = await mcp_server_instance.auth.user_provider.refresh_token(request.refresh_token) 

1275 

1276 if not result.get("success"): 

1277 raise HTTPException(status_code=401, detail=f"Token refresh failed: {result.get('error')}") 

1278 

1279 tokens = result["tokens"] 

1280 

1281 logger.info("Token refreshed via Keycloak") 

1282 

1283 return RefreshTokenResponse( 

1284 access_token=tokens["access_token"], 

1285 token_type="bearer", 

1286 expires_in=tokens.get("expires_in", 300), 

1287 refresh_token=tokens.get("refresh_token"), 

1288 ) 

1289 

1290 except HTTPException: 

1291 raise 

1292 except Exception as e: 

1293 logger.error(f"Token refresh failed: {e}", exc_info=True) 

1294 raise HTTPException(status_code=500, detail="Token refresh failed") 

1295 

1296 # Handle InMemory token refresh (verify current token and issue new one) 

1297 elif request.current_token: 

1298 try: 

1299 # Verify current token 

1300 token_verification = await mcp_server_instance.auth.verify_token(request.current_token) 

1301 

1302 if not token_verification.valid: 

1303 logger.warning("Current token invalid for refresh", extra={"error": token_verification.error}) 

1304 raise HTTPException(status_code=401, detail=f"Invalid token: {token_verification.error}") 

1305 

1306 # Extract user info from token 

1307 if not token_verification.payload or "username" not in token_verification.payload: 

1308 raise HTTPException(status_code=401, detail="Invalid token: missing username") 

1309 

1310 username = token_verification.payload["username"] 

1311 span.set_attribute("auth.username", username) 

1312 

1313 # Issue new token 

1314 new_token = mcp_server_instance.auth.create_token( 

1315 username=username, expires_in=mcp_server_instance.settings.jwt_expiration_seconds 

1316 ) 

1317 

1318 logger.info("Token refreshed for user", extra={"username": username}) 

1319 

1320 return RefreshTokenResponse( 

1321 access_token=new_token, 

1322 token_type="bearer", 

1323 expires_in=mcp_server_instance.settings.jwt_expiration_seconds, 

1324 ) 

1325 

1326 except HTTPException: 

1327 raise 

1328 except Exception as e: 

1329 logger.error(f"Token refresh failed: {e}", exc_info=True) 

1330 raise HTTPException(status_code=500, detail="Token refresh failed") 

1331 

1332 else: 

1333 raise HTTPException( 

1334 status_code=400, detail="Either 'refresh_token' (Keycloak) or 'current_token' (InMemory) must be provided" 

1335 ) 

1336 

1337 

1338async def stream_jsonrpc_response(data: dict[str, Any]) -> AsyncIterator[str]: 

1339 """ 

1340 Stream a JSON-RPC response in chunks 

1341 

1342 Yields newline-delimited JSON for streaming responses 

1343 """ 

1344 # For now, send the complete response 

1345 # In the future, this could stream token-by-token for LLM responses 

1346 yield json.dumps(data) + "\n" 

1347 

1348 

1349@app.post("/message", response_model=None) 

1350async def handle_message(request: Request) -> JSONResponse | StreamingResponse: 

1351 """ 

1352 Handle MCP messages via StreamableHTTP POST 

1353 

1354 This is the main endpoint for MCP protocol messages. 

1355 Supports both regular and streaming responses. 

1356 """ 

1357 try: 

1358 # Parse JSON-RPC request 

1359 message = await request.json() 

1360 

1361 with tracer.start_as_current_span("mcp.streamable.message") as span: 

1362 span.set_attribute("mcp.method", message.get("method", "unknown")) 

1363 span.set_attribute("mcp.id", str(message.get("id", ""))) 

1364 

1365 logger.info("Received MCP message", extra={"method": message.get("method"), "id": message.get("id")}) 

1366 

1367 method = message.get("method") 

1368 message_id = message.get("id") 

1369 

1370 # Handle different MCP methods 

1371 if method == "initialize": 

1372 mcp_server = get_mcp_server() 

1373 response = { 

1374 "jsonrpc": "2.0", 

1375 "id": message_id, 

1376 "result": { 

1377 "protocolVersion": "2024-11-05", 

1378 "serverInfo": {"name": "langgraph-agent", "version": mcp_server.settings.service_version}, 

1379 "capabilities": { 

1380 "tools": {"listChanged": False}, 

1381 "resources": {"listChanged": False}, 

1382 "prompts": {}, 

1383 "logging": {}, 

1384 }, 

1385 }, 

1386 } 

1387 return JSONResponse(response) 

1388 

1389 elif method == "tools/list": 

1390 # Use public API instead of private _tool_manager 

1391 tools = await get_mcp_server().list_tools_public() 

1392 response = { 

1393 "jsonrpc": "2.0", 

1394 "id": message_id, 

1395 "result": {"tools": [tool.model_dump(mode="json") for tool in tools]}, 

1396 } 

1397 return JSONResponse(response) 

1398 

1399 elif method == "tools/call": 

1400 params = message.get("params", {}) 

1401 tool_name = params.get("name") 

1402 arguments = params.get("arguments", {}) 

1403 

1404 # Check if client supports streaming 

1405 accept_header = request.headers.get("accept", "") 

1406 supports_streaming = "text/event-stream" in accept_header or "application/x-ndjson" in accept_header 

1407 

1408 # Use public API instead of private _tool_manager 

1409 result = await get_mcp_server().call_tool_public(tool_name, arguments) 

1410 

1411 response_data = { 

1412 "jsonrpc": "2.0", 

1413 "id": message_id, 

1414 "result": {"content": [item.model_dump(mode="json") for item in result]}, 

1415 } 

1416 

1417 # If streaming is supported, stream the response 

1418 if supports_streaming: 

1419 return StreamingResponse( 

1420 stream_jsonrpc_response(response_data), 

1421 media_type="application/x-ndjson", 

1422 headers={"X-Content-Type-Options": "nosniff", "Cache-Control": "no-cache"}, 

1423 ) 

1424 else: 

1425 return JSONResponse(response_data) 

1426 

1427 elif method == "resources/list": 

1428 # Use public API instead of private _resource_manager 

1429 resources = await get_mcp_server().list_resources_public() 

1430 response = { 

1431 "jsonrpc": "2.0", 

1432 "id": message_id, 

1433 "result": {"resources": [res.model_dump(mode="json") for res in resources]}, 

1434 } 

1435 return JSONResponse(response) 

1436 

1437 elif method == "resources/read": 

1438 params = message.get("params", {}) 

1439 resource_uri = params.get("uri") 

1440 

1441 # Handle resource read (implement based on your needs) 

1442 response = { 

1443 "jsonrpc": "2.0", 

1444 "id": message_id, 

1445 "result": { 

1446 "contents": [ 

1447 { 

1448 "uri": resource_uri, 

1449 "mimeType": "application/json", 

1450 "text": json.dumps({"config": "placeholder"}), 

1451 } 

1452 ] 

1453 }, 

1454 } 

1455 return JSONResponse(response) 

1456 

1457 else: 

1458 raise HTTPException(status_code=400, detail=f"Unknown method: {method}") 

1459 

1460 except PermissionError as e: 

1461 logger.warning(f"Permission denied: {e}") 

1462 return JSONResponse( 

1463 status_code=200, # JSON-RPC errors should return 200 with error in body 

1464 content={ 

1465 "jsonrpc": "2.0", 

1466 "id": message.get("id") if "message" in locals() else None, 

1467 "error": {"code": -32001, "message": str(e)}, 

1468 }, 

1469 ) 

1470 except ValueError as e: 

1471 # Validation errors (missing fields, invalid arguments) 

1472 logger.warning(f"Validation error: {e}") 

1473 return JSONResponse( 

1474 status_code=200, # JSON-RPC errors should return 200 with error in body 

1475 content={ 

1476 "jsonrpc": "2.0", 

1477 "id": message.get("id") if "message" in locals() else None, 

1478 "error": {"code": -32602, "message": f"Invalid params: {e!s}"}, 

1479 }, 

1480 ) 

1481 except HTTPException as e: 

1482 # HTTP exceptions (unknown methods, etc.) 

1483 logger.warning(f"HTTP exception: {e.detail}") 

1484 return JSONResponse( 

1485 status_code=200, # JSON-RPC errors should return 200 with error in body 

1486 content={ 

1487 "jsonrpc": "2.0", 

1488 "id": message.get("id") if "message" in locals() else None, 

1489 "error": {"code": -32601 if e.status_code == 400 else -32603, "message": e.detail}, 

1490 }, 

1491 ) 

1492 except json.JSONDecodeError as e: 

1493 # Invalid JSON in request body 

1494 logger.warning(f"JSON parse error: {e}") 

1495 return JSONResponse( 

1496 status_code=400, # Parse errors can return 400 

1497 content={ 

1498 "jsonrpc": "2.0", 

1499 "id": None, 

1500 "error": {"code": -32700, "message": "Parse error: Invalid JSON"}, 

1501 }, 

1502 ) 

1503 except Exception as e: 

1504 logger.error(f"Error handling message: {e}", exc_info=True) 

1505 return JSONResponse( 

1506 status_code=500, 

1507 content={ 

1508 "jsonrpc": "2.0", 

1509 "id": message.get("id") if "message" in locals() else None, 

1510 "error": {"code": -32603, "message": str(e)}, 

1511 }, 

1512 ) 

1513 

1514 

1515@app.get("/tools") 

1516async def list_tools() -> dict[str, Any]: 

1517 """List available tools (convenience endpoint)""" 

1518 # Use public API instead of private _tool_manager 

1519 tools = await get_mcp_server().list_tools_public() 

1520 return {"tools": [tool.model_dump(mode="json") for tool in tools]} 

1521 

1522 

1523@app.get("/resources") 

1524async def list_resources() -> dict[str, Any]: 

1525 """List available resources (convenience endpoint)""" 

1526 # Use public API instead of private _resource_manager 

1527 resources = await get_mcp_server().list_resources_public() 

1528 return {"resources": [res.model_dump(mode="json") for res in resources]} 

1529 

1530 

1531# Include health check routes 

1532from mcp_server_langgraph.health.checks import app as health_app # noqa: E402 

1533 

1534app.mount("/health", health_app) 

1535 

1536from mcp_server_langgraph.api.api_keys import router as api_keys_router # noqa: E402 

1537from mcp_server_langgraph.api.gdpr import router as gdpr_router # noqa: E402 

1538from mcp_server_langgraph.api.scim import router as scim_router # noqa: E402 

1539from mcp_server_langgraph.api.service_principals import router as service_principals_router # noqa: E402 

1540 

1541# Include REST API routes 

1542from mcp_server_langgraph.api.version import router as version_router # noqa: E402 

1543 

1544app.include_router(version_router) 

1545app.include_router(gdpr_router) 

1546app.include_router(api_keys_router) 

1547app.include_router(service_principals_router) 

1548app.include_router(scim_router) 

1549 

1550 

1551# ============================================================================== 

1552# Custom OpenAPI Schema Generation 

1553# ============================================================================== 

1554# Pagination models need to be included in OpenAPI schema for API documentation, 

1555# even if not all endpoints use them yet. This ensures consistent API contract. 

1556 

1557 

1558def custom_openapi() -> dict[str, Any]: 

1559 """ 

1560 Custom OpenAPI schema generator that includes pagination models. 

1561 

1562 FastAPI only includes models in the OpenAPI schema if they're used in endpoint 

1563 signatures. We explicitly add pagination models here to ensure they're documented 

1564 for API consumers, even if not all endpoints implement pagination yet. 

1565 

1566 This follows TDD principles - tests define the expected API contract first. 

1567 """ 

1568 from typing import cast 

1569 

1570 if app.openapi_schema: 

1571 return cast(dict[str, Any], app.openapi_schema) # type: ignore[redundant-cast] 

1572 

1573 # Generate base OpenAPI schema 

1574 from fastapi.openapi.utils import get_openapi 

1575 

1576 openapi_schema = get_openapi( 

1577 title=app.title, 

1578 version=app.version, 

1579 description=app.description, 

1580 routes=app.routes, 

1581 tags=app.openapi_tags, 

1582 ) 

1583 

1584 # Add pagination models to schema components 

1585 from mcp_server_langgraph.api.pagination import PaginationMetadata, PaginationParams 

1586 

1587 # Get JSON schemas for pagination models 

1588 pagination_params_schema = PaginationParams.model_json_schema() 

1589 pagination_metadata_schema = PaginationMetadata.model_json_schema() 

1590 

1591 # Add to components/schemas 

1592 if "components" not in openapi_schema: 1592 ↛ 1593line 1592 didn't jump to line 1593 because the condition on line 1592 was never true

1593 openapi_schema["components"] = {} 

1594 if "schemas" not in openapi_schema["components"]: 1594 ↛ 1595line 1594 didn't jump to line 1595 because the condition on line 1594 was never true

1595 openapi_schema["components"]["schemas"] = {} 

1596 

1597 openapi_schema["components"]["schemas"]["PaginationParams"] = pagination_params_schema 

1598 openapi_schema["components"]["schemas"]["PaginationMetadata"] = pagination_metadata_schema 

1599 

1600 # Note: PaginatedResponse is generic, so we document it in the description 

1601 # Individual endpoint responses (e.g., PaginatedResponse[APIKey]) will be 

1602 # generated automatically when endpoints use them 

1603 

1604 app.openapi_schema = openapi_schema 

1605 return cast(dict[str, Any], app.openapi_schema) # type: ignore[redundant-cast] 

1606 

1607 

1608# Apply custom OpenAPI schema 

1609app.openapi = custom_openapi # type: ignore[method-assign] 

1610 

1611 

1612def main() -> None: 

1613 """Entry point for console script""" 

1614 # Initialize observability system before creating server 

1615 import atexit 

1616 

1617 from mcp_server_langgraph.core.config import settings 

1618 from mcp_server_langgraph.observability.telemetry import init_observability, shutdown_observability 

1619 

1620 # Initialize with settings and enable file logging if configured 

1621 init_observability(settings=settings, enable_file_logging=getattr(settings, "enable_file_logging", False)) 

1622 

1623 # Register shutdown handler as fallback (lifespan is primary) 

1624 atexit.register(shutdown_observability) 

1625 

1626 # SECURITY: Validate CORS configuration before starting server 

1627 settings.validate_cors_config() 

1628 

1629 logger.info(f"Starting MCP StreamableHTTP server on port {settings.get_secret('PORT', fallback='8000')}") 

1630 

1631 port_str = settings.get_secret("PORT", fallback="8000") 

1632 port = int(port_str) if port_str else 8000 

1633 

1634 uvicorn.run( 

1635 app, 

1636 host="0.0.0.0", # nosec B104 - Required for containerized deployment 

1637 port=port, 

1638 log_level=settings.log_level.lower(), 

1639 access_log=True, 

1640 ) 

1641 

1642 

1643if __name__ == "__main__": 

1644 main()