Coverage for src / mcp_server_langgraph / middleware / session_timeout.py: 94%

77 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Session Timeout Middleware - HIPAA 164.312(a)(2)(iii) 

3 

4Implements automatic logoff after period of inactivity. 

5Required for HIPAA compliance when processing PHI. 

6""" 

7 

8from collections.abc import Awaitable, Callable 

9from datetime import datetime, UTC 

10from typing import Any 

11 

12import jwt 

13from fastapi import Request, Response 

14from starlette.middleware.base import BaseHTTPMiddleware 

15from starlette.responses import JSONResponse 

16 

17from mcp_server_langgraph.auth.session import SessionStore, get_session_store 

18from mcp_server_langgraph.core.config import settings 

19from mcp_server_langgraph.observability.telemetry import logger, metrics 

20 

21 

22class SessionTimeoutMiddleware(BaseHTTPMiddleware): 

23 """ 

24 Automatic session timeout middleware (HIPAA 164.312(a)(2)(iii)) 

25 

26 Terminates inactive sessions after configured timeout period. 

27 Default timeout: 15 minutes (HIPAA recommendation). 

28 

29 Features: 

30 - Configurable timeout period 

31 - Sliding window (activity extends session) 

32 - Audit logging of timeout events 

33 - Metrics tracking 

34 """ 

35 

36 def __init__( 

37 self, 

38 app: Any, 

39 timeout_seconds: int = 900, # 15 minutes default 

40 session_store: SessionStore | None = None, 

41 ): 

42 """ 

43 Initialize session timeout middleware 

44 

45 Args: 

46 app: FastAPI application 

47 timeout_seconds: Inactivity timeout in seconds (default: 900 = 15 minutes) 

48 session_store: Session storage backend 

49 """ 

50 super().__init__(app) 

51 self.timeout_seconds = timeout_seconds 

52 self.session_store = session_store or get_session_store() 

53 

54 logger.info( 

55 "Session timeout middleware initialized", 

56 extra={"timeout_seconds": timeout_seconds, "timeout_minutes": timeout_seconds / 60}, 

57 ) 

58 

59 async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: 

60 """ 

61 Check session activity and enforce timeout 

62 

63 Args: 

64 request: Incoming HTTP request 

65 call_next: Next middleware/endpoint 

66 

67 Returns: 

68 Response (401 if session timed out, otherwise normal response) 

69 """ 

70 # Skip timeout check for public endpoints 

71 if self._is_public_endpoint(request.url.path): 

72 return await call_next(request) 

73 

74 # Get session from request (if authenticated) 

75 session_id = self._get_session_id(request) 

76 

77 if not session_id: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was never true

78 # No session, continue normally 

79 return await call_next(request) 

80 

81 # Check session inactivity 

82 try: 

83 session = await self.session_store.get(session_id) 

84 

85 if not session: 

86 # Session not found (already expired or deleted) 

87 return await call_next(request) 

88 

89 # Parse last accessed time 

90 last_accessed = datetime.fromisoformat(session.last_accessed.replace("Z", "+00:00")) 

91 now = datetime.now(UTC) 

92 inactive_seconds = (now - last_accessed).total_seconds() 

93 

94 if inactive_seconds > self.timeout_seconds: 

95 # Session has timed out 

96 await self._handle_timeout(request, session_id, inactive_seconds) 

97 

98 return JSONResponse( 

99 status_code=401, 

100 content={ 

101 "detail": "Session expired due to inactivity", 

102 "inactive_seconds": int(inactive_seconds), 

103 "timeout_seconds": self.timeout_seconds, 

104 "code": "SESSION_TIMEOUT", 

105 }, 

106 ) 

107 

108 # Update last activity time (sliding window) 

109 session.last_accessed = now.isoformat().replace("+00:00", "Z") 

110 await self.session_store.update(session.session_id, session.metadata) 

111 

112 except Exception as e: 

113 logger.error(f"Session timeout check failed: {e}", exc_info=True) 

114 # Continue on error (fail open for availability) 

115 

116 # Session is active, continue 

117 response = await call_next(request) 

118 return response 

119 

120 async def _handle_timeout(self, request: Request, session_id: str, inactive_seconds: float) -> None: 

121 """ 

122 Handle session timeout 

123 

124 Args: 

125 request: HTTP request 

126 session_id: Session ID that timed out 

127 inactive_seconds: Seconds of inactivity 

128 """ 

129 # Delete the session 

130 await self.session_store.delete(session_id) 

131 

132 # Log timeout event (HIPAA audit requirement) 

133 logger.warning( 

134 "HIPAA: Session timeout", 

135 extra={ 

136 "session_id": session_id, 

137 "inactive_seconds": inactive_seconds, 

138 "timeout_seconds": self.timeout_seconds, 

139 "ip_address": request.client.host if request.client else "unknown", 

140 "path": request.url.path, 

141 }, 

142 ) 

143 

144 # Track metrics 

145 metrics.successful_calls.add( 

146 1, 

147 { 

148 "operation": "session_timeout", 

149 "inactive_seconds": int(inactive_seconds), 

150 }, 

151 ) 

152 

153 def _decode_jwt_token(self, token: str) -> dict[str, Any] | None: 

154 """ 

155 Decode JWT token and extract payload 

156 

157 Handles both symmetric (HS256) and asymmetric (RS256, ES256, etc.) algorithms. 

158 

159 Args: 

160 token: JWT token string 

161 

162 Returns: 

163 Decoded payload dict, or None if decoding fails 

164 """ 

165 try: 

166 # Select the correct key based on the algorithm 

167 # Asymmetric algorithms (RS256, RS384, RS512, ES256, etc.) use public key 

168 # Symmetric algorithms (HS256, HS384, HS512) use secret key 

169 algorithm = settings.jwt_algorithm 

170 if algorithm.startswith(("RS", "ES", "PS")): 

171 # Asymmetric algorithm - use public key for verification 

172 key = settings.jwt_public_key or "" # type: ignore[attr-defined] 

173 else: 

174 # Symmetric algorithm - use secret key 

175 key = str(settings.jwt_secret_key) if settings.jwt_secret_key else "" 

176 

177 # Decode JWT to extract session ID 

178 # Note: We don't verify expiration here (verify_exp=False) because 

179 # session timeout is independent of JWT expiration 

180 payload: dict[str, Any] = jwt.decode( 

181 token, 

182 key, 

183 algorithms=[algorithm], 

184 options={"verify_exp": False}, # Don't verify expiration 

185 ) 

186 

187 return payload 

188 

189 except jwt.InvalidTokenError as e: 

190 # Invalid JWT - log and continue to other methods 

191 logger.debug(f"Failed to decode JWT for session timeout: {e}") 

192 return None 

193 except Exception as e: 

194 # Unexpected error (e.g., missing jwt_secret_key) 

195 logger.warning(f"Unexpected error decoding JWT: {e}") 

196 return None 

197 

198 def _get_session_id(self, request: Request) -> str | None: 

199 """ 

200 Extract session ID from request 

201 

202 Tries multiple sources: 

203 1. Authorization header (Bearer token) 

204 2. Cookie 

205 3. Request state (if already authenticated) 

206 

207 Args: 

208 request: HTTP request 

209 

210 Returns: 

211 Session ID or None 

212 """ 

213 # Try Authorization header (JWT Bearer token) 

214 auth_header = request.headers.get("authorization", "") 

215 if auth_header.startswith("Bearer "): 

216 token = auth_header.replace("Bearer ", "").strip() 

217 payload = self._decode_jwt_token(token) 

218 

219 if payload: 

220 # Try multiple possible session ID claim names 

221 # Standard claims: 'sid' (session ID), 'jti' (JWT ID), or custom 'session_id' 

222 session_id_from_jwt: str | None = payload.get("sid") or payload.get("session_id") or payload.get("jti") 

223 

224 if session_id_from_jwt: 224 ↛ 228line 224 didn't jump to line 228 because the condition on line 224 was always true

225 return str(session_id_from_jwt) 

226 

227 # Try cookie 

228 session_id: str | None = request.cookies.get("session_id") 

229 if session_id: 

230 return session_id 

231 

232 # Try request state (if already authenticated by previous middleware) 

233 if hasattr(request.state, "session_id"): 

234 state_session_id: str = request.state.session_id 

235 return state_session_id 

236 

237 return None 

238 

239 def _is_public_endpoint(self, path: str) -> bool: 

240 """ 

241 Check if endpoint is public (no timeout enforcement) 

242 

243 Args: 

244 path: Request path 

245 

246 Returns: 

247 True if public endpoint 

248 """ 

249 public_paths = [ 

250 "/health", 

251 "/metrics", 

252 "/api/v1/auth/login", 

253 "/api/v1/auth/register", 

254 "/docs", 

255 "/openapi.json", 

256 ] 

257 

258 return any(path.startswith(public_path) for public_path in public_paths) 

259 

260 

261def create_session_timeout_middleware( 

262 app: Any, 

263 timeout_minutes: int = 15, 

264 session_store: SessionStore | None = None, 

265) -> SessionTimeoutMiddleware: 

266 """ 

267 Create session timeout middleware 

268 

269 Args: 

270 app: FastAPI application 

271 timeout_minutes: Inactivity timeout in minutes (default: 15) 

272 session_store: Session storage backend 

273 

274 Returns: 

275 SessionTimeoutMiddleware instance 

276 

277 Example: 

278 from fastapi import FastAPI 

279 from mcp_server_langgraph.middleware.session_timeout import create_session_timeout_middleware 

280 

281 app = FastAPI() 

282 

283 # Add session timeout middleware (HIPAA compliant) 

284 app.add_middleware( 

285 SessionTimeoutMiddleware, 

286 timeout_seconds=900, # 15 minutes 

287 session_store=session_store 

288 ) 

289 """ 

290 return SessionTimeoutMiddleware( 

291 app=app, 

292 timeout_seconds=timeout_minutes * 60, 

293 session_store=session_store, 

294 )