Coverage for src / mcp_server_langgraph / auth / session.py: 81%

332 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Session Management for Authentication 

3 

4Provides pluggable session storage backends with support for: 

5- Redis (production) 

6- In-memory (development/testing) 

7- Session lifecycle management 

8- Sliding expiration windows 

9- Concurrent session limits 

10""" 

11 

12import json 

13import secrets 

14from abc import ABC, abstractmethod 

15from collections.abc import Mapping 

16from datetime import datetime, timedelta, UTC 

17from typing import Any, cast 

18 

19from pydantic import BaseModel, ConfigDict, Field, field_validator 

20 

21try: 

22 import redis.asyncio as redis 

23 

24 REDIS_AVAILABLE = True 

25except ImportError: 

26 redis = None # type: ignore[assignment] 

27 REDIS_AVAILABLE = False 

28 

29from mcp_server_langgraph.observability.telemetry import logger, tracer 

30 

31 

32class SessionData(BaseModel): 

33 """ 

34 Type-safe session data structure with validation 

35 

36 Uses Pydantic for automatic validation, serialization, and better IDE support. 

37 All datetime fields are stored as ISO format strings for Redis compatibility. 

38 """ 

39 

40 session_id: str = Field(..., description="Unique session identifier", min_length=32) 

41 user_id: str = Field(..., description="User identifier (e.g., 'user:alice')") 

42 username: str = Field(..., description="Username") 

43 roles: list[str] = Field(default_factory=list, description="User roles") 

44 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional session metadata") 

45 created_at: str = Field(..., description="Session creation timestamp (ISO format)") 

46 last_accessed: str = Field(..., description="Last access timestamp (ISO format)") 

47 expires_at: str = Field(..., description="Expiration timestamp (ISO format)") 

48 

49 model_config = ConfigDict( 

50 frozen=False, # Allow updates for last_accessed, metadata 

51 validate_assignment=True, # Validate on field assignment 

52 str_strip_whitespace=True, # Strip whitespace from strings 

53 json_schema_extra={ 

54 "example": { 

55 "session_id": "a1b2c3d4e5f6...", 

56 "user_id": "user:alice", 

57 "username": "alice", 

58 "roles": ["user", "admin"], 

59 "metadata": {"ip": "192.168.1.1"}, 

60 "created_at": "2025-01-01T00:00:00.000000", 

61 "last_accessed": "2025-01-01T00:00:00.000000", 

62 "expires_at": "2025-01-02T00:00:00.000000", 

63 } 

64 }, 

65 ) 

66 

67 @field_validator("session_id") 

68 @classmethod 

69 def validate_session_id(cls, v: str) -> str: 

70 """Validate session ID is properly formatted and secure""" 

71 if not v or len(v) < 32: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 msg = "Session ID must be at least 32 characters for security" 

73 raise ValueError(msg) 

74 return v 

75 

76 @field_validator("user_id") 

77 @classmethod 

78 def validate_user_id(cls, v: str) -> str: 

79 """Validate user ID format""" 

80 if not v: 

81 msg = "User ID cannot be empty" 

82 raise ValueError(msg) 

83 return v 

84 

85 @field_validator("created_at", "last_accessed", "expires_at") 

86 @classmethod 

87 def validate_timestamp(cls, v: str) -> str: 

88 """Validate timestamp is in ISO format and normalize Zulu time to explicit timezone""" 

89 try: 

90 # Handle Zulu time (Z) suffix by replacing with +00:00 

91 # This normalizes timestamps from datetime.isoformat() calls 

92 normalized = v.replace("Z", "+00:00") if v.endswith("Z") else v 

93 datetime.fromisoformat(normalized) 

94 return normalized 

95 except (ValueError, TypeError): 

96 msg = f"Timestamp must be in ISO format, got: {v}" 

97 raise ValueError(msg) 

98 

99 def to_dict(self) -> dict[str, Any]: 

100 """Convert to dictionary for backward compatibility""" 

101 return self.model_dump() 

102 

103 @classmethod 

104 def from_dict(cls, data: dict[str, Any]) -> "SessionData": 

105 """Create SessionData from dictionary for backward compatibility""" 

106 return cls(**data) 

107 

108 

109class SessionStore(ABC): 

110 """Abstract base class for session storage backends""" 

111 

112 @abstractmethod 

113 async def create( 

114 self, 

115 user_id: str, 

116 username: str, 

117 roles: list[str], 

118 metadata: dict[str, Any] | None = None, 

119 ttl_seconds: int | None = None, 

120 ) -> str: 

121 """ 

122 Create a new session 

123 

124 Args: 

125 user_id: User identifier 

126 username: Username 

127 roles: User roles 

128 metadata: Additional session metadata 

129 ttl_seconds: Time-to-live in seconds 

130 

131 Returns: 

132 Session ID 

133 """ 

134 

135 @abstractmethod 

136 async def get(self, session_id: str) -> SessionData | None: 

137 """ 

138 Get session data 

139 

140 Args: 

141 session_id: Session identifier 

142 

143 Returns: 

144 Session data or None if not found/expired 

145 """ 

146 

147 @abstractmethod 

148 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool: 

149 """ 

150 Update session metadata 

151 

152 Args: 

153 session_id: Session identifier 

154 metadata: Metadata to update 

155 

156 Returns: 

157 True if successful, False otherwise 

158 """ 

159 

160 @abstractmethod 

161 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool: 

162 """ 

163 Refresh session expiration 

164 

165 Args: 

166 session_id: Session identifier 

167 ttl_seconds: New TTL in seconds 

168 

169 Returns: 

170 True if successful, False otherwise 

171 """ 

172 

173 @abstractmethod 

174 async def delete(self, session_id: str) -> bool: 

175 """ 

176 Delete a session 

177 

178 Args: 

179 session_id: Session identifier 

180 

181 Returns: 

182 True if deleted, False if not found 

183 """ 

184 

185 @abstractmethod 

186 async def list_user_sessions(self, user_id: str) -> list[SessionData]: 

187 """ 

188 List all sessions for a user 

189 

190 Args: 

191 user_id: User identifier 

192 

193 Returns: 

194 List of session data 

195 """ 

196 

197 @abstractmethod 

198 async def delete_user_sessions(self, user_id: str) -> int: 

199 """ 

200 Delete all sessions for a user 

201 

202 Args: 

203 user_id: User identifier 

204 

205 Returns: 

206 Number of sessions deleted 

207 """ 

208 

209 @abstractmethod 

210 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]: 

211 """ 

212 Get sessions that haven't been accessed since cutoff date 

213 

214 Args: 

215 cutoff_date: Return sessions with last_accessed before this date 

216 

217 Returns: 

218 List of inactive session data 

219 """ 

220 

221 @abstractmethod 

222 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int: 

223 """ 

224 Delete sessions that haven't been accessed since cutoff date 

225 

226 Args: 

227 cutoff_date: Delete sessions with last_accessed before this date 

228 

229 Returns: 

230 Number of sessions deleted 

231 """ 

232 

233 def _generate_session_id(self) -> str: 

234 """Generate cryptographically secure session ID""" 

235 return secrets.token_urlsafe(32) 

236 

237 

238class InMemorySessionStore(SessionStore): 

239 """ 

240 In-memory session store for development and testing 

241 

242 WARNING: Not suitable for production use: 

243 - Data lost on restart 

244 - No clustering support 

245 - No persistence 

246 """ 

247 

248 def __init__( 

249 self, 

250 default_ttl_seconds: int = 86400, # 24 hours 

251 sliding_window: bool = True, 

252 max_concurrent_sessions: int = 5, 

253 ): 

254 """ 

255 Initialize in-memory session store 

256 

257 Args: 

258 default_ttl_seconds: Default session TTL 

259 sliding_window: Enable sliding expiration 

260 max_concurrent_sessions: Max sessions per user 

261 """ 

262 self.sessions: dict[str, SessionData] = {} 

263 self.user_sessions: dict[str, list[str]] = {} # user_id -> [session_ids] 

264 self.default_ttl = default_ttl_seconds 

265 self.sliding_window = sliding_window 

266 self.max_concurrent = max_concurrent_sessions 

267 

268 logger.info( 

269 "Initialized InMemorySessionStore", 

270 extra={ 

271 "default_ttl": default_ttl_seconds, 

272 "sliding_window": sliding_window, 

273 "max_concurrent": max_concurrent_sessions, 

274 }, 

275 ) 

276 

277 async def create( 

278 self, 

279 user_id: str, 

280 username: str, 

281 roles: list[str], 

282 metadata: dict[str, Any] | None = None, 

283 ttl_seconds: int | None = None, 

284 ) -> str: 

285 """Create new session""" 

286 with tracer.start_as_current_span("session.create") as span: 

287 span.set_attribute("user.id", user_id) 

288 

289 # Check concurrent session limit 

290 if user_id in self.user_sessions and len(self.user_sessions[user_id]) >= self.max_concurrent: 

291 # Remove oldest session 

292 oldest_session_id = self.user_sessions[user_id].pop(0) 

293 self.sessions.pop(oldest_session_id, None) 

294 logger.info(f"Removed oldest session for user {user_id} (max concurrent limit reached)") 

295 

296 # Generate session 

297 session_id = self._generate_session_id() 

298 ttl = ttl_seconds or self.default_ttl 

299 now = datetime.now(UTC) 

300 

301 session_data = SessionData( 

302 session_id=session_id, 

303 user_id=user_id, 

304 username=username, 

305 roles=roles, 

306 metadata=metadata or {}, 

307 created_at=now.isoformat(), 

308 last_accessed=now.isoformat(), 

309 expires_at=(now + timedelta(seconds=ttl)).isoformat(), 

310 ) 

311 

312 self.sessions[session_id] = session_data 

313 

314 # Track user sessions 

315 if user_id not in self.user_sessions: 

316 self.user_sessions[user_id] = [] 

317 self.user_sessions[user_id].append(session_id) 

318 

319 from mcp_server_langgraph.core.security import sanitize_for_logging 

320 

321 logger.info( 

322 "Session created", 

323 extra=sanitize_for_logging({"session_id": session_id, "user_id": user_id, "ttl_seconds": ttl}), 

324 ) 

325 

326 return session_id 

327 

328 async def get(self, session_id: str) -> SessionData | None: 

329 """Get session with expiration check""" 

330 with tracer.start_as_current_span("session.get") as span: 

331 span.set_attribute("session.id", session_id) 

332 

333 if session_id not in self.sessions: 

334 return None 

335 

336 session = self.sessions[session_id] 

337 

338 # Check expiration 

339 expires_at = datetime.fromisoformat(session.expires_at) 

340 if datetime.now(UTC) > expires_at: 

341 # Session expired 

342 await self.delete(session_id) 

343 return None 

344 

345 # Update last accessed time (sliding window) 

346 if self.sliding_window: 

347 session.last_accessed = datetime.now(UTC).isoformat() 

348 

349 return session 

350 

351 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool: 

352 """Update session metadata""" 

353 if session_id not in self.sessions: 

354 return False 

355 

356 session = self.sessions[session_id] 

357 session.metadata.update(metadata) 

358 session.last_accessed = datetime.now(UTC).isoformat() 

359 

360 logger.info(f"Session metadata updated: {session_id}") 

361 return True 

362 

363 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool: 

364 """Refresh session expiration""" 

365 if session_id not in self.sessions: 

366 return False 

367 

368 session = self.sessions[session_id] 

369 ttl = ttl_seconds or self.default_ttl 

370 now = datetime.now(UTC) 

371 

372 session.last_accessed = now.isoformat() 

373 session.expires_at = (now + timedelta(seconds=ttl)).isoformat() 

374 

375 logger.info(f"Session refreshed: {session_id}, new TTL: {ttl}s") 

376 return True 

377 

378 async def delete(self, session_id: str) -> bool: 

379 """Delete session""" 

380 if session_id not in self.sessions: 

381 return False 

382 

383 session = self.sessions.pop(session_id) 

384 user_id = session.user_id 

385 

386 # Remove from user sessions tracking 

387 if user_id in self.user_sessions: 387 ↛ 395line 387 didn't jump to line 395 because the condition on line 387 was always true

388 try: 

389 self.user_sessions[user_id].remove(session_id) 

390 if not self.user_sessions[user_id]: 

391 del self.user_sessions[user_id] 

392 except ValueError: 

393 pass 

394 

395 logger.info(f"Session deleted: {session_id}") 

396 return True 

397 

398 async def list_user_sessions(self, user_id: str) -> list[SessionData]: 

399 """List all active sessions for a user""" 

400 if user_id not in self.user_sessions: 

401 return [] 

402 

403 sessions = [] 

404 for session_id in self.user_sessions[user_id][:]: # Copy to allow modification 

405 session = await self.get(session_id) # Will auto-delete expired 

406 if session: 406 ↛ 404line 406 didn't jump to line 404 because the condition on line 406 was always true

407 sessions.append(session) 

408 

409 return sessions 

410 

411 async def delete_user_sessions(self, user_id: str) -> int: 

412 """Delete all sessions for a user""" 

413 if user_id not in self.user_sessions: 

414 return 0 

415 

416 session_ids = self.user_sessions[user_id][:] 

417 count = 0 

418 

419 for session_id in session_ids: 

420 if await self.delete(session_id): 420 ↛ 419line 420 didn't jump to line 419 because the condition on line 420 was always true

421 count += 1 

422 

423 logger.info(f"Deleted {count} sessions for user {user_id}") 

424 return count 

425 

426 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]: 

427 """Get sessions that haven't been accessed since cutoff date""" 

428 inactive_sessions = [] 

429 

430 for session_id, session in list(self.sessions.items()): 

431 try: 

432 # Check if session is expired first 

433 if await self.get(session_id) is None: 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true

434 continue # Session was expired and auto-deleted 

435 

436 # Parse last_accessed timestamp 

437 last_accessed = datetime.fromisoformat(session.last_accessed) 

438 

439 # Add to inactive list if older than cutoff 

440 if last_accessed < cutoff_date: 

441 inactive_sessions.append(session) 

442 

443 except (ValueError, AttributeError) as e: 

444 logger.warning(f"Error parsing session {session_id}: {e}") 

445 continue 

446 

447 logger.info(f"Found {len(inactive_sessions)} inactive sessions before {cutoff_date.isoformat()}") 

448 return inactive_sessions 

449 

450 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int: 

451 """Delete sessions that haven't been accessed since cutoff date""" 

452 inactive_sessions = await self.get_inactive_sessions(cutoff_date) 

453 count = 0 

454 

455 for session in inactive_sessions: 

456 if await self.delete(session.session_id): 456 ↛ 455line 456 didn't jump to line 455 because the condition on line 456 was always true

457 count += 1 

458 

459 logger.info(f"Deleted {count} inactive sessions before {cutoff_date.isoformat()}") 

460 return count 

461 

462 

463class RedisSessionStore(SessionStore): 

464 """ 

465 Redis-backed session store for production use 

466 

467 Features: 

468 - Persistent storage 

469 - Clustering support 

470 - Automatic expiration via Redis TTL 

471 - High performance 

472 """ 

473 

474 def __init__( 

475 self, 

476 redis_url: str = "redis://localhost:6379/0", 

477 default_ttl_seconds: int = 86400, 

478 sliding_window: bool = True, 

479 max_concurrent_sessions: int = 5, 

480 ssl: bool = False, 

481 decode_responses: bool = True, 

482 password: str | None = None, 

483 ttl_seconds: int | None = None, 

484 ): 

485 """ 

486 Initialize Redis session store 

487 

488 Args: 

489 redis_url: Redis connection URL 

490 default_ttl_seconds: Default session TTL 

491 sliding_window: Enable sliding expiration 

492 max_concurrent_sessions: Max sessions per user 

493 ssl: Use SSL/TLS 

494 decode_responses: Decode responses to strings 

495 password: Redis password (optional) 

496 ttl_seconds: Alias for default_ttl_seconds (for backward compatibility) 

497 """ 

498 if not REDIS_AVAILABLE: 498 ↛ 499line 498 didn't jump to line 499 because the condition on line 498 was never true

499 msg = "Redis not available. Add 'redis[hiredis]>=5.0.0' to pyproject.toml dependencies, then run: uv sync" 

500 raise ImportError(msg) 

501 

502 # Support both ttl_seconds and default_ttl_seconds for backward compatibility 

503 if ttl_seconds is not None: 

504 default_ttl_seconds = ttl_seconds 

505 

506 self.redis_url = redis_url 

507 self.default_ttl = default_ttl_seconds 

508 self.sliding_window = sliding_window 

509 self.max_concurrent = max_concurrent_sessions 

510 self.decode_responses = decode_responses 

511 

512 # Initialize Redis client 

513 self.redis = redis.from_url( # type: ignore[no-untyped-call] 

514 redis_url, 

515 password=password, 

516 ssl=ssl, 

517 decode_responses=decode_responses, 

518 encoding="utf-8", 

519 ) 

520 

521 logger.info( 

522 "Initialized RedisSessionStore", 

523 extra={ 

524 "redis_url": redis_url, 

525 "default_ttl": default_ttl_seconds, 

526 "sliding_window": sliding_window, 

527 "max_concurrent": max_concurrent_sessions, 

528 }, 

529 ) 

530 

531 async def create( 

532 self, 

533 user_id: str, 

534 username: str, 

535 roles: list[str], 

536 metadata: dict[str, Any] | None = None, 

537 ttl_seconds: int | None = None, 

538 ) -> str: 

539 """Create new session in Redis""" 

540 with tracer.start_as_current_span("session.create") as span: 

541 span.set_attribute("user.id", user_id) 

542 

543 # Check concurrent session limit 

544 user_sessions_key = f"user_sessions:{user_id}" 

545 session_ids = await self.redis.lrange(user_sessions_key, 0, -1) 

546 

547 if len(session_ids) >= self.max_concurrent: 

548 # Remove oldest session 

549 oldest_session_id = session_ids[0] 

550 if isinstance(oldest_session_id, bytes): 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true

551 oldest_session_id = oldest_session_id.decode("utf-8") 

552 await self.delete(oldest_session_id) 

553 logger.info(f"Removed oldest session for user {user_id}") 

554 

555 # Generate session 

556 session_id = self._generate_session_id() 

557 ttl = ttl_seconds or self.default_ttl 

558 now = datetime.now(UTC) 

559 

560 session_data = { 

561 "session_id": session_id, 

562 "user_id": user_id, 

563 "username": username, 

564 "roles": ",".join(roles), # Store as comma-separated 

565 "metadata": json.dumps(metadata or {}), # Store as JSON string 

566 "created_at": now.isoformat(), 

567 "last_accessed": now.isoformat(), 

568 "expires_at": (now + timedelta(seconds=ttl)).isoformat(), 

569 } 

570 

571 # Store session in Redis 

572 session_key = f"session:{session_id}" 

573 await self.redis.hset(session_key, mapping=cast(Mapping[str | bytes, bytes | float | int | str], session_data)) 

574 await self.redis.expire(session_key, ttl) 

575 

576 # Track user sessions 

577 await self.redis.rpush(user_sessions_key, session_id) 

578 await self.redis.expire(user_sessions_key, ttl + 3600) # Extra hour 

579 

580 from mcp_server_langgraph.core.security import sanitize_for_logging 

581 

582 logger.info( 

583 "Session created in Redis", 

584 extra=sanitize_for_logging({"session_id": session_id, "user_id": user_id, "ttl_seconds": ttl}), 

585 ) 

586 

587 return session_id 

588 

589 async def get(self, session_id: str) -> SessionData | None: 

590 """Get session from Redis""" 

591 with tracer.start_as_current_span("session.get") as span: 

592 span.set_attribute("session.id", session_id) 

593 

594 session_key = f"session:{session_id}" 

595 data = await self.redis.hgetall(session_key) 

596 

597 if not data: 

598 return None 

599 

600 # Convert Redis data to SessionData (Pydantic validates automatically) 

601 # Parse metadata safely using json.loads instead of eval 

602 metadata_str = data.get("metadata", "{}") 

603 try: 

604 metadata = json.loads(metadata_str) if metadata_str else {} 

605 except (json.JSONDecodeError, TypeError): 

606 metadata = {} 

607 

608 session = SessionData( 

609 session_id=cast(str, data.get("session_id")), 

610 user_id=cast(str, data.get("user_id")), 

611 username=cast(str, data.get("username")), 

612 roles=data.get("roles", "").split(",") if data.get("roles") else [], 

613 metadata=metadata, 

614 created_at=cast(str, data.get("created_at")), 

615 last_accessed=cast(str, data.get("last_accessed")), 

616 expires_at=cast(str, data.get("expires_at")), 

617 ) 

618 

619 # Update last accessed (sliding window) 

620 if self.sliding_window: 620 ↛ 623line 620 didn't jump to line 623 because the condition on line 620 was always true

621 await self.redis.hset(session_key, "last_accessed", datetime.now(UTC).isoformat()) 

622 

623 return session 

624 

625 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool: 

626 """Update session metadata""" 

627 session_key = f"session:{session_id}" 

628 exists = await self.redis.exists(session_key) 

629 

630 if not exists: 630 ↛ 631line 630 didn't jump to line 631 because the condition on line 630 was never true

631 return False 

632 

633 # Get current session to preserve Pydantic validation 

634 session = await self.get(session_id) 

635 if not session: 635 ↛ 636line 635 didn't jump to line 636 because the condition on line 635 was never true

636 return False 

637 

638 # Update metadata on Pydantic model 

639 session.metadata.update(metadata) 

640 session.last_accessed = datetime.now(UTC).isoformat() 

641 

642 # Persist to Redis 

643 await self.redis.hset( 

644 session_key, mapping={"metadata": json.dumps(session.metadata), "last_accessed": session.last_accessed} 

645 ) 

646 

647 logger.info(f"Session metadata updated in Redis: {session_id}") 

648 return True 

649 

650 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool: 

651 """Refresh session expiration""" 

652 session_key = f"session:{session_id}" 

653 exists = await self.redis.exists(session_key) 

654 

655 if not exists: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true

656 return False 

657 

658 ttl = ttl_seconds or self.default_ttl 

659 now = datetime.now(UTC) 

660 new_expires_at = (now + timedelta(seconds=ttl)).isoformat() 

661 

662 await self.redis.hset(session_key, mapping={"last_accessed": now.isoformat(), "expires_at": new_expires_at}) 

663 await self.redis.expire(session_key, ttl) 

664 

665 logger.info(f"Session refreshed in Redis: {session_id}, TTL: {ttl}s") 

666 return True 

667 

668 async def delete(self, session_id: str) -> bool: 

669 """Delete session from Redis""" 

670 session_key = f"session:{session_id}" 

671 

672 # Get user_id before deleting 

673 user_id = await self.redis.hget(session_key, "user_id") 

674 

675 # Delete session 

676 deleted = await self.redis.delete(session_key) 

677 

678 if deleted and user_id: 

679 # Remove from user sessions list 

680 user_sessions_key = f"user_sessions:{user_id}" 

681 await self.redis.lrem(user_sessions_key, 0, session_id) 

682 

683 logger.info(f"Session deleted from Redis: {session_id}") 

684 return bool(deleted) 

685 

686 async def list_user_sessions(self, user_id: str) -> list[SessionData]: 

687 """List all active sessions for a user""" 

688 user_sessions_key = f"user_sessions:{user_id}" 

689 session_ids = await self.redis.lrange(user_sessions_key, 0, -1) 

690 

691 sessions = [] 

692 for session_id in session_ids: 

693 if isinstance(session_id, bytes): 693 ↛ 694line 693 didn't jump to line 694 because the condition on line 693 was never true

694 session_id = session_id.decode("utf-8") 

695 session = await self.get(session_id) 

696 if session: 696 ↛ 692line 696 didn't jump to line 692 because the condition on line 696 was always true

697 sessions.append(session) 

698 

699 return sessions 

700 

701 async def delete_user_sessions(self, user_id: str) -> int: 

702 """Delete all sessions for a user""" 

703 user_sessions_key = f"user_sessions:{user_id}" 

704 session_ids = await self.redis.lrange(user_sessions_key, 0, -1) 

705 

706 count = 0 

707 for session_id in session_ids: 

708 if isinstance(session_id, bytes): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 session_id = session_id.decode("utf-8") 

710 if await self.delete(session_id): 710 ↛ 707line 710 didn't jump to line 707 because the condition on line 710 was always true

711 count += 1 

712 

713 # Delete user sessions list 

714 await self.redis.delete(user_sessions_key) 

715 

716 logger.info(f"Deleted {count} sessions from Redis for user {user_id}") 

717 return count 

718 

719 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]: 

720 """Get sessions that haven't been accessed since cutoff date""" 

721 inactive_sessions = [] 

722 

723 # Scan all session keys 

724 cursor = 0 

725 while True: 

726 cursor, keys = await self.redis.scan(cursor, match="session:*", count=100) 

727 

728 for key in keys: 

729 if isinstance(key, bytes): 

730 key = key.decode("utf-8") 

731 

732 session_id = key.replace("session:", "") 

733 session = await self.get(session_id) 

734 

735 if session: 

736 try: 

737 # Parse last_accessed timestamp 

738 last_accessed = datetime.fromisoformat(session.last_accessed) 

739 

740 # Add to inactive list if older than cutoff 

741 if last_accessed < cutoff_date: 

742 inactive_sessions.append(session) 

743 

744 except (ValueError, AttributeError) as e: 

745 logger.warning(f"Error parsing session {session_id}: {e}") 

746 continue 

747 

748 if cursor == 0: 

749 break 

750 

751 logger.info(f"Found {len(inactive_sessions)} inactive sessions in Redis before {cutoff_date.isoformat()}") 

752 return inactive_sessions 

753 

754 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int: 

755 """Delete sessions that haven't been accessed since cutoff date""" 

756 inactive_sessions = await self.get_inactive_sessions(cutoff_date) 

757 count = 0 

758 

759 for session in inactive_sessions: 

760 if await self.delete(session.session_id): 

761 count += 1 

762 

763 logger.info(f"Deleted {count} inactive sessions from Redis before {cutoff_date.isoformat()}") 

764 return count 

765 

766 

767def create_session_store(backend: str = "memory", redis_url: str | None = None, **kwargs: Any) -> SessionStore: 

768 """ 

769 Factory function to create session store 

770 

771 Args: 

772 backend: "memory" or "redis" 

773 redis_url: Redis connection URL (required for redis backend) 

774 **kwargs: Additional arguments for session store 

775 

776 Returns: 

777 SessionStore instance 

778 

779 Raises: 

780 ValueError: If backend is unknown or redis_url missing for redis backend 

781 """ 

782 backend = backend.lower() 

783 

784 if backend == "memory": 

785 logger.info("Creating InMemorySessionStore") 

786 return InMemorySessionStore(**kwargs) 

787 

788 elif backend == "redis": 

789 if not redis_url: 789 ↛ 790line 789 didn't jump to line 790 because the condition on line 789 was never true

790 msg = "redis_url required for Redis backend" 

791 raise ValueError(msg) 

792 

793 if not REDIS_AVAILABLE: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true

794 msg = "Redis not available. Add 'redis[hiredis]>=5.0.0' to pyproject.toml dependencies, then run: uv sync" 

795 raise ImportError(msg) 

796 

797 logger.info("Creating RedisSessionStore") 

798 return RedisSessionStore(redis_url=redis_url, **kwargs) 

799 

800 else: 

801 msg = f"Unknown session backend: {backend}. Supported: 'memory', 'redis'" 

802 raise ValueError(msg) 

803 

804 

805# Global session store instance 

806_session_store: SessionStore | None = None 

807 

808 

809def get_session_store() -> SessionStore: 

810 """ 

811 FastAPI dependency to get the global session store instance 

812 

813 Returns: 

814 SessionStore instance 

815 

816 Warning: 

817 If no session store has been registered via set_session_store(), this function 

818 creates a default in-memory store. This fallback behavior should only occur 

819 during testing or if create_auth_middleware() hasn't been called yet. 

820 

821 In production, ensure create_auth_middleware() is called during app startup 

822 to register the configured session store (Redis or Memory based on settings). 

823 

824 Example: 

825 @app.get("/api/sessions") 

826 async def list_sessions(session_store: SessionStore = Depends(get_session_store)): 

827 # Use session_store 

828 pass 

829 """ 

830 global _session_store 

831 

832 if _session_store is None: 

833 # Create default in-memory session store as fallback 

834 # This should only happen during testing or before middleware initialization 

835 logger.warning( 

836 "Session store not registered globally, using fallback in-memory store. " 

837 "This may indicate create_auth_middleware() was not called. " 

838 "In production, register the session store via set_session_store()." 

839 ) 

840 _session_store = InMemorySessionStore() 

841 

842 return _session_store 

843 

844 

845def set_session_store(session_store: SessionStore) -> None: 

846 """ 

847 Set the global session store instance 

848 

849 Args: 

850 session_store: Session store to use globally 

851 

852 Example: 

853 # At application startup 

854 redis_store = create_session_store("redis", redis_url="redis://localhost:6379") 

855 set_session_store(redis_store) 

856 """ 

857 global _session_store 

858 _session_store = session_store