Coverage for src / mcp_server_langgraph / api / auth_request_middleware.py: 77%

41 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2FastAPI Request Middleware for Authentication 

3 

4This middleware intercepts incoming HTTP requests, extracts Bearer tokens, 

5verifies them, and sets request.state.user for downstream handlers. 

6 

7Benefits of middleware approach over dependency injection: 

8- Runs before route handlers (earlier in request lifecycle) 

9- Simpler testing (no dependency override complexity) 

10- More explicit request flow 

11- Easier to understand and maintain 

12""" 

13 

14from typing import Any 

15 

16from starlette.middleware.base import BaseHTTPMiddleware 

17from starlette.requests import Request 

18 

19from mcp_server_langgraph.auth.middleware import AuthMiddleware 

20from mcp_server_langgraph.observability.telemetry import logger 

21 

22 

23class AuthRequestMiddleware(BaseHTTPMiddleware): 

24 """ 

25 FastAPI request middleware for JWT-based authentication. 

26 

27 Extracts Bearer tokens from Authorization headers, verifies them, 

28 and sets request.state.user for authenticated requests. 

29 

30 Usage: 

31 app = FastAPI() 

32 auth_middleware = AuthMiddleware(secret_key=settings.secret_key) 

33 app.add_middleware(AuthRequestMiddleware, auth_middleware=auth_middleware) 

34 """ 

35 

36 def __init__(self, app, auth_middleware: AuthMiddleware): # type: ignore[no-untyped-def] 

37 """ 

38 Initialize the auth request middleware. 

39 

40 Args: 

41 app: FastAPI application 

42 auth_middleware: AuthMiddleware instance for token verification 

43 """ 

44 super().__init__(app) 

45 self.auth_middleware = auth_middleware 

46 

47 async def dispatch(self, request: Request, call_next): # type: ignore[no-untyped-def] 

48 """ 

49 Process incoming request, extract and verify auth token. 

50 

51 Args: 

52 request: Incoming HTTP request 

53 call_next: Next middleware/handler in chain 

54 

55 Returns: 

56 HTTP response from downstream handlers 

57 """ 

58 # Extract Bearer token from Authorization header 

59 auth_header = request.headers.get("Authorization", "") 

60 token = None 

61 

62 if auth_header.startswith("Bearer "): 

63 token = auth_header[7:] # Remove "Bearer " prefix 

64 

65 # If token present, verify it and set request.state.user 

66 if token: 

67 try: 

68 verification = await self.auth_middleware.verify_token(token) 

69 

70 if verification.valid and verification.payload: 

71 # Extract user information from token payload 

72 user_data = self._extract_user_from_payload(verification.payload) 

73 request.state.user = user_data 

74 

75 logger.debug( 

76 "Request authenticated via middleware", 

77 extra={ 

78 "user_id": user_data.get("user_id"), 

79 "username": user_data.get("username"), 

80 "path": request.url.path, 

81 }, 

82 ) 

83 else: 

84 # Invalid token - don't set request.state.user 

85 # Let endpoints decide how to handle unauthenticated requests 

86 logger.debug( 

87 "Token verification failed", 

88 extra={ 

89 "error": verification.error, 

90 "path": request.url.path, 

91 }, 

92 ) 

93 

94 except Exception as e: 

95 # Token verification error - don't set request.state.user 

96 logger.warning( 

97 f"Token verification exception: {e}", 

98 extra={"path": request.url.path}, 

99 exc_info=True, 

100 ) 

101 

102 # Continue to next middleware/handler 

103 # If authentication failed, request.state.user won't be set, 

104 # and endpoints can return 401 as needed 

105 response = await call_next(request) 

106 return response 

107 

108 def _extract_user_from_payload(self, payload: dict[str, Any]) -> dict[str, Any]: 

109 """ 

110 Extract user information from JWT payload. 

111 

112 Handles both InMemory tokens and Keycloak tokens with proper field mapping. 

113 

114 Args: 

115 payload: JWT token payload 

116 

117 Returns: 

118 User data dict with user_id, username, roles, etc. 

119 """ 

120 # Extract Keycloak UUID from sub claim (if present) 

121 keycloak_id = payload.get("sub") 

122 

123 # Priority: preferred_username (Keycloak) > username (InMemory) > extract from sub (fallback) 

124 username = payload.get("preferred_username") or payload.get("username") 

125 

126 if not username: 126 ↛ 128line 126 didn't jump to line 128 because the condition on line 126 was never true

127 # Fallback to extracting username from sub 

128 sub = keycloak_id or "unknown" 

129 

130 # If sub is in "user:username" format, extract username 

131 if sub.startswith("user:"): 

132 id_part = sub.replace("user:", "") 

133 

134 # Handle worker-safe IDs (e.g., "user:test_gw0_charlie" → "charlie") 

135 import re 

136 

137 match = re.match(r"test_gw\d+_(.*)", id_part) 

138 username = match.group(1) if match else id_part 

139 else: 

140 username = sub 

141 

142 # For user_id, use sub directly if it's already in "user:*" format, otherwise normalize from username 

143 # This preserves worker-safe IDs like "user:test_gw0_alice" from InMemoryUserProvider tokens 

144 if keycloak_id and keycloak_id.startswith("user:"): 

145 user_id = keycloak_id # Use sub directly (preserves worker-safe IDs) 

146 else: 

147 # Normalize to "user:username" format for OpenFGA compatibility 

148 user_id = f"user:{username}" if not username.startswith("user:") else username 

149 

150 return { 

151 "user_id": user_id, 

152 "keycloak_id": keycloak_id, # Raw UUID for Keycloak Admin API 

153 "username": username, 

154 "roles": payload.get("roles", []), 

155 "email": payload.get("email"), 

156 }