Coverage for src / mcp_server_langgraph / auth / middleware.py: 70%

301 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Authentication and Authorization middleware with OpenFGA integration 

3 

4Now supports: 

5- Pluggable user providers (InMemory, Keycloak, custom) 

6- Session management (Token-based or Session-based) 

7- Fine-grained authorization via OpenFGA 

8""" 

9 

10from functools import wraps 

11from typing import Any, Optional 

12 

13from pydantic import BaseModel, ConfigDict, Field 

14 

15from mcp_server_langgraph.auth.openfga import OpenFGAClient 

16from mcp_server_langgraph.auth.session import SessionData, SessionStore 

17from mcp_server_langgraph.auth.user_provider import AuthResponse, InMemoryUserProvider, TokenVerification, UserProvider 

18from mcp_server_langgraph.observability.telemetry import logger, tracer 

19 

20# FastAPI imports for dependency injection (optional, only if using FastAPI endpoints) 

21try: 

22 from fastapi import Depends, HTTPException, Request, status # noqa: F401 (Depends used at line 929, 970) 

23 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 

24 

25 FASTAPI_AVAILABLE = True 

26except ImportError: 

27 FASTAPI_AVAILABLE = False 

28 

29# ============================================================================ 

30# Helper Functions 

31# ============================================================================ 

32 

33 

34def normalize_user_id(user_id: str) -> str: 

35 """ 

36 Normalize user_id to handle multiple formats. 

37 

38 Accepts: 

39 - Plain usernames: "alice" → "alice" 

40 - Prefixed IDs: "user:alice" → "alice" 

41 - Other prefixes: "uid:123" → "123" 

42 

43 This allows clients to use either format: 

44 - OpenFGA format (user:alice) 

45 - Simple username format (alice) 

46 

47 Args: 

48 user_id: User identifier in any supported format 

49 

50 Returns: 

51 Normalized username (without prefix) 

52 """ 

53 if not user_id: 

54 return user_id 

55 

56 # If contains colon, extract the part after the colon 

57 if ":" in user_id: 

58 return user_id.split(":", 1)[1] 

59 

60 # Otherwise, return as-is 

61 return user_id 

62 

63 

64# ============================================================================ 

65# Pydantic Models for Middleware Operations 

66# ============================================================================ 

67 

68 

69class AuthorizationResult(BaseModel): 

70 """ 

71 Type-safe authorization check result 

72 

73 Returned from authorize() operations. 

74 """ 

75 

76 authorized: bool = Field(..., description="Whether access is authorized") 

77 user_id: str = Field(..., description="User identifier that was checked") 

78 relation: str = Field(..., description="Relation that was checked") 

79 resource: str = Field(..., description="Resource that was checked") 

80 reason: str | None = Field(None, description="Reason for denial if not authorized") 

81 used_fallback: bool = Field(default=False, description="Whether fallback authorization was used") 

82 

83 model_config = ConfigDict( 

84 frozen=False, 

85 validate_assignment=True, 

86 str_strip_whitespace=True, 

87 json_schema_extra={ 

88 "example": { 

89 "authorized": True, 

90 "user_id": "user:alice", 

91 "relation": "executor", 

92 "resource": "tool:chat", 

93 "reason": None, 

94 "used_fallback": False, 

95 } 

96 }, 

97 ) 

98 

99 def to_dict(self) -> dict[str, Any]: 

100 """Convert to dictionary for backward compatibility""" 

101 return self.model_dump(exclude_none=True) 

102 

103 @classmethod 

104 def from_dict(cls, data: dict[str, Any]) -> "AuthorizationResult": 

105 """Create AuthorizationResult from dictionary""" 

106 return cls(**data) 

107 

108 

109class AuthMiddleware: 

110 """ 

111 Authentication and authorization handler with OpenFGA 

112 

113 Combines authentication (via pluggable user providers) with fine-grained 

114 relationship-based authorization using OpenFGA. 

115 

116 Supports multiple authentication backends: 

117 - InMemoryUserProvider (development/testing) 

118 - KeycloakUserProvider (production) 

119 - Custom providers 

120 """ 

121 

122 def __init__( 

123 self, 

124 secret_key: str | None = None, 

125 openfga_client: OpenFGAClient | None = None, 

126 user_provider: UserProvider | None = None, 

127 session_store: SessionStore | None = None, 

128 settings: Any | None = None, 

129 ): 

130 """ 

131 Initialize AuthMiddleware 

132 

133 Args: 

134 secret_key: Secret key for JWT tokens (used by InMemoryUserProvider). 

135 Must be provided via environment variable or settings. 

136 openfga_client: OpenFGA client for authorization 

137 user_provider: User provider instance (defaults to InMemoryUserProvider for backward compatibility) 

138 session_store: Session store for session-based authentication (optional) 

139 settings: Application settings (for authorization fallback control) 

140 """ 

141 self.secret_key = secret_key 

142 self.openfga = openfga_client 

143 self.session_store = session_store 

144 self.settings = settings 

145 

146 # Use provided user provider or default to in-memory for backward compatibility 

147 if user_provider is None: 

148 logger.info("No user provider specified, defaulting to InMemoryUserProvider") 

149 user_provider = InMemoryUserProvider(secret_key=secret_key) 

150 

151 self.user_provider = user_provider 

152 

153 # For backward compatibility: expose users_db if using InMemoryUserProvider 

154 if isinstance(user_provider, InMemoryUserProvider): 

155 self.users_db = user_provider.users_db 

156 else: 

157 self.users_db = {} # Empty dict for non-inmemory providers 

158 

159 logger.info( 

160 "AuthMiddleware initialized", 

161 extra={ 

162 "provider_type": type(user_provider).__name__, 

163 "openfga_enabled": openfga_client is not None, 

164 "session_enabled": session_store is not None, 

165 "allow_auth_fallback": getattr(settings, "allow_auth_fallback", None) if settings else None, 

166 }, 

167 ) 

168 

169 async def authenticate(self, username: str, password: str | None = None) -> AuthResponse: 

170 """ 

171 Authenticate user by username 

172 

173 Args: 

174 username: Username to authenticate (accepts both "alice" and "user:alice" formats) 

175 password: Password (required for some providers like Keycloak) 

176 

177 Returns: 

178 AuthResponse with authentication result 

179 """ 

180 with tracer.start_as_current_span("auth.authenticate") as span: 

181 # Normalize username to handle both "alice" and "user:alice" formats 

182 normalized_username = normalize_user_id(username) 

183 span.set_attribute("auth.username", normalized_username) 

184 

185 # Delegate to user provider (returns Pydantic AuthResponse) 

186 result = await self.user_provider.authenticate(normalized_username, password) 

187 

188 if result.authorized: 

189 from mcp_server_langgraph.core.security import sanitize_for_logging 

190 

191 logger.info( 

192 "User authenticated", 

193 extra=sanitize_for_logging({"username": username, "user_id": result.user_id}), 

194 ) 

195 else: 

196 from mcp_server_langgraph.core.security import sanitize_for_logging 

197 

198 logger.warning( 

199 "Authentication failed", extra=sanitize_for_logging({"username": username, "reason": result.reason}) 

200 ) 

201 

202 return result 

203 

204 async def authorize(self, user_id: str, relation: str, resource: str, context: dict[str, Any] | None = None) -> bool: 

205 """ 

206 Check if user is authorized using OpenFGA 

207 

208 Args: 

209 user_id: User identifier (e.g., "user:alice") 

210 relation: Relation to check (e.g., "executor", "viewer") 

211 resource: Resource identifier (e.g., "tool:chat") 

212 context: Additional context for authorization 

213 

214 Returns: 

215 True if authorized, False otherwise 

216 """ 

217 with tracer.start_as_current_span("auth.authorize") as span: 

218 span.set_attribute("user.id", user_id) 

219 span.set_attribute("auth.relation", relation) 

220 span.set_attribute("auth.resource", resource) 

221 

222 # Use OpenFGA if available 

223 if self.openfga: 

224 try: 

225 authorized = await self.openfga.check_permission( 

226 user=user_id, relation=relation, object=resource, context=context 

227 ) 

228 

229 span.set_attribute("auth.authorized", authorized) 

230 logger.info( 

231 "Authorization check (OpenFGA)", 

232 extra={"user_id": user_id, "relation": relation, "resource": resource, "authorized": authorized}, 

233 ) 

234 

235 return authorized 

236 

237 except Exception as e: 

238 logger.error( 

239 f"OpenFGA authorization check failed: {e}", 

240 extra={"user_id": user_id, "relation": relation, "resource": resource}, 

241 exc_info=True, 

242 ) 

243 # Fail closed - deny access on error 

244 return False 

245 

246 # SECURITY CONTROL (OpenAI Codex Finding #1): Check if fallback authorization is allowed 

247 # When OpenFGA is not available, check configuration to determine if we should: 

248 # 1. Fail closed (deny all access) - secure default for production 

249 # 2. Fall back to role-based checks - only if explicitly enabled for dev/test 

250 

251 allow_fallback = getattr(self.settings, "allow_auth_fallback", False) if self.settings else False 

252 environment = getattr(self.settings, "environment", "production") if self.settings else "production" 

253 

254 # Defense in depth: NEVER allow fallback in production, even if misconfigured 

255 if environment == "production": 

256 logger.error( 

257 "Authorization DENIED: OpenFGA unavailable in production environment. " 

258 "Fallback authorization is not permitted in production for security reasons.", 

259 extra={ 

260 "user_id": user_id, 

261 "relation": relation, 

262 "resource": resource, 

263 "environment": environment, 

264 "allow_auth_fallback": allow_fallback, 

265 }, 

266 ) 

267 return False 

268 

269 # Check if fallback is explicitly enabled 

270 if not allow_fallback: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 logger.warning( 

272 "Authorization DENIED: OpenFGA unavailable and fallback authorization is disabled. " 

273 "Set ALLOW_AUTH_FALLBACK=true to enable role-based fallback in development/test.", 

274 extra={ 

275 "user_id": user_id, 

276 "relation": relation, 

277 "resource": resource, 

278 "allow_auth_fallback": allow_fallback, 

279 "environment": environment, 

280 }, 

281 ) 

282 return False 

283 

284 # Fallback: simple permission check (only when explicitly allowed in non-production) 

285 logger.warning( 

286 "OpenFGA not available, using fallback authorization (explicitly enabled)", 

287 extra={ 

288 "allow_auth_fallback": allow_fallback, 

289 "environment": environment, 

290 }, 

291 ) 

292 

293 # Extract username from user_id (handle worker-safe IDs for pytest-xdist) 

294 # InMemoryUserProvider is test-only, so this test-specific logic is acceptable 

295 # Examples: "user:alice" → "alice", "user:test_gw0_alice" → "alice" 

296 if ":" in user_id: 296 ↛ 304line 296 didn't jump to line 304 because the condition on line 296 was always true

297 id_part = user_id.split(":", 1)[1] # Remove "user:" prefix 

298 # Check if it's a worker-safe ID (format: test_gw\d+_username) 

299 import re 

300 

301 match = re.match(r"test_gw\d+_(.*)", id_part) 

302 username = match.group(1) if match else id_part 

303 else: 

304 username = user_id 

305 

306 # Get user data - try in-memory first, then query provider 

307 user_data = None 

308 user_roles = [] 

309 

310 if isinstance(self.user_provider, InMemoryUserProvider): 

311 # Fast path: Use in-memory users_db 

312 if username not in self.users_db: 

313 logger.warning( 

314 "Fallback authorization denied - user not found", 

315 extra={"user_id": user_id, "username": username, "provider": "InMemory"}, 

316 ) 

317 return False 

318 user = self.users_db[username] 

319 user_roles = user["roles"] 

320 

321 else: 

322 # For external providers (Keycloak, etc.): Query the provider 

323 try: 

324 user_data = await self.user_provider.get_user_by_username(username) 

325 if not user_data: 

326 logger.warning( 

327 "Fallback authorization denied - user not found in provider", 

328 extra={"user_id": user_id, "username": username, "provider": type(self.user_provider).__name__}, 

329 ) 

330 return False 

331 user_roles = user_data.roles 

332 logger.info( 

333 "Fetched user from provider for fallback authorization", 

334 extra={ 

335 "user_id": user_id, 

336 "username": username, 

337 "provider": type(self.user_provider).__name__, 

338 "roles": user_roles, 

339 }, 

340 ) 

341 except Exception as e: 

342 logger.error( 

343 f"Failed to fetch user from provider for fallback authorization: {e}", 

344 extra={"user_id": user_id, "username": username, "provider": type(self.user_provider).__name__}, 

345 exc_info=True, 

346 ) 

347 # Fail closed - deny access if we can't verify user 

348 return False 

349 

350 # Admin users have access to everything 

351 if "admin" in user_roles: 

352 logger.info( 

353 "Fallback authorization granted - admin user", 

354 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource}, 

355 ) 

356 return True 

357 

358 # Basic resource-based checks 

359 if relation == "executor" and resource.startswith("tool:"): 

360 authorized = "premium" in user_roles or "user" in user_roles 

361 if authorized: 

362 logger.info( 

363 "Fallback authorization granted - tool executor", 

364 extra={ 

365 "user_id": user_id, 

366 "username": username, 

367 "relation": relation, 

368 "resource": resource, 

369 "roles": user_roles, 

370 }, 

371 ) 

372 return authorized 

373 

374 if relation in ("viewer", "editor") and resource.startswith("conversation:"): 

375 # SECURITY: Scope conversation access by ownership in fallback mode 

376 # Extract thread_id from resource (format: "conversation:thread_id") 

377 thread_id = resource.split(":", 1)[1] if ":" in resource else "" 

378 

379 # Allow access only if: 

380 # 1. Thread is the default/unnamed thread 

381 # 2. Thread explicitly belongs to this user (prefixed with username) 

382 # 3. User is accessing their own user-scoped conversations 

383 if thread_id == "default" or thread_id == "": 

384 logger.info( 

385 "Fallback authorization granted - default conversation", 

386 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource}, 

387 ) 

388 return True 

389 

390 # Check if conversation belongs to this user 

391 # Format: "conversation:username_thread" or "conversation:user:username_thread" 

392 if thread_id.startswith(f"{username}_"): 

393 logger.info( 

394 "Fallback authorization granted - user-owned conversation", 

395 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource}, 

396 ) 

397 return True 

398 

399 # Also support user:username prefix in thread_id 

400 user_id_normalized = user_id.split(":")[-1] if ":" in user_id else user_id 

401 if thread_id.startswith(f"{user_id_normalized}_"): 401 ↛ 402line 401 didn't jump to line 402 because the condition on line 401 was never true

402 logger.info( 

403 "Fallback authorization granted - user-owned conversation (normalized)", 

404 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource}, 

405 ) 

406 return True 

407 

408 # Deny access to conversations not owned by this user 

409 logger.warning( 

410 "Fallback authorization denied conversation access", 

411 extra={ 

412 "user_id": user_id, 

413 "username": username, 

414 "thread_id": thread_id, 

415 "relation": relation, 

416 "reason": "conversation_not_owned_by_user", 

417 }, 

418 ) 

419 return False 

420 

421 # Default deny 

422 logger.warning( 

423 "Fallback authorization denied - no matching rule", 

424 extra={ 

425 "user_id": user_id, 

426 "username": username, 

427 "relation": relation, 

428 "resource": resource, 

429 "roles": user_roles, 

430 }, 

431 ) 

432 return False 

433 

434 def _get_mock_resources(self, user_id: str, relation: str, resource_type: str) -> list[str]: 

435 """ 

436 Get mock resources for development/testing when OpenFGA is not available. 

437 

438 Provides sample data to enable development and testing without authorization infrastructure. 

439 Resources are scoped per user to maintain proper RBAC semantics. 

440 

441 Args: 

442 user_id: User identifier (used to scope conversation resources) 

443 relation: Relation to check (e.g., "executor", "viewer") 

444 resource_type: Type of resources (e.g., "tool", "conversation") 

445 

446 Returns: 

447 List of mock resource identifiers scoped to the user 

448 """ 

449 # Extract username from user_id (handle both "user:alice" and "alice" formats) 

450 username = user_id.split(":")[-1] if ":" in user_id else user_id 

451 

452 # Mock data for different resource types 

453 mock_data = { 

454 "tool": [ 

455 "tool:agent_chat", 

456 "tool:conversation_get", 

457 "tool:conversation_search", 

458 ], 

459 "conversation": [ 

460 # User-scoped conversations to maintain RBAC semantics 

461 f"conversation:{username}_demo_thread_1", 

462 f"conversation:{username}_demo_thread_2", 

463 f"conversation:{username}_demo_thread_3", 

464 f"conversation:{username}_sample_conversation", 

465 ], 

466 "user": [ 

467 "user:alice", 

468 "user:bob", 

469 "user:charlie", 

470 ], 

471 } 

472 

473 # Return mock data for the requested type 

474 return mock_data.get(resource_type, []) 

475 

476 async def list_accessible_resources(self, user_id: str, relation: str, resource_type: str) -> list[str]: 

477 """ 

478 List all resources user has access to 

479 

480 Args: 

481 user_id: User identifier 

482 relation: Relation to check (e.g., "executor", "viewer") 

483 resource_type: Type of resources (e.g., "tool", "conversation") 

484 

485 Returns: 

486 List of accessible resource identifiers 

487 """ 

488 if not self.openfga: 

489 # In development mode, return mock data for better developer experience 

490 # SECURITY: Mock data only enabled when explicitly configured (defaults to dev-only) 

491 try: 

492 from mcp_server_langgraph.core.config import settings 

493 

494 if settings.get_mock_authorization_enabled(): 494 ↛ 495line 494 didn't jump to line 495 because the condition on line 494 was never true

495 logger.info( 

496 "OpenFGA not available, using mock resources", 

497 extra={ 

498 "user_id": user_id, 

499 "relation": relation, 

500 "resource_type": resource_type, 

501 "environment": settings.environment, 

502 }, 

503 ) 

504 return self._get_mock_resources(user_id, relation, resource_type) 

505 except Exception: 

506 pass # If settings not available, fall through to empty list 

507 

508 logger.warning("OpenFGA not available for resource listing, no mock data enabled") 

509 return [] 

510 

511 try: 

512 resources = await self.openfga.list_objects(user=user_id, relation=relation, object_type=resource_type) 

513 

514 logger.info( 

515 "Listed accessible resources", 

516 extra={"user_id": user_id, "relation": relation, "resource_type": resource_type, "count": len(resources)}, 

517 ) 

518 

519 return resources 

520 

521 except Exception as e: 

522 logger.error(f"Failed to list accessible resources: {e}", exc_info=True) 

523 return [] 

524 

525 def create_token(self, username: str, expires_in: int = 3600) -> str: 

526 """ 

527 Create JWT token for user (InMemoryUserProvider only) 

528 

529 For Keycloak provider, tokens are issued by Keycloak itself. 

530 

531 Args: 

532 username: Username 

533 expires_in: Token expiration in seconds 

534 

535 Returns: 

536 JWT token string 

537 

538 Raises: 

539 ValueError: If user not found or provider doesn't support token creation 

540 """ 

541 # Check if provider supports token creation 

542 if isinstance(self.user_provider, InMemoryUserProvider): 542 ↛ 546line 542 didn't jump to line 546 because the condition on line 542 was always true

543 return self.user_provider.create_token(username, expires_in) 

544 

545 # For other providers, we can't create tokens 

546 msg = f"Token creation not supported for provider type: {type(self.user_provider).__name__}" 

547 raise ValueError(msg) 

548 

549 async def verify_token(self, token: str) -> TokenVerification: 

550 """ 

551 Verify and decode JWT token 

552 

553 Supports both self-issued tokens (InMemoryUserProvider) and 

554 Keycloak-issued tokens (KeycloakUserProvider). 

555 

556 Args: 

557 token: JWT token to verify 

558 

559 Returns: 

560 TokenVerification with validation result 

561 """ 

562 # Delegate to user provider (returns Pydantic TokenVerification) 

563 result = await self.user_provider.verify_token(token) 

564 

565 if result.valid: 

566 logger.info("Token verified", extra={"sub": result.payload.get("sub") if result.payload else None}) 

567 else: 

568 logger.warning("Token verification failed", extra={"error": result.error}) 

569 

570 return result 

571 

572 # Session Management Methods 

573 

574 async def create_session( 

575 self, 

576 user_id: str, 

577 username: str, 

578 roles: list[str], 

579 metadata: dict[str, Any] | None = None, 

580 ttl_seconds: int | None = None, 

581 ) -> str | None: 

582 """ 

583 Create a new session 

584 

585 Args: 

586 user_id: User identifier 

587 username: Username 

588 roles: User roles 

589 metadata: Additional session metadata 

590 ttl_seconds: Session TTL in seconds 

591 

592 Returns: 

593 Session ID or None if session store not configured 

594 """ 

595 if not self.session_store: 

596 logger.warning("Session store not configured, cannot create session") 

597 return None 

598 

599 with tracer.start_as_current_span("auth.create_session") as span: 

600 span.set_attribute("user.id", user_id) 

601 

602 session_id = await self.session_store.create( 

603 user_id=user_id, username=username, roles=roles, metadata=metadata, ttl_seconds=ttl_seconds 

604 ) 

605 

606 logger.info("Session created", extra={"session_id": session_id, "user_id": user_id}) 

607 return session_id 

608 

609 async def get_session(self, session_id: str) -> SessionData | None: 

610 """ 

611 Get session data 

612 

613 Args: 

614 session_id: Session identifier 

615 

616 Returns: 

617 Session data or None 

618 """ 

619 if not self.session_store: 

620 return None 

621 

622 with tracer.start_as_current_span("auth.get_session") as span: 

623 span.set_attribute("session.id", session_id) 

624 

625 session = await self.session_store.get(session_id) 

626 

627 if session: 

628 logger.debug(f"Session retrieved: {session_id}") 

629 else: 

630 logger.debug(f"Session not found or expired: {session_id}") 

631 

632 return session 

633 

634 async def refresh_session(self, session_id: str, ttl_seconds: int | None = None) -> bool: 

635 """ 

636 Refresh session expiration 

637 

638 Args: 

639 session_id: Session identifier 

640 ttl_seconds: New TTL in seconds 

641 

642 Returns: 

643 True if refreshed successfully 

644 """ 

645 if not self.session_store: 

646 return False 

647 

648 with tracer.start_as_current_span("auth.refresh_session") as span: 

649 span.set_attribute("session.id", session_id) 

650 

651 refreshed = await self.session_store.refresh(session_id, ttl_seconds) 

652 

653 if refreshed: 

654 logger.info(f"Session refreshed: {session_id}") 

655 else: 

656 logger.warning(f"Failed to refresh session: {session_id}") 

657 

658 return refreshed 

659 

660 async def revoke_session(self, session_id: str) -> bool: 

661 """ 

662 Revoke (delete) a session 

663 

664 Args: 

665 session_id: Session identifier 

666 

667 Returns: 

668 True if revoked successfully 

669 """ 

670 if not self.session_store: 

671 return False 

672 

673 with tracer.start_as_current_span("auth.revoke_session") as span: 

674 span.set_attribute("session.id", session_id) 

675 

676 revoked = await self.session_store.delete(session_id) 

677 

678 if revoked: 

679 logger.info(f"Session revoked: {session_id}") 

680 else: 

681 logger.warning(f"Session not found for revocation: {session_id}") 

682 

683 return revoked 

684 

685 async def list_user_sessions(self, user_id: str) -> list[SessionData]: 

686 """ 

687 List all active sessions for a user 

688 

689 Args: 

690 user_id: User identifier 

691 

692 Returns: 

693 List of session data 

694 """ 

695 if not self.session_store: 

696 return [] 

697 

698 with tracer.start_as_current_span("auth.list_user_sessions") as span: 

699 span.set_attribute("user.id", user_id) 

700 

701 sessions = await self.session_store.list_user_sessions(user_id) 

702 

703 logger.info(f"Listed {len(sessions)} sessions for user {user_id}") 

704 return sessions 

705 

706 async def revoke_user_sessions(self, user_id: str) -> int: 

707 """ 

708 Revoke all sessions for a user 

709 

710 Args: 

711 user_id: User identifier 

712 

713 Returns: 

714 Number of sessions revoked 

715 """ 

716 if not self.session_store: 

717 return 0 

718 

719 with tracer.start_as_current_span("auth.revoke_user_sessions") as span: 

720 span.set_attribute("user.id", user_id) 

721 

722 count = await self.session_store.delete_user_sessions(user_id) 

723 

724 logger.info(f"Revoked {count} sessions for user {user_id}") 

725 return count 

726 

727 

728def require_auth( # type: ignore[no-untyped-def] 

729 relation: str | None = None, 

730 resource: str | None = None, 

731 openfga_client: OpenFGAClient | None = None, 

732 auth_middleware: Optional["AuthMiddleware"] = None, 

733): 

734 """ 

735 Decorator for requiring authentication/authorization 

736 

737 Args: 

738 relation: Required relation (e.g., "executor") 

739 resource: Resource to check access to 

740 openfga_client: OpenFGA client instance 

741 auth_middleware: Optional AuthMiddleware instance (for testing with pre-seeded users) 

742 """ 

743 

744 def decorator(func) -> None: # type: ignore[no-untyped-def] 

745 @wraps(func) 

746 async def wrapper(*args, **kwargs) -> None: # type: ignore[no-untyped-def] 

747 # Use provided auth_middleware or create new instance 

748 auth = auth_middleware if auth_middleware is not None else AuthMiddleware(openfga_client=openfga_client) 

749 username = kwargs.get("username") 

750 password = kwargs.get("password") 

751 user_id = kwargs.get("user_id") 

752 

753 if not username and not user_id: 

754 msg = "Authentication required" 

755 raise PermissionError(msg) 

756 

757 # Authenticate 

758 if username: 758 ↛ 766line 758 didn't jump to line 766 because the condition on line 758 was always true

759 auth_result = await auth.authenticate(username, password) 

760 if not auth_result.authorized: 

761 msg = "Authentication failed" 

762 raise PermissionError(msg) 

763 user_id = auth_result.user_id 

764 

765 # Authorize if relation and resource specified 

766 if relation and resource: 

767 if not await auth.authorize(user_id, relation, resource): # type: ignore[arg-type] 

768 msg = f"Not authorized: {user_id} cannot {relation} {resource}" 

769 raise PermissionError(msg) 

770 

771 # Add user_id to kwargs if authenticated 

772 kwargs["user_id"] = user_id 

773 return await func(*args, **kwargs) # type: ignore[no-any-return] 

774 

775 return wrapper # type: ignore[return-value] 

776 

777 return decorator 

778 

779 

780async def verify_token(token: str, secret_key: str | None = None) -> TokenVerification: 

781 """ 

782 Standalone token verification function 

783 

784 Args: 

785 token: JWT token to verify 

786 secret_key: Secret key for verification 

787 

788 Returns: 

789 TokenVerification with validation result 

790 """ 

791 auth = AuthMiddleware(secret_key=secret_key or "your-secret-key-change-in-production") 

792 return await auth.verify_token(token) 

793 

794 

795# ============================================================================ 

796# FastAPI Dependency Injection Support 

797# ============================================================================ 

798 

799if FASTAPI_AVAILABLE: # noqa: C901 799 ↛ exitline 799 didn't exit the module because the condition on line 799 was always true

800 # Global auth middleware instance (set by application) 

801 _global_auth_middleware: AuthMiddleware | None = None 

802 

803 def set_global_auth_middleware(auth: AuthMiddleware) -> None: 

804 """ 

805 Set global auth middleware instance for FastAPI dependencies. 

806 

807 This should be called during application startup. 

808 

809 Args: 

810 auth: AuthMiddleware instance configured with user provider, OpenFGA, etc. 

811 """ 

812 global _global_auth_middleware 

813 _global_auth_middleware = auth 

814 

815 def get_auth_middleware() -> AuthMiddleware: 

816 """ 

817 Get global auth middleware instance. 

818 

819 Returns: 

820 AuthMiddleware instance 

821 

822 Raises: 

823 RuntimeError: If auth middleware not initialized 

824 """ 

825 if _global_auth_middleware is None: 825 ↛ 826line 825 didn't jump to line 826 because the condition on line 825 was never true

826 msg = "Auth middleware not initialized. Call set_global_auth_middleware() during app startup." 

827 raise RuntimeError(msg) 

828 return _global_auth_middleware 

829 

830 # HTTP Bearer security scheme for JWT tokens 

831 bearer_scheme = HTTPBearer(auto_error=False) 

832 

833 async def get_current_user( 

834 request: Request, 

835 ) -> dict[str, Any]: 

836 """ 

837 FastAPI dependency for extracting authenticated user from request. 

838 

839 Supports multiple authentication methods: 

840 1. JWT token in Authorization header (Bearer token) 

841 2. User already set in request.state.user (by middleware) 

842 

843 Args: 

844 request: FastAPI request object 

845 

846 Returns: 

847 User dict with user_id, username, roles, etc. 

848 

849 Raises: 

850 HTTPException: If authentication fails (401) 

851 """ 

852 # Check if user already set by middleware 

853 if hasattr(request.state, "user") and request.state.user: 

854 return request.state.user # type: ignore[no-any-return] 

855 

856 # Extract bearer token from Authorization header 

857 auth_header = request.headers.get("Authorization") 

858 token = None 

859 if auth_header and auth_header.startswith("Bearer "): 

860 token = auth_header[7:] # Remove "Bearer " prefix 

861 

862 # Try to authenticate with Bearer token 

863 if token: 

864 auth = get_auth_middleware() 

865 verification = await auth.verify_token(token) 

866 

867 if verification.valid and verification.payload: 

868 # Extract username: prefer preferred_username (Keycloak) over username over sub 

869 # Keycloak uses UUID in 'sub', but OpenFGA needs 'user:username' format 

870 # Extract Keycloak UUID from sub claim (required for Admin API calls) 

871 keycloak_id = verification.payload.get("sub") 

872 

873 # Priority: preferred_username (Keycloak) > username (InMemory) > extract from sub (fallback) 

874 username = verification.payload.get("preferred_username") or verification.payload.get("username") 

875 if not username: 

876 # Fallback to extracting username from sub (for non-Keycloak IdPs without username field) 

877 sub = keycloak_id or "unknown" 

878 # If sub is in "user:username" format, extract username 

879 if sub.startswith("user:"): 879 ↛ 887line 879 didn't jump to line 887 because the condition on line 879 was always true

880 id_part = sub.replace("user:", "") 

881 # Handle worker-safe IDs (e.g., "user:test_gw0_charlie" → "charlie") 

882 import re 

883 

884 match = re.match(r"test_gw\d+_(.*)", id_part) 

885 username = match.group(1) if match else id_part 

886 else: 

887 username = sub 

888 

889 # For user_id, use sub directly if it's already in "user:*" format, otherwise normalize from username 

890 # This preserves worker-safe IDs like "user:test_gw0_alice" from InMemoryUserProvider tokens 

891 if keycloak_id and keycloak_id.startswith("user:"): 

892 user_id = keycloak_id # Use sub directly (preserves worker-safe IDs) 

893 else: 

894 # Normalize to "user:username" format for OpenFGA compatibility 

895 user_id = f"user:{username}" if not username.startswith("user:") else username 

896 

897 user_data = { 

898 "user_id": user_id, 

899 "keycloak_id": keycloak_id, # Raw UUID for Keycloak Admin API 

900 "username": username, 

901 "roles": verification.payload.get("roles", []), 

902 "email": verification.payload.get("email"), 

903 } 

904 # Cache in request state for subsequent calls 

905 request.state.user = user_data 

906 return user_data 

907 else: 

908 raise HTTPException( 

909 status_code=status.HTTP_401_UNAUTHORIZED, 

910 detail=f"Invalid token: {verification.error}", 

911 headers={"WWW-Authenticate": "Bearer"}, 

912 ) 

913 

914 # No authentication provided 

915 raise HTTPException( 

916 status_code=status.HTTP_401_UNAUTHORIZED, 

917 detail="Authentication required", 

918 headers={"WWW-Authenticate": "Bearer"}, 

919 ) 

920 

921 async def get_current_user_with_auth( 

922 user: dict[str, Any], 

923 relation: str | None = None, 

924 resource: str | None = None, 

925 ) -> dict[str, Any]: 

926 """ 

927 FastAPI dependency for authenticated + authorized user. 

928 

929 Use this when you need both authentication and authorization. 

930 

931 Example: 

932 @app.get("/protected") 

933 async def protected_endpoint( 

934 user: Dict[str, Any] = Depends( 

935 lambda: get_current_user_with_auth(relation="viewer", resource="tool:chat") 

936 ) 

937 ): 

938 return {"user": user} 

939 

940 Args: 

941 user: User dict from get_current_user dependency 

942 relation: Required relation (e.g., "executor", "viewer") 

943 resource: Resource to check access to (e.g., "tool:chat") 

944 

945 Returns: 

946 User dict if authorized 

947 

948 Raises: 

949 HTTPException: If authorization fails (403) 

950 """ 

951 if relation and resource: 

952 auth = get_auth_middleware() 

953 user_id = user.get("user_id", "") 

954 

955 authorized = await auth.authorize(user_id=user_id, relation=relation, resource=resource) 

956 

957 if not authorized: 

958 raise HTTPException( 

959 status_code=status.HTTP_403_FORBIDDEN, 

960 detail=f"Not authorized: {user_id} cannot {relation} {resource}", 

961 ) 

962 

963 return user 

964 

965 def require_auth_dependency(relation: str | None = None, resource: str | None = None) -> None: 

966 """ 

967 Create a FastAPI dependency for authentication + authorization. 

968 

969 This replaces the @require_auth decorator for FastAPI routes. 

970 

971 Example: 

972 from fastapi import Depends 

973 

974 @app.get("/tools") 

975 async def list_tools(user: Dict = Depends(require_auth_dependency(relation="executor", resource="tool:*"))): 

976 return {"tools": [...]} 

977 

978 Args: 

979 relation: Required relation (e.g., "executor") 

980 resource: Resource to check access to 

981 

982 Returns: 

983 FastAPI dependency function 

984 """ 

985 

986 async def dependency( 

987 request: Request, 

988 credentials: HTTPAuthorizationCredentials | None = bearer_scheme, # type: ignore[assignment] 

989 ) -> dict[str, Any]: 

990 # Get authenticated user 

991 user = await get_current_user(request) 

992 

993 # Check authorization if required 

994 if relation and resource: 

995 auth = get_auth_middleware() 

996 authorized = await auth.authorize(user_id=user["user_id"], relation=relation, resource=resource) 

997 

998 if not authorized: 

999 raise HTTPException( 

1000 status_code=status.HTTP_403_FORBIDDEN, 

1001 detail=f"Not authorized: {user['user_id']} cannot {relation} {resource}", 

1002 ) 

1003 

1004 return user 

1005 

1006 return dependency # type: ignore[return-value]