Coverage for src / mcp_server_langgraph / middleware / session_timeout.py: 94%
77 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Session Timeout Middleware - HIPAA 164.312(a)(2)(iii)
4Implements automatic logoff after period of inactivity.
5Required for HIPAA compliance when processing PHI.
6"""
8from collections.abc import Awaitable, Callable
9from datetime import datetime, UTC
10from typing import Any
12import jwt
13from fastapi import Request, Response
14from starlette.middleware.base import BaseHTTPMiddleware
15from starlette.responses import JSONResponse
17from mcp_server_langgraph.auth.session import SessionStore, get_session_store
18from mcp_server_langgraph.core.config import settings
19from mcp_server_langgraph.observability.telemetry import logger, metrics
22class SessionTimeoutMiddleware(BaseHTTPMiddleware):
23 """
24 Automatic session timeout middleware (HIPAA 164.312(a)(2)(iii))
26 Terminates inactive sessions after configured timeout period.
27 Default timeout: 15 minutes (HIPAA recommendation).
29 Features:
30 - Configurable timeout period
31 - Sliding window (activity extends session)
32 - Audit logging of timeout events
33 - Metrics tracking
34 """
36 def __init__(
37 self,
38 app: Any,
39 timeout_seconds: int = 900, # 15 minutes default
40 session_store: SessionStore | None = None,
41 ):
42 """
43 Initialize session timeout middleware
45 Args:
46 app: FastAPI application
47 timeout_seconds: Inactivity timeout in seconds (default: 900 = 15 minutes)
48 session_store: Session storage backend
49 """
50 super().__init__(app)
51 self.timeout_seconds = timeout_seconds
52 self.session_store = session_store or get_session_store()
54 logger.info(
55 "Session timeout middleware initialized",
56 extra={"timeout_seconds": timeout_seconds, "timeout_minutes": timeout_seconds / 60},
57 )
59 async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
60 """
61 Check session activity and enforce timeout
63 Args:
64 request: Incoming HTTP request
65 call_next: Next middleware/endpoint
67 Returns:
68 Response (401 if session timed out, otherwise normal response)
69 """
70 # Skip timeout check for public endpoints
71 if self._is_public_endpoint(request.url.path):
72 return await call_next(request)
74 # Get session from request (if authenticated)
75 session_id = self._get_session_id(request)
77 if not session_id: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was never true
78 # No session, continue normally
79 return await call_next(request)
81 # Check session inactivity
82 try:
83 session = await self.session_store.get(session_id)
85 if not session:
86 # Session not found (already expired or deleted)
87 return await call_next(request)
89 # Parse last accessed time
90 last_accessed = datetime.fromisoformat(session.last_accessed.replace("Z", "+00:00"))
91 now = datetime.now(UTC)
92 inactive_seconds = (now - last_accessed).total_seconds()
94 if inactive_seconds > self.timeout_seconds:
95 # Session has timed out
96 await self._handle_timeout(request, session_id, inactive_seconds)
98 return JSONResponse(
99 status_code=401,
100 content={
101 "detail": "Session expired due to inactivity",
102 "inactive_seconds": int(inactive_seconds),
103 "timeout_seconds": self.timeout_seconds,
104 "code": "SESSION_TIMEOUT",
105 },
106 )
108 # Update last activity time (sliding window)
109 session.last_accessed = now.isoformat().replace("+00:00", "Z")
110 await self.session_store.update(session.session_id, session.metadata)
112 except Exception as e:
113 logger.error(f"Session timeout check failed: {e}", exc_info=True)
114 # Continue on error (fail open for availability)
116 # Session is active, continue
117 response = await call_next(request)
118 return response
120 async def _handle_timeout(self, request: Request, session_id: str, inactive_seconds: float) -> None:
121 """
122 Handle session timeout
124 Args:
125 request: HTTP request
126 session_id: Session ID that timed out
127 inactive_seconds: Seconds of inactivity
128 """
129 # Delete the session
130 await self.session_store.delete(session_id)
132 # Log timeout event (HIPAA audit requirement)
133 logger.warning(
134 "HIPAA: Session timeout",
135 extra={
136 "session_id": session_id,
137 "inactive_seconds": inactive_seconds,
138 "timeout_seconds": self.timeout_seconds,
139 "ip_address": request.client.host if request.client else "unknown",
140 "path": request.url.path,
141 },
142 )
144 # Track metrics
145 metrics.successful_calls.add(
146 1,
147 {
148 "operation": "session_timeout",
149 "inactive_seconds": int(inactive_seconds),
150 },
151 )
153 def _decode_jwt_token(self, token: str) -> dict[str, Any] | None:
154 """
155 Decode JWT token and extract payload
157 Handles both symmetric (HS256) and asymmetric (RS256, ES256, etc.) algorithms.
159 Args:
160 token: JWT token string
162 Returns:
163 Decoded payload dict, or None if decoding fails
164 """
165 try:
166 # Select the correct key based on the algorithm
167 # Asymmetric algorithms (RS256, RS384, RS512, ES256, etc.) use public key
168 # Symmetric algorithms (HS256, HS384, HS512) use secret key
169 algorithm = settings.jwt_algorithm
170 if algorithm.startswith(("RS", "ES", "PS")):
171 # Asymmetric algorithm - use public key for verification
172 key = settings.jwt_public_key or "" # type: ignore[attr-defined]
173 else:
174 # Symmetric algorithm - use secret key
175 key = str(settings.jwt_secret_key) if settings.jwt_secret_key else ""
177 # Decode JWT to extract session ID
178 # Note: We don't verify expiration here (verify_exp=False) because
179 # session timeout is independent of JWT expiration
180 payload: dict[str, Any] = jwt.decode(
181 token,
182 key,
183 algorithms=[algorithm],
184 options={"verify_exp": False}, # Don't verify expiration
185 )
187 return payload
189 except jwt.InvalidTokenError as e:
190 # Invalid JWT - log and continue to other methods
191 logger.debug(f"Failed to decode JWT for session timeout: {e}")
192 return None
193 except Exception as e:
194 # Unexpected error (e.g., missing jwt_secret_key)
195 logger.warning(f"Unexpected error decoding JWT: {e}")
196 return None
198 def _get_session_id(self, request: Request) -> str | None:
199 """
200 Extract session ID from request
202 Tries multiple sources:
203 1. Authorization header (Bearer token)
204 2. Cookie
205 3. Request state (if already authenticated)
207 Args:
208 request: HTTP request
210 Returns:
211 Session ID or None
212 """
213 # Try Authorization header (JWT Bearer token)
214 auth_header = request.headers.get("authorization", "")
215 if auth_header.startswith("Bearer "):
216 token = auth_header.replace("Bearer ", "").strip()
217 payload = self._decode_jwt_token(token)
219 if payload:
220 # Try multiple possible session ID claim names
221 # Standard claims: 'sid' (session ID), 'jti' (JWT ID), or custom 'session_id'
222 session_id_from_jwt: str | None = payload.get("sid") or payload.get("session_id") or payload.get("jti")
224 if session_id_from_jwt: 224 ↛ 228line 224 didn't jump to line 228 because the condition on line 224 was always true
225 return str(session_id_from_jwt)
227 # Try cookie
228 session_id: str | None = request.cookies.get("session_id")
229 if session_id:
230 return session_id
232 # Try request state (if already authenticated by previous middleware)
233 if hasattr(request.state, "session_id"):
234 state_session_id: str = request.state.session_id
235 return state_session_id
237 return None
239 def _is_public_endpoint(self, path: str) -> bool:
240 """
241 Check if endpoint is public (no timeout enforcement)
243 Args:
244 path: Request path
246 Returns:
247 True if public endpoint
248 """
249 public_paths = [
250 "/health",
251 "/metrics",
252 "/api/v1/auth/login",
253 "/api/v1/auth/register",
254 "/docs",
255 "/openapi.json",
256 ]
258 return any(path.startswith(public_path) for public_path in public_paths)
261def create_session_timeout_middleware(
262 app: Any,
263 timeout_minutes: int = 15,
264 session_store: SessionStore | None = None,
265) -> SessionTimeoutMiddleware:
266 """
267 Create session timeout middleware
269 Args:
270 app: FastAPI application
271 timeout_minutes: Inactivity timeout in minutes (default: 15)
272 session_store: Session storage backend
274 Returns:
275 SessionTimeoutMiddleware instance
277 Example:
278 from fastapi import FastAPI
279 from mcp_server_langgraph.middleware.session_timeout import create_session_timeout_middleware
281 app = FastAPI()
283 # Add session timeout middleware (HIPAA compliant)
284 app.add_middleware(
285 SessionTimeoutMiddleware,
286 timeout_seconds=900, # 15 minutes
287 session_store=session_store
288 )
289 """
290 return SessionTimeoutMiddleware(
291 app=app,
292 timeout_seconds=timeout_minutes * 60,
293 session_store=session_store,
294 )