Coverage for src / mcp_server_langgraph / auth / middleware.py: 70%
301 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Authentication and Authorization middleware with OpenFGA integration
4Now supports:
5- Pluggable user providers (InMemory, Keycloak, custom)
6- Session management (Token-based or Session-based)
7- Fine-grained authorization via OpenFGA
8"""
10from functools import wraps
11from typing import Any, Optional
13from pydantic import BaseModel, ConfigDict, Field
15from mcp_server_langgraph.auth.openfga import OpenFGAClient
16from mcp_server_langgraph.auth.session import SessionData, SessionStore
17from mcp_server_langgraph.auth.user_provider import AuthResponse, InMemoryUserProvider, TokenVerification, UserProvider
18from mcp_server_langgraph.observability.telemetry import logger, tracer
20# FastAPI imports for dependency injection (optional, only if using FastAPI endpoints)
21try:
22 from fastapi import Depends, HTTPException, Request, status # noqa: F401 (Depends used at line 929, 970)
23 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
25 FASTAPI_AVAILABLE = True
26except ImportError:
27 FASTAPI_AVAILABLE = False
29# ============================================================================
30# Helper Functions
31# ============================================================================
34def normalize_user_id(user_id: str) -> str:
35 """
36 Normalize user_id to handle multiple formats.
38 Accepts:
39 - Plain usernames: "alice" → "alice"
40 - Prefixed IDs: "user:alice" → "alice"
41 - Other prefixes: "uid:123" → "123"
43 This allows clients to use either format:
44 - OpenFGA format (user:alice)
45 - Simple username format (alice)
47 Args:
48 user_id: User identifier in any supported format
50 Returns:
51 Normalized username (without prefix)
52 """
53 if not user_id:
54 return user_id
56 # If contains colon, extract the part after the colon
57 if ":" in user_id:
58 return user_id.split(":", 1)[1]
60 # Otherwise, return as-is
61 return user_id
64# ============================================================================
65# Pydantic Models for Middleware Operations
66# ============================================================================
69class AuthorizationResult(BaseModel):
70 """
71 Type-safe authorization check result
73 Returned from authorize() operations.
74 """
76 authorized: bool = Field(..., description="Whether access is authorized")
77 user_id: str = Field(..., description="User identifier that was checked")
78 relation: str = Field(..., description="Relation that was checked")
79 resource: str = Field(..., description="Resource that was checked")
80 reason: str | None = Field(None, description="Reason for denial if not authorized")
81 used_fallback: bool = Field(default=False, description="Whether fallback authorization was used")
83 model_config = ConfigDict(
84 frozen=False,
85 validate_assignment=True,
86 str_strip_whitespace=True,
87 json_schema_extra={
88 "example": {
89 "authorized": True,
90 "user_id": "user:alice",
91 "relation": "executor",
92 "resource": "tool:chat",
93 "reason": None,
94 "used_fallback": False,
95 }
96 },
97 )
99 def to_dict(self) -> dict[str, Any]:
100 """Convert to dictionary for backward compatibility"""
101 return self.model_dump(exclude_none=True)
103 @classmethod
104 def from_dict(cls, data: dict[str, Any]) -> "AuthorizationResult":
105 """Create AuthorizationResult from dictionary"""
106 return cls(**data)
109class AuthMiddleware:
110 """
111 Authentication and authorization handler with OpenFGA
113 Combines authentication (via pluggable user providers) with fine-grained
114 relationship-based authorization using OpenFGA.
116 Supports multiple authentication backends:
117 - InMemoryUserProvider (development/testing)
118 - KeycloakUserProvider (production)
119 - Custom providers
120 """
122 def __init__(
123 self,
124 secret_key: str | None = None,
125 openfga_client: OpenFGAClient | None = None,
126 user_provider: UserProvider | None = None,
127 session_store: SessionStore | None = None,
128 settings: Any | None = None,
129 ):
130 """
131 Initialize AuthMiddleware
133 Args:
134 secret_key: Secret key for JWT tokens (used by InMemoryUserProvider).
135 Must be provided via environment variable or settings.
136 openfga_client: OpenFGA client for authorization
137 user_provider: User provider instance (defaults to InMemoryUserProvider for backward compatibility)
138 session_store: Session store for session-based authentication (optional)
139 settings: Application settings (for authorization fallback control)
140 """
141 self.secret_key = secret_key
142 self.openfga = openfga_client
143 self.session_store = session_store
144 self.settings = settings
146 # Use provided user provider or default to in-memory for backward compatibility
147 if user_provider is None:
148 logger.info("No user provider specified, defaulting to InMemoryUserProvider")
149 user_provider = InMemoryUserProvider(secret_key=secret_key)
151 self.user_provider = user_provider
153 # For backward compatibility: expose users_db if using InMemoryUserProvider
154 if isinstance(user_provider, InMemoryUserProvider):
155 self.users_db = user_provider.users_db
156 else:
157 self.users_db = {} # Empty dict for non-inmemory providers
159 logger.info(
160 "AuthMiddleware initialized",
161 extra={
162 "provider_type": type(user_provider).__name__,
163 "openfga_enabled": openfga_client is not None,
164 "session_enabled": session_store is not None,
165 "allow_auth_fallback": getattr(settings, "allow_auth_fallback", None) if settings else None,
166 },
167 )
169 async def authenticate(self, username: str, password: str | None = None) -> AuthResponse:
170 """
171 Authenticate user by username
173 Args:
174 username: Username to authenticate (accepts both "alice" and "user:alice" formats)
175 password: Password (required for some providers like Keycloak)
177 Returns:
178 AuthResponse with authentication result
179 """
180 with tracer.start_as_current_span("auth.authenticate") as span:
181 # Normalize username to handle both "alice" and "user:alice" formats
182 normalized_username = normalize_user_id(username)
183 span.set_attribute("auth.username", normalized_username)
185 # Delegate to user provider (returns Pydantic AuthResponse)
186 result = await self.user_provider.authenticate(normalized_username, password)
188 if result.authorized:
189 from mcp_server_langgraph.core.security import sanitize_for_logging
191 logger.info(
192 "User authenticated",
193 extra=sanitize_for_logging({"username": username, "user_id": result.user_id}),
194 )
195 else:
196 from mcp_server_langgraph.core.security import sanitize_for_logging
198 logger.warning(
199 "Authentication failed", extra=sanitize_for_logging({"username": username, "reason": result.reason})
200 )
202 return result
204 async def authorize(self, user_id: str, relation: str, resource: str, context: dict[str, Any] | None = None) -> bool:
205 """
206 Check if user is authorized using OpenFGA
208 Args:
209 user_id: User identifier (e.g., "user:alice")
210 relation: Relation to check (e.g., "executor", "viewer")
211 resource: Resource identifier (e.g., "tool:chat")
212 context: Additional context for authorization
214 Returns:
215 True if authorized, False otherwise
216 """
217 with tracer.start_as_current_span("auth.authorize") as span:
218 span.set_attribute("user.id", user_id)
219 span.set_attribute("auth.relation", relation)
220 span.set_attribute("auth.resource", resource)
222 # Use OpenFGA if available
223 if self.openfga:
224 try:
225 authorized = await self.openfga.check_permission(
226 user=user_id, relation=relation, object=resource, context=context
227 )
229 span.set_attribute("auth.authorized", authorized)
230 logger.info(
231 "Authorization check (OpenFGA)",
232 extra={"user_id": user_id, "relation": relation, "resource": resource, "authorized": authorized},
233 )
235 return authorized
237 except Exception as e:
238 logger.error(
239 f"OpenFGA authorization check failed: {e}",
240 extra={"user_id": user_id, "relation": relation, "resource": resource},
241 exc_info=True,
242 )
243 # Fail closed - deny access on error
244 return False
246 # SECURITY CONTROL (OpenAI Codex Finding #1): Check if fallback authorization is allowed
247 # When OpenFGA is not available, check configuration to determine if we should:
248 # 1. Fail closed (deny all access) - secure default for production
249 # 2. Fall back to role-based checks - only if explicitly enabled for dev/test
251 allow_fallback = getattr(self.settings, "allow_auth_fallback", False) if self.settings else False
252 environment = getattr(self.settings, "environment", "production") if self.settings else "production"
254 # Defense in depth: NEVER allow fallback in production, even if misconfigured
255 if environment == "production":
256 logger.error(
257 "Authorization DENIED: OpenFGA unavailable in production environment. "
258 "Fallback authorization is not permitted in production for security reasons.",
259 extra={
260 "user_id": user_id,
261 "relation": relation,
262 "resource": resource,
263 "environment": environment,
264 "allow_auth_fallback": allow_fallback,
265 },
266 )
267 return False
269 # Check if fallback is explicitly enabled
270 if not allow_fallback: 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 logger.warning(
272 "Authorization DENIED: OpenFGA unavailable and fallback authorization is disabled. "
273 "Set ALLOW_AUTH_FALLBACK=true to enable role-based fallback in development/test.",
274 extra={
275 "user_id": user_id,
276 "relation": relation,
277 "resource": resource,
278 "allow_auth_fallback": allow_fallback,
279 "environment": environment,
280 },
281 )
282 return False
284 # Fallback: simple permission check (only when explicitly allowed in non-production)
285 logger.warning(
286 "OpenFGA not available, using fallback authorization (explicitly enabled)",
287 extra={
288 "allow_auth_fallback": allow_fallback,
289 "environment": environment,
290 },
291 )
293 # Extract username from user_id (handle worker-safe IDs for pytest-xdist)
294 # InMemoryUserProvider is test-only, so this test-specific logic is acceptable
295 # Examples: "user:alice" → "alice", "user:test_gw0_alice" → "alice"
296 if ":" in user_id: 296 ↛ 304line 296 didn't jump to line 304 because the condition on line 296 was always true
297 id_part = user_id.split(":", 1)[1] # Remove "user:" prefix
298 # Check if it's a worker-safe ID (format: test_gw\d+_username)
299 import re
301 match = re.match(r"test_gw\d+_(.*)", id_part)
302 username = match.group(1) if match else id_part
303 else:
304 username = user_id
306 # Get user data - try in-memory first, then query provider
307 user_data = None
308 user_roles = []
310 if isinstance(self.user_provider, InMemoryUserProvider):
311 # Fast path: Use in-memory users_db
312 if username not in self.users_db:
313 logger.warning(
314 "Fallback authorization denied - user not found",
315 extra={"user_id": user_id, "username": username, "provider": "InMemory"},
316 )
317 return False
318 user = self.users_db[username]
319 user_roles = user["roles"]
321 else:
322 # For external providers (Keycloak, etc.): Query the provider
323 try:
324 user_data = await self.user_provider.get_user_by_username(username)
325 if not user_data:
326 logger.warning(
327 "Fallback authorization denied - user not found in provider",
328 extra={"user_id": user_id, "username": username, "provider": type(self.user_provider).__name__},
329 )
330 return False
331 user_roles = user_data.roles
332 logger.info(
333 "Fetched user from provider for fallback authorization",
334 extra={
335 "user_id": user_id,
336 "username": username,
337 "provider": type(self.user_provider).__name__,
338 "roles": user_roles,
339 },
340 )
341 except Exception as e:
342 logger.error(
343 f"Failed to fetch user from provider for fallback authorization: {e}",
344 extra={"user_id": user_id, "username": username, "provider": type(self.user_provider).__name__},
345 exc_info=True,
346 )
347 # Fail closed - deny access if we can't verify user
348 return False
350 # Admin users have access to everything
351 if "admin" in user_roles:
352 logger.info(
353 "Fallback authorization granted - admin user",
354 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource},
355 )
356 return True
358 # Basic resource-based checks
359 if relation == "executor" and resource.startswith("tool:"):
360 authorized = "premium" in user_roles or "user" in user_roles
361 if authorized:
362 logger.info(
363 "Fallback authorization granted - tool executor",
364 extra={
365 "user_id": user_id,
366 "username": username,
367 "relation": relation,
368 "resource": resource,
369 "roles": user_roles,
370 },
371 )
372 return authorized
374 if relation in ("viewer", "editor") and resource.startswith("conversation:"):
375 # SECURITY: Scope conversation access by ownership in fallback mode
376 # Extract thread_id from resource (format: "conversation:thread_id")
377 thread_id = resource.split(":", 1)[1] if ":" in resource else ""
379 # Allow access only if:
380 # 1. Thread is the default/unnamed thread
381 # 2. Thread explicitly belongs to this user (prefixed with username)
382 # 3. User is accessing their own user-scoped conversations
383 if thread_id == "default" or thread_id == "":
384 logger.info(
385 "Fallback authorization granted - default conversation",
386 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource},
387 )
388 return True
390 # Check if conversation belongs to this user
391 # Format: "conversation:username_thread" or "conversation:user:username_thread"
392 if thread_id.startswith(f"{username}_"):
393 logger.info(
394 "Fallback authorization granted - user-owned conversation",
395 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource},
396 )
397 return True
399 # Also support user:username prefix in thread_id
400 user_id_normalized = user_id.split(":")[-1] if ":" in user_id else user_id
401 if thread_id.startswith(f"{user_id_normalized}_"): 401 ↛ 402line 401 didn't jump to line 402 because the condition on line 401 was never true
402 logger.info(
403 "Fallback authorization granted - user-owned conversation (normalized)",
404 extra={"user_id": user_id, "username": username, "relation": relation, "resource": resource},
405 )
406 return True
408 # Deny access to conversations not owned by this user
409 logger.warning(
410 "Fallback authorization denied conversation access",
411 extra={
412 "user_id": user_id,
413 "username": username,
414 "thread_id": thread_id,
415 "relation": relation,
416 "reason": "conversation_not_owned_by_user",
417 },
418 )
419 return False
421 # Default deny
422 logger.warning(
423 "Fallback authorization denied - no matching rule",
424 extra={
425 "user_id": user_id,
426 "username": username,
427 "relation": relation,
428 "resource": resource,
429 "roles": user_roles,
430 },
431 )
432 return False
434 def _get_mock_resources(self, user_id: str, relation: str, resource_type: str) -> list[str]:
435 """
436 Get mock resources for development/testing when OpenFGA is not available.
438 Provides sample data to enable development and testing without authorization infrastructure.
439 Resources are scoped per user to maintain proper RBAC semantics.
441 Args:
442 user_id: User identifier (used to scope conversation resources)
443 relation: Relation to check (e.g., "executor", "viewer")
444 resource_type: Type of resources (e.g., "tool", "conversation")
446 Returns:
447 List of mock resource identifiers scoped to the user
448 """
449 # Extract username from user_id (handle both "user:alice" and "alice" formats)
450 username = user_id.split(":")[-1] if ":" in user_id else user_id
452 # Mock data for different resource types
453 mock_data = {
454 "tool": [
455 "tool:agent_chat",
456 "tool:conversation_get",
457 "tool:conversation_search",
458 ],
459 "conversation": [
460 # User-scoped conversations to maintain RBAC semantics
461 f"conversation:{username}_demo_thread_1",
462 f"conversation:{username}_demo_thread_2",
463 f"conversation:{username}_demo_thread_3",
464 f"conversation:{username}_sample_conversation",
465 ],
466 "user": [
467 "user:alice",
468 "user:bob",
469 "user:charlie",
470 ],
471 }
473 # Return mock data for the requested type
474 return mock_data.get(resource_type, [])
476 async def list_accessible_resources(self, user_id: str, relation: str, resource_type: str) -> list[str]:
477 """
478 List all resources user has access to
480 Args:
481 user_id: User identifier
482 relation: Relation to check (e.g., "executor", "viewer")
483 resource_type: Type of resources (e.g., "tool", "conversation")
485 Returns:
486 List of accessible resource identifiers
487 """
488 if not self.openfga:
489 # In development mode, return mock data for better developer experience
490 # SECURITY: Mock data only enabled when explicitly configured (defaults to dev-only)
491 try:
492 from mcp_server_langgraph.core.config import settings
494 if settings.get_mock_authorization_enabled(): 494 ↛ 495line 494 didn't jump to line 495 because the condition on line 494 was never true
495 logger.info(
496 "OpenFGA not available, using mock resources",
497 extra={
498 "user_id": user_id,
499 "relation": relation,
500 "resource_type": resource_type,
501 "environment": settings.environment,
502 },
503 )
504 return self._get_mock_resources(user_id, relation, resource_type)
505 except Exception:
506 pass # If settings not available, fall through to empty list
508 logger.warning("OpenFGA not available for resource listing, no mock data enabled")
509 return []
511 try:
512 resources = await self.openfga.list_objects(user=user_id, relation=relation, object_type=resource_type)
514 logger.info(
515 "Listed accessible resources",
516 extra={"user_id": user_id, "relation": relation, "resource_type": resource_type, "count": len(resources)},
517 )
519 return resources
521 except Exception as e:
522 logger.error(f"Failed to list accessible resources: {e}", exc_info=True)
523 return []
525 def create_token(self, username: str, expires_in: int = 3600) -> str:
526 """
527 Create JWT token for user (InMemoryUserProvider only)
529 For Keycloak provider, tokens are issued by Keycloak itself.
531 Args:
532 username: Username
533 expires_in: Token expiration in seconds
535 Returns:
536 JWT token string
538 Raises:
539 ValueError: If user not found or provider doesn't support token creation
540 """
541 # Check if provider supports token creation
542 if isinstance(self.user_provider, InMemoryUserProvider): 542 ↛ 546line 542 didn't jump to line 546 because the condition on line 542 was always true
543 return self.user_provider.create_token(username, expires_in)
545 # For other providers, we can't create tokens
546 msg = f"Token creation not supported for provider type: {type(self.user_provider).__name__}"
547 raise ValueError(msg)
549 async def verify_token(self, token: str) -> TokenVerification:
550 """
551 Verify and decode JWT token
553 Supports both self-issued tokens (InMemoryUserProvider) and
554 Keycloak-issued tokens (KeycloakUserProvider).
556 Args:
557 token: JWT token to verify
559 Returns:
560 TokenVerification with validation result
561 """
562 # Delegate to user provider (returns Pydantic TokenVerification)
563 result = await self.user_provider.verify_token(token)
565 if result.valid:
566 logger.info("Token verified", extra={"sub": result.payload.get("sub") if result.payload else None})
567 else:
568 logger.warning("Token verification failed", extra={"error": result.error})
570 return result
572 # Session Management Methods
574 async def create_session(
575 self,
576 user_id: str,
577 username: str,
578 roles: list[str],
579 metadata: dict[str, Any] | None = None,
580 ttl_seconds: int | None = None,
581 ) -> str | None:
582 """
583 Create a new session
585 Args:
586 user_id: User identifier
587 username: Username
588 roles: User roles
589 metadata: Additional session metadata
590 ttl_seconds: Session TTL in seconds
592 Returns:
593 Session ID or None if session store not configured
594 """
595 if not self.session_store:
596 logger.warning("Session store not configured, cannot create session")
597 return None
599 with tracer.start_as_current_span("auth.create_session") as span:
600 span.set_attribute("user.id", user_id)
602 session_id = await self.session_store.create(
603 user_id=user_id, username=username, roles=roles, metadata=metadata, ttl_seconds=ttl_seconds
604 )
606 logger.info("Session created", extra={"session_id": session_id, "user_id": user_id})
607 return session_id
609 async def get_session(self, session_id: str) -> SessionData | None:
610 """
611 Get session data
613 Args:
614 session_id: Session identifier
616 Returns:
617 Session data or None
618 """
619 if not self.session_store:
620 return None
622 with tracer.start_as_current_span("auth.get_session") as span:
623 span.set_attribute("session.id", session_id)
625 session = await self.session_store.get(session_id)
627 if session:
628 logger.debug(f"Session retrieved: {session_id}")
629 else:
630 logger.debug(f"Session not found or expired: {session_id}")
632 return session
634 async def refresh_session(self, session_id: str, ttl_seconds: int | None = None) -> bool:
635 """
636 Refresh session expiration
638 Args:
639 session_id: Session identifier
640 ttl_seconds: New TTL in seconds
642 Returns:
643 True if refreshed successfully
644 """
645 if not self.session_store:
646 return False
648 with tracer.start_as_current_span("auth.refresh_session") as span:
649 span.set_attribute("session.id", session_id)
651 refreshed = await self.session_store.refresh(session_id, ttl_seconds)
653 if refreshed:
654 logger.info(f"Session refreshed: {session_id}")
655 else:
656 logger.warning(f"Failed to refresh session: {session_id}")
658 return refreshed
660 async def revoke_session(self, session_id: str) -> bool:
661 """
662 Revoke (delete) a session
664 Args:
665 session_id: Session identifier
667 Returns:
668 True if revoked successfully
669 """
670 if not self.session_store:
671 return False
673 with tracer.start_as_current_span("auth.revoke_session") as span:
674 span.set_attribute("session.id", session_id)
676 revoked = await self.session_store.delete(session_id)
678 if revoked:
679 logger.info(f"Session revoked: {session_id}")
680 else:
681 logger.warning(f"Session not found for revocation: {session_id}")
683 return revoked
685 async def list_user_sessions(self, user_id: str) -> list[SessionData]:
686 """
687 List all active sessions for a user
689 Args:
690 user_id: User identifier
692 Returns:
693 List of session data
694 """
695 if not self.session_store:
696 return []
698 with tracer.start_as_current_span("auth.list_user_sessions") as span:
699 span.set_attribute("user.id", user_id)
701 sessions = await self.session_store.list_user_sessions(user_id)
703 logger.info(f"Listed {len(sessions)} sessions for user {user_id}")
704 return sessions
706 async def revoke_user_sessions(self, user_id: str) -> int:
707 """
708 Revoke all sessions for a user
710 Args:
711 user_id: User identifier
713 Returns:
714 Number of sessions revoked
715 """
716 if not self.session_store:
717 return 0
719 with tracer.start_as_current_span("auth.revoke_user_sessions") as span:
720 span.set_attribute("user.id", user_id)
722 count = await self.session_store.delete_user_sessions(user_id)
724 logger.info(f"Revoked {count} sessions for user {user_id}")
725 return count
728def require_auth( # type: ignore[no-untyped-def]
729 relation: str | None = None,
730 resource: str | None = None,
731 openfga_client: OpenFGAClient | None = None,
732 auth_middleware: Optional["AuthMiddleware"] = None,
733):
734 """
735 Decorator for requiring authentication/authorization
737 Args:
738 relation: Required relation (e.g., "executor")
739 resource: Resource to check access to
740 openfga_client: OpenFGA client instance
741 auth_middleware: Optional AuthMiddleware instance (for testing with pre-seeded users)
742 """
744 def decorator(func) -> None: # type: ignore[no-untyped-def]
745 @wraps(func)
746 async def wrapper(*args, **kwargs) -> None: # type: ignore[no-untyped-def]
747 # Use provided auth_middleware or create new instance
748 auth = auth_middleware if auth_middleware is not None else AuthMiddleware(openfga_client=openfga_client)
749 username = kwargs.get("username")
750 password = kwargs.get("password")
751 user_id = kwargs.get("user_id")
753 if not username and not user_id:
754 msg = "Authentication required"
755 raise PermissionError(msg)
757 # Authenticate
758 if username: 758 ↛ 766line 758 didn't jump to line 766 because the condition on line 758 was always true
759 auth_result = await auth.authenticate(username, password)
760 if not auth_result.authorized:
761 msg = "Authentication failed"
762 raise PermissionError(msg)
763 user_id = auth_result.user_id
765 # Authorize if relation and resource specified
766 if relation and resource:
767 if not await auth.authorize(user_id, relation, resource): # type: ignore[arg-type]
768 msg = f"Not authorized: {user_id} cannot {relation} {resource}"
769 raise PermissionError(msg)
771 # Add user_id to kwargs if authenticated
772 kwargs["user_id"] = user_id
773 return await func(*args, **kwargs) # type: ignore[no-any-return]
775 return wrapper # type: ignore[return-value]
777 return decorator
780async def verify_token(token: str, secret_key: str | None = None) -> TokenVerification:
781 """
782 Standalone token verification function
784 Args:
785 token: JWT token to verify
786 secret_key: Secret key for verification
788 Returns:
789 TokenVerification with validation result
790 """
791 auth = AuthMiddleware(secret_key=secret_key or "your-secret-key-change-in-production")
792 return await auth.verify_token(token)
795# ============================================================================
796# FastAPI Dependency Injection Support
797# ============================================================================
799if FASTAPI_AVAILABLE: # noqa: C901 799 ↛ exitline 799 didn't exit the module because the condition on line 799 was always true
800 # Global auth middleware instance (set by application)
801 _global_auth_middleware: AuthMiddleware | None = None
803 def set_global_auth_middleware(auth: AuthMiddleware) -> None:
804 """
805 Set global auth middleware instance for FastAPI dependencies.
807 This should be called during application startup.
809 Args:
810 auth: AuthMiddleware instance configured with user provider, OpenFGA, etc.
811 """
812 global _global_auth_middleware
813 _global_auth_middleware = auth
815 def get_auth_middleware() -> AuthMiddleware:
816 """
817 Get global auth middleware instance.
819 Returns:
820 AuthMiddleware instance
822 Raises:
823 RuntimeError: If auth middleware not initialized
824 """
825 if _global_auth_middleware is None: 825 ↛ 826line 825 didn't jump to line 826 because the condition on line 825 was never true
826 msg = "Auth middleware not initialized. Call set_global_auth_middleware() during app startup."
827 raise RuntimeError(msg)
828 return _global_auth_middleware
830 # HTTP Bearer security scheme for JWT tokens
831 bearer_scheme = HTTPBearer(auto_error=False)
833 async def get_current_user(
834 request: Request,
835 ) -> dict[str, Any]:
836 """
837 FastAPI dependency for extracting authenticated user from request.
839 Supports multiple authentication methods:
840 1. JWT token in Authorization header (Bearer token)
841 2. User already set in request.state.user (by middleware)
843 Args:
844 request: FastAPI request object
846 Returns:
847 User dict with user_id, username, roles, etc.
849 Raises:
850 HTTPException: If authentication fails (401)
851 """
852 # Check if user already set by middleware
853 if hasattr(request.state, "user") and request.state.user:
854 return request.state.user # type: ignore[no-any-return]
856 # Extract bearer token from Authorization header
857 auth_header = request.headers.get("Authorization")
858 token = None
859 if auth_header and auth_header.startswith("Bearer "):
860 token = auth_header[7:] # Remove "Bearer " prefix
862 # Try to authenticate with Bearer token
863 if token:
864 auth = get_auth_middleware()
865 verification = await auth.verify_token(token)
867 if verification.valid and verification.payload:
868 # Extract username: prefer preferred_username (Keycloak) over username over sub
869 # Keycloak uses UUID in 'sub', but OpenFGA needs 'user:username' format
870 # Extract Keycloak UUID from sub claim (required for Admin API calls)
871 keycloak_id = verification.payload.get("sub")
873 # Priority: preferred_username (Keycloak) > username (InMemory) > extract from sub (fallback)
874 username = verification.payload.get("preferred_username") or verification.payload.get("username")
875 if not username:
876 # Fallback to extracting username from sub (for non-Keycloak IdPs without username field)
877 sub = keycloak_id or "unknown"
878 # If sub is in "user:username" format, extract username
879 if sub.startswith("user:"): 879 ↛ 887line 879 didn't jump to line 887 because the condition on line 879 was always true
880 id_part = sub.replace("user:", "")
881 # Handle worker-safe IDs (e.g., "user:test_gw0_charlie" → "charlie")
882 import re
884 match = re.match(r"test_gw\d+_(.*)", id_part)
885 username = match.group(1) if match else id_part
886 else:
887 username = sub
889 # For user_id, use sub directly if it's already in "user:*" format, otherwise normalize from username
890 # This preserves worker-safe IDs like "user:test_gw0_alice" from InMemoryUserProvider tokens
891 if keycloak_id and keycloak_id.startswith("user:"):
892 user_id = keycloak_id # Use sub directly (preserves worker-safe IDs)
893 else:
894 # Normalize to "user:username" format for OpenFGA compatibility
895 user_id = f"user:{username}" if not username.startswith("user:") else username
897 user_data = {
898 "user_id": user_id,
899 "keycloak_id": keycloak_id, # Raw UUID for Keycloak Admin API
900 "username": username,
901 "roles": verification.payload.get("roles", []),
902 "email": verification.payload.get("email"),
903 }
904 # Cache in request state for subsequent calls
905 request.state.user = user_data
906 return user_data
907 else:
908 raise HTTPException(
909 status_code=status.HTTP_401_UNAUTHORIZED,
910 detail=f"Invalid token: {verification.error}",
911 headers={"WWW-Authenticate": "Bearer"},
912 )
914 # No authentication provided
915 raise HTTPException(
916 status_code=status.HTTP_401_UNAUTHORIZED,
917 detail="Authentication required",
918 headers={"WWW-Authenticate": "Bearer"},
919 )
921 async def get_current_user_with_auth(
922 user: dict[str, Any],
923 relation: str | None = None,
924 resource: str | None = None,
925 ) -> dict[str, Any]:
926 """
927 FastAPI dependency for authenticated + authorized user.
929 Use this when you need both authentication and authorization.
931 Example:
932 @app.get("/protected")
933 async def protected_endpoint(
934 user: Dict[str, Any] = Depends(
935 lambda: get_current_user_with_auth(relation="viewer", resource="tool:chat")
936 )
937 ):
938 return {"user": user}
940 Args:
941 user: User dict from get_current_user dependency
942 relation: Required relation (e.g., "executor", "viewer")
943 resource: Resource to check access to (e.g., "tool:chat")
945 Returns:
946 User dict if authorized
948 Raises:
949 HTTPException: If authorization fails (403)
950 """
951 if relation and resource:
952 auth = get_auth_middleware()
953 user_id = user.get("user_id", "")
955 authorized = await auth.authorize(user_id=user_id, relation=relation, resource=resource)
957 if not authorized:
958 raise HTTPException(
959 status_code=status.HTTP_403_FORBIDDEN,
960 detail=f"Not authorized: {user_id} cannot {relation} {resource}",
961 )
963 return user
965 def require_auth_dependency(relation: str | None = None, resource: str | None = None) -> None:
966 """
967 Create a FastAPI dependency for authentication + authorization.
969 This replaces the @require_auth decorator for FastAPI routes.
971 Example:
972 from fastapi import Depends
974 @app.get("/tools")
975 async def list_tools(user: Dict = Depends(require_auth_dependency(relation="executor", resource="tool:*"))):
976 return {"tools": [...]}
978 Args:
979 relation: Required relation (e.g., "executor")
980 resource: Resource to check access to
982 Returns:
983 FastAPI dependency function
984 """
986 async def dependency(
987 request: Request,
988 credentials: HTTPAuthorizationCredentials | None = bearer_scheme, # type: ignore[assignment]
989 ) -> dict[str, Any]:
990 # Get authenticated user
991 user = await get_current_user(request)
993 # Check authorization if required
994 if relation and resource:
995 auth = get_auth_middleware()
996 authorized = await auth.authorize(user_id=user["user_id"], relation=relation, resource=resource)
998 if not authorized:
999 raise HTTPException(
1000 status_code=status.HTTP_403_FORBIDDEN,
1001 detail=f"Not authorized: {user['user_id']} cannot {relation} {resource}",
1002 )
1004 return user
1006 return dependency # type: ignore[return-value]