Coverage for src / mcp_server_langgraph / auth / session.py: 81%
332 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Session Management for Authentication
4Provides pluggable session storage backends with support for:
5- Redis (production)
6- In-memory (development/testing)
7- Session lifecycle management
8- Sliding expiration windows
9- Concurrent session limits
10"""
12import json
13import secrets
14from abc import ABC, abstractmethod
15from collections.abc import Mapping
16from datetime import datetime, timedelta, UTC
17from typing import Any, cast
19from pydantic import BaseModel, ConfigDict, Field, field_validator
21try:
22 import redis.asyncio as redis
24 REDIS_AVAILABLE = True
25except ImportError:
26 redis = None # type: ignore[assignment]
27 REDIS_AVAILABLE = False
29from mcp_server_langgraph.observability.telemetry import logger, tracer
32class SessionData(BaseModel):
33 """
34 Type-safe session data structure with validation
36 Uses Pydantic for automatic validation, serialization, and better IDE support.
37 All datetime fields are stored as ISO format strings for Redis compatibility.
38 """
40 session_id: str = Field(..., description="Unique session identifier", min_length=32)
41 user_id: str = Field(..., description="User identifier (e.g., 'user:alice')")
42 username: str = Field(..., description="Username")
43 roles: list[str] = Field(default_factory=list, description="User roles")
44 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional session metadata")
45 created_at: str = Field(..., description="Session creation timestamp (ISO format)")
46 last_accessed: str = Field(..., description="Last access timestamp (ISO format)")
47 expires_at: str = Field(..., description="Expiration timestamp (ISO format)")
49 model_config = ConfigDict(
50 frozen=False, # Allow updates for last_accessed, metadata
51 validate_assignment=True, # Validate on field assignment
52 str_strip_whitespace=True, # Strip whitespace from strings
53 json_schema_extra={
54 "example": {
55 "session_id": "a1b2c3d4e5f6...",
56 "user_id": "user:alice",
57 "username": "alice",
58 "roles": ["user", "admin"],
59 "metadata": {"ip": "192.168.1.1"},
60 "created_at": "2025-01-01T00:00:00.000000",
61 "last_accessed": "2025-01-01T00:00:00.000000",
62 "expires_at": "2025-01-02T00:00:00.000000",
63 }
64 },
65 )
67 @field_validator("session_id")
68 @classmethod
69 def validate_session_id(cls, v: str) -> str:
70 """Validate session ID is properly formatted and secure"""
71 if not v or len(v) < 32: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 msg = "Session ID must be at least 32 characters for security"
73 raise ValueError(msg)
74 return v
76 @field_validator("user_id")
77 @classmethod
78 def validate_user_id(cls, v: str) -> str:
79 """Validate user ID format"""
80 if not v:
81 msg = "User ID cannot be empty"
82 raise ValueError(msg)
83 return v
85 @field_validator("created_at", "last_accessed", "expires_at")
86 @classmethod
87 def validate_timestamp(cls, v: str) -> str:
88 """Validate timestamp is in ISO format and normalize Zulu time to explicit timezone"""
89 try:
90 # Handle Zulu time (Z) suffix by replacing with +00:00
91 # This normalizes timestamps from datetime.isoformat() calls
92 normalized = v.replace("Z", "+00:00") if v.endswith("Z") else v
93 datetime.fromisoformat(normalized)
94 return normalized
95 except (ValueError, TypeError):
96 msg = f"Timestamp must be in ISO format, got: {v}"
97 raise ValueError(msg)
99 def to_dict(self) -> dict[str, Any]:
100 """Convert to dictionary for backward compatibility"""
101 return self.model_dump()
103 @classmethod
104 def from_dict(cls, data: dict[str, Any]) -> "SessionData":
105 """Create SessionData from dictionary for backward compatibility"""
106 return cls(**data)
109class SessionStore(ABC):
110 """Abstract base class for session storage backends"""
112 @abstractmethod
113 async def create(
114 self,
115 user_id: str,
116 username: str,
117 roles: list[str],
118 metadata: dict[str, Any] | None = None,
119 ttl_seconds: int | None = None,
120 ) -> str:
121 """
122 Create a new session
124 Args:
125 user_id: User identifier
126 username: Username
127 roles: User roles
128 metadata: Additional session metadata
129 ttl_seconds: Time-to-live in seconds
131 Returns:
132 Session ID
133 """
135 @abstractmethod
136 async def get(self, session_id: str) -> SessionData | None:
137 """
138 Get session data
140 Args:
141 session_id: Session identifier
143 Returns:
144 Session data or None if not found/expired
145 """
147 @abstractmethod
148 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool:
149 """
150 Update session metadata
152 Args:
153 session_id: Session identifier
154 metadata: Metadata to update
156 Returns:
157 True if successful, False otherwise
158 """
160 @abstractmethod
161 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool:
162 """
163 Refresh session expiration
165 Args:
166 session_id: Session identifier
167 ttl_seconds: New TTL in seconds
169 Returns:
170 True if successful, False otherwise
171 """
173 @abstractmethod
174 async def delete(self, session_id: str) -> bool:
175 """
176 Delete a session
178 Args:
179 session_id: Session identifier
181 Returns:
182 True if deleted, False if not found
183 """
185 @abstractmethod
186 async def list_user_sessions(self, user_id: str) -> list[SessionData]:
187 """
188 List all sessions for a user
190 Args:
191 user_id: User identifier
193 Returns:
194 List of session data
195 """
197 @abstractmethod
198 async def delete_user_sessions(self, user_id: str) -> int:
199 """
200 Delete all sessions for a user
202 Args:
203 user_id: User identifier
205 Returns:
206 Number of sessions deleted
207 """
209 @abstractmethod
210 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]:
211 """
212 Get sessions that haven't been accessed since cutoff date
214 Args:
215 cutoff_date: Return sessions with last_accessed before this date
217 Returns:
218 List of inactive session data
219 """
221 @abstractmethod
222 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int:
223 """
224 Delete sessions that haven't been accessed since cutoff date
226 Args:
227 cutoff_date: Delete sessions with last_accessed before this date
229 Returns:
230 Number of sessions deleted
231 """
233 def _generate_session_id(self) -> str:
234 """Generate cryptographically secure session ID"""
235 return secrets.token_urlsafe(32)
238class InMemorySessionStore(SessionStore):
239 """
240 In-memory session store for development and testing
242 WARNING: Not suitable for production use:
243 - Data lost on restart
244 - No clustering support
245 - No persistence
246 """
248 def __init__(
249 self,
250 default_ttl_seconds: int = 86400, # 24 hours
251 sliding_window: bool = True,
252 max_concurrent_sessions: int = 5,
253 ):
254 """
255 Initialize in-memory session store
257 Args:
258 default_ttl_seconds: Default session TTL
259 sliding_window: Enable sliding expiration
260 max_concurrent_sessions: Max sessions per user
261 """
262 self.sessions: dict[str, SessionData] = {}
263 self.user_sessions: dict[str, list[str]] = {} # user_id -> [session_ids]
264 self.default_ttl = default_ttl_seconds
265 self.sliding_window = sliding_window
266 self.max_concurrent = max_concurrent_sessions
268 logger.info(
269 "Initialized InMemorySessionStore",
270 extra={
271 "default_ttl": default_ttl_seconds,
272 "sliding_window": sliding_window,
273 "max_concurrent": max_concurrent_sessions,
274 },
275 )
277 async def create(
278 self,
279 user_id: str,
280 username: str,
281 roles: list[str],
282 metadata: dict[str, Any] | None = None,
283 ttl_seconds: int | None = None,
284 ) -> str:
285 """Create new session"""
286 with tracer.start_as_current_span("session.create") as span:
287 span.set_attribute("user.id", user_id)
289 # Check concurrent session limit
290 if user_id in self.user_sessions and len(self.user_sessions[user_id]) >= self.max_concurrent:
291 # Remove oldest session
292 oldest_session_id = self.user_sessions[user_id].pop(0)
293 self.sessions.pop(oldest_session_id, None)
294 logger.info(f"Removed oldest session for user {user_id} (max concurrent limit reached)")
296 # Generate session
297 session_id = self._generate_session_id()
298 ttl = ttl_seconds or self.default_ttl
299 now = datetime.now(UTC)
301 session_data = SessionData(
302 session_id=session_id,
303 user_id=user_id,
304 username=username,
305 roles=roles,
306 metadata=metadata or {},
307 created_at=now.isoformat(),
308 last_accessed=now.isoformat(),
309 expires_at=(now + timedelta(seconds=ttl)).isoformat(),
310 )
312 self.sessions[session_id] = session_data
314 # Track user sessions
315 if user_id not in self.user_sessions:
316 self.user_sessions[user_id] = []
317 self.user_sessions[user_id].append(session_id)
319 from mcp_server_langgraph.core.security import sanitize_for_logging
321 logger.info(
322 "Session created",
323 extra=sanitize_for_logging({"session_id": session_id, "user_id": user_id, "ttl_seconds": ttl}),
324 )
326 return session_id
328 async def get(self, session_id: str) -> SessionData | None:
329 """Get session with expiration check"""
330 with tracer.start_as_current_span("session.get") as span:
331 span.set_attribute("session.id", session_id)
333 if session_id not in self.sessions:
334 return None
336 session = self.sessions[session_id]
338 # Check expiration
339 expires_at = datetime.fromisoformat(session.expires_at)
340 if datetime.now(UTC) > expires_at:
341 # Session expired
342 await self.delete(session_id)
343 return None
345 # Update last accessed time (sliding window)
346 if self.sliding_window:
347 session.last_accessed = datetime.now(UTC).isoformat()
349 return session
351 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool:
352 """Update session metadata"""
353 if session_id not in self.sessions:
354 return False
356 session = self.sessions[session_id]
357 session.metadata.update(metadata)
358 session.last_accessed = datetime.now(UTC).isoformat()
360 logger.info(f"Session metadata updated: {session_id}")
361 return True
363 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool:
364 """Refresh session expiration"""
365 if session_id not in self.sessions:
366 return False
368 session = self.sessions[session_id]
369 ttl = ttl_seconds or self.default_ttl
370 now = datetime.now(UTC)
372 session.last_accessed = now.isoformat()
373 session.expires_at = (now + timedelta(seconds=ttl)).isoformat()
375 logger.info(f"Session refreshed: {session_id}, new TTL: {ttl}s")
376 return True
378 async def delete(self, session_id: str) -> bool:
379 """Delete session"""
380 if session_id not in self.sessions:
381 return False
383 session = self.sessions.pop(session_id)
384 user_id = session.user_id
386 # Remove from user sessions tracking
387 if user_id in self.user_sessions: 387 ↛ 395line 387 didn't jump to line 395 because the condition on line 387 was always true
388 try:
389 self.user_sessions[user_id].remove(session_id)
390 if not self.user_sessions[user_id]:
391 del self.user_sessions[user_id]
392 except ValueError:
393 pass
395 logger.info(f"Session deleted: {session_id}")
396 return True
398 async def list_user_sessions(self, user_id: str) -> list[SessionData]:
399 """List all active sessions for a user"""
400 if user_id not in self.user_sessions:
401 return []
403 sessions = []
404 for session_id in self.user_sessions[user_id][:]: # Copy to allow modification
405 session = await self.get(session_id) # Will auto-delete expired
406 if session: 406 ↛ 404line 406 didn't jump to line 404 because the condition on line 406 was always true
407 sessions.append(session)
409 return sessions
411 async def delete_user_sessions(self, user_id: str) -> int:
412 """Delete all sessions for a user"""
413 if user_id not in self.user_sessions:
414 return 0
416 session_ids = self.user_sessions[user_id][:]
417 count = 0
419 for session_id in session_ids:
420 if await self.delete(session_id): 420 ↛ 419line 420 didn't jump to line 419 because the condition on line 420 was always true
421 count += 1
423 logger.info(f"Deleted {count} sessions for user {user_id}")
424 return count
426 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]:
427 """Get sessions that haven't been accessed since cutoff date"""
428 inactive_sessions = []
430 for session_id, session in list(self.sessions.items()):
431 try:
432 # Check if session is expired first
433 if await self.get(session_id) is None: 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true
434 continue # Session was expired and auto-deleted
436 # Parse last_accessed timestamp
437 last_accessed = datetime.fromisoformat(session.last_accessed)
439 # Add to inactive list if older than cutoff
440 if last_accessed < cutoff_date:
441 inactive_sessions.append(session)
443 except (ValueError, AttributeError) as e:
444 logger.warning(f"Error parsing session {session_id}: {e}")
445 continue
447 logger.info(f"Found {len(inactive_sessions)} inactive sessions before {cutoff_date.isoformat()}")
448 return inactive_sessions
450 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int:
451 """Delete sessions that haven't been accessed since cutoff date"""
452 inactive_sessions = await self.get_inactive_sessions(cutoff_date)
453 count = 0
455 for session in inactive_sessions:
456 if await self.delete(session.session_id): 456 ↛ 455line 456 didn't jump to line 455 because the condition on line 456 was always true
457 count += 1
459 logger.info(f"Deleted {count} inactive sessions before {cutoff_date.isoformat()}")
460 return count
463class RedisSessionStore(SessionStore):
464 """
465 Redis-backed session store for production use
467 Features:
468 - Persistent storage
469 - Clustering support
470 - Automatic expiration via Redis TTL
471 - High performance
472 """
474 def __init__(
475 self,
476 redis_url: str = "redis://localhost:6379/0",
477 default_ttl_seconds: int = 86400,
478 sliding_window: bool = True,
479 max_concurrent_sessions: int = 5,
480 ssl: bool = False,
481 decode_responses: bool = True,
482 password: str | None = None,
483 ttl_seconds: int | None = None,
484 ):
485 """
486 Initialize Redis session store
488 Args:
489 redis_url: Redis connection URL
490 default_ttl_seconds: Default session TTL
491 sliding_window: Enable sliding expiration
492 max_concurrent_sessions: Max sessions per user
493 ssl: Use SSL/TLS
494 decode_responses: Decode responses to strings
495 password: Redis password (optional)
496 ttl_seconds: Alias for default_ttl_seconds (for backward compatibility)
497 """
498 if not REDIS_AVAILABLE: 498 ↛ 499line 498 didn't jump to line 499 because the condition on line 498 was never true
499 msg = "Redis not available. Add 'redis[hiredis]>=5.0.0' to pyproject.toml dependencies, then run: uv sync"
500 raise ImportError(msg)
502 # Support both ttl_seconds and default_ttl_seconds for backward compatibility
503 if ttl_seconds is not None:
504 default_ttl_seconds = ttl_seconds
506 self.redis_url = redis_url
507 self.default_ttl = default_ttl_seconds
508 self.sliding_window = sliding_window
509 self.max_concurrent = max_concurrent_sessions
510 self.decode_responses = decode_responses
512 # Initialize Redis client
513 self.redis = redis.from_url( # type: ignore[no-untyped-call]
514 redis_url,
515 password=password,
516 ssl=ssl,
517 decode_responses=decode_responses,
518 encoding="utf-8",
519 )
521 logger.info(
522 "Initialized RedisSessionStore",
523 extra={
524 "redis_url": redis_url,
525 "default_ttl": default_ttl_seconds,
526 "sliding_window": sliding_window,
527 "max_concurrent": max_concurrent_sessions,
528 },
529 )
531 async def create(
532 self,
533 user_id: str,
534 username: str,
535 roles: list[str],
536 metadata: dict[str, Any] | None = None,
537 ttl_seconds: int | None = None,
538 ) -> str:
539 """Create new session in Redis"""
540 with tracer.start_as_current_span("session.create") as span:
541 span.set_attribute("user.id", user_id)
543 # Check concurrent session limit
544 user_sessions_key = f"user_sessions:{user_id}"
545 session_ids = await self.redis.lrange(user_sessions_key, 0, -1)
547 if len(session_ids) >= self.max_concurrent:
548 # Remove oldest session
549 oldest_session_id = session_ids[0]
550 if isinstance(oldest_session_id, bytes): 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true
551 oldest_session_id = oldest_session_id.decode("utf-8")
552 await self.delete(oldest_session_id)
553 logger.info(f"Removed oldest session for user {user_id}")
555 # Generate session
556 session_id = self._generate_session_id()
557 ttl = ttl_seconds or self.default_ttl
558 now = datetime.now(UTC)
560 session_data = {
561 "session_id": session_id,
562 "user_id": user_id,
563 "username": username,
564 "roles": ",".join(roles), # Store as comma-separated
565 "metadata": json.dumps(metadata or {}), # Store as JSON string
566 "created_at": now.isoformat(),
567 "last_accessed": now.isoformat(),
568 "expires_at": (now + timedelta(seconds=ttl)).isoformat(),
569 }
571 # Store session in Redis
572 session_key = f"session:{session_id}"
573 await self.redis.hset(session_key, mapping=cast(Mapping[str | bytes, bytes | float | int | str], session_data))
574 await self.redis.expire(session_key, ttl)
576 # Track user sessions
577 await self.redis.rpush(user_sessions_key, session_id)
578 await self.redis.expire(user_sessions_key, ttl + 3600) # Extra hour
580 from mcp_server_langgraph.core.security import sanitize_for_logging
582 logger.info(
583 "Session created in Redis",
584 extra=sanitize_for_logging({"session_id": session_id, "user_id": user_id, "ttl_seconds": ttl}),
585 )
587 return session_id
589 async def get(self, session_id: str) -> SessionData | None:
590 """Get session from Redis"""
591 with tracer.start_as_current_span("session.get") as span:
592 span.set_attribute("session.id", session_id)
594 session_key = f"session:{session_id}"
595 data = await self.redis.hgetall(session_key)
597 if not data:
598 return None
600 # Convert Redis data to SessionData (Pydantic validates automatically)
601 # Parse metadata safely using json.loads instead of eval
602 metadata_str = data.get("metadata", "{}")
603 try:
604 metadata = json.loads(metadata_str) if metadata_str else {}
605 except (json.JSONDecodeError, TypeError):
606 metadata = {}
608 session = SessionData(
609 session_id=cast(str, data.get("session_id")),
610 user_id=cast(str, data.get("user_id")),
611 username=cast(str, data.get("username")),
612 roles=data.get("roles", "").split(",") if data.get("roles") else [],
613 metadata=metadata,
614 created_at=cast(str, data.get("created_at")),
615 last_accessed=cast(str, data.get("last_accessed")),
616 expires_at=cast(str, data.get("expires_at")),
617 )
619 # Update last accessed (sliding window)
620 if self.sliding_window: 620 ↛ 623line 620 didn't jump to line 623 because the condition on line 620 was always true
621 await self.redis.hset(session_key, "last_accessed", datetime.now(UTC).isoformat())
623 return session
625 async def update(self, session_id: str, metadata: dict[str, Any]) -> bool:
626 """Update session metadata"""
627 session_key = f"session:{session_id}"
628 exists = await self.redis.exists(session_key)
630 if not exists: 630 ↛ 631line 630 didn't jump to line 631 because the condition on line 630 was never true
631 return False
633 # Get current session to preserve Pydantic validation
634 session = await self.get(session_id)
635 if not session: 635 ↛ 636line 635 didn't jump to line 636 because the condition on line 635 was never true
636 return False
638 # Update metadata on Pydantic model
639 session.metadata.update(metadata)
640 session.last_accessed = datetime.now(UTC).isoformat()
642 # Persist to Redis
643 await self.redis.hset(
644 session_key, mapping={"metadata": json.dumps(session.metadata), "last_accessed": session.last_accessed}
645 )
647 logger.info(f"Session metadata updated in Redis: {session_id}")
648 return True
650 async def refresh(self, session_id: str, ttl_seconds: int | None = None) -> bool:
651 """Refresh session expiration"""
652 session_key = f"session:{session_id}"
653 exists = await self.redis.exists(session_key)
655 if not exists: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true
656 return False
658 ttl = ttl_seconds or self.default_ttl
659 now = datetime.now(UTC)
660 new_expires_at = (now + timedelta(seconds=ttl)).isoformat()
662 await self.redis.hset(session_key, mapping={"last_accessed": now.isoformat(), "expires_at": new_expires_at})
663 await self.redis.expire(session_key, ttl)
665 logger.info(f"Session refreshed in Redis: {session_id}, TTL: {ttl}s")
666 return True
668 async def delete(self, session_id: str) -> bool:
669 """Delete session from Redis"""
670 session_key = f"session:{session_id}"
672 # Get user_id before deleting
673 user_id = await self.redis.hget(session_key, "user_id")
675 # Delete session
676 deleted = await self.redis.delete(session_key)
678 if deleted and user_id:
679 # Remove from user sessions list
680 user_sessions_key = f"user_sessions:{user_id}"
681 await self.redis.lrem(user_sessions_key, 0, session_id)
683 logger.info(f"Session deleted from Redis: {session_id}")
684 return bool(deleted)
686 async def list_user_sessions(self, user_id: str) -> list[SessionData]:
687 """List all active sessions for a user"""
688 user_sessions_key = f"user_sessions:{user_id}"
689 session_ids = await self.redis.lrange(user_sessions_key, 0, -1)
691 sessions = []
692 for session_id in session_ids:
693 if isinstance(session_id, bytes): 693 ↛ 694line 693 didn't jump to line 694 because the condition on line 693 was never true
694 session_id = session_id.decode("utf-8")
695 session = await self.get(session_id)
696 if session: 696 ↛ 692line 696 didn't jump to line 692 because the condition on line 696 was always true
697 sessions.append(session)
699 return sessions
701 async def delete_user_sessions(self, user_id: str) -> int:
702 """Delete all sessions for a user"""
703 user_sessions_key = f"user_sessions:{user_id}"
704 session_ids = await self.redis.lrange(user_sessions_key, 0, -1)
706 count = 0
707 for session_id in session_ids:
708 if isinstance(session_id, bytes): 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 session_id = session_id.decode("utf-8")
710 if await self.delete(session_id): 710 ↛ 707line 710 didn't jump to line 707 because the condition on line 710 was always true
711 count += 1
713 # Delete user sessions list
714 await self.redis.delete(user_sessions_key)
716 logger.info(f"Deleted {count} sessions from Redis for user {user_id}")
717 return count
719 async def get_inactive_sessions(self, cutoff_date: datetime) -> list[SessionData]:
720 """Get sessions that haven't been accessed since cutoff date"""
721 inactive_sessions = []
723 # Scan all session keys
724 cursor = 0
725 while True:
726 cursor, keys = await self.redis.scan(cursor, match="session:*", count=100)
728 for key in keys:
729 if isinstance(key, bytes):
730 key = key.decode("utf-8")
732 session_id = key.replace("session:", "")
733 session = await self.get(session_id)
735 if session:
736 try:
737 # Parse last_accessed timestamp
738 last_accessed = datetime.fromisoformat(session.last_accessed)
740 # Add to inactive list if older than cutoff
741 if last_accessed < cutoff_date:
742 inactive_sessions.append(session)
744 except (ValueError, AttributeError) as e:
745 logger.warning(f"Error parsing session {session_id}: {e}")
746 continue
748 if cursor == 0:
749 break
751 logger.info(f"Found {len(inactive_sessions)} inactive sessions in Redis before {cutoff_date.isoformat()}")
752 return inactive_sessions
754 async def delete_inactive_sessions(self, cutoff_date: datetime) -> int:
755 """Delete sessions that haven't been accessed since cutoff date"""
756 inactive_sessions = await self.get_inactive_sessions(cutoff_date)
757 count = 0
759 for session in inactive_sessions:
760 if await self.delete(session.session_id):
761 count += 1
763 logger.info(f"Deleted {count} inactive sessions from Redis before {cutoff_date.isoformat()}")
764 return count
767def create_session_store(backend: str = "memory", redis_url: str | None = None, **kwargs: Any) -> SessionStore:
768 """
769 Factory function to create session store
771 Args:
772 backend: "memory" or "redis"
773 redis_url: Redis connection URL (required for redis backend)
774 **kwargs: Additional arguments for session store
776 Returns:
777 SessionStore instance
779 Raises:
780 ValueError: If backend is unknown or redis_url missing for redis backend
781 """
782 backend = backend.lower()
784 if backend == "memory":
785 logger.info("Creating InMemorySessionStore")
786 return InMemorySessionStore(**kwargs)
788 elif backend == "redis":
789 if not redis_url: 789 ↛ 790line 789 didn't jump to line 790 because the condition on line 789 was never true
790 msg = "redis_url required for Redis backend"
791 raise ValueError(msg)
793 if not REDIS_AVAILABLE: 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true
794 msg = "Redis not available. Add 'redis[hiredis]>=5.0.0' to pyproject.toml dependencies, then run: uv sync"
795 raise ImportError(msg)
797 logger.info("Creating RedisSessionStore")
798 return RedisSessionStore(redis_url=redis_url, **kwargs)
800 else:
801 msg = f"Unknown session backend: {backend}. Supported: 'memory', 'redis'"
802 raise ValueError(msg)
805# Global session store instance
806_session_store: SessionStore | None = None
809def get_session_store() -> SessionStore:
810 """
811 FastAPI dependency to get the global session store instance
813 Returns:
814 SessionStore instance
816 Warning:
817 If no session store has been registered via set_session_store(), this function
818 creates a default in-memory store. This fallback behavior should only occur
819 during testing or if create_auth_middleware() hasn't been called yet.
821 In production, ensure create_auth_middleware() is called during app startup
822 to register the configured session store (Redis or Memory based on settings).
824 Example:
825 @app.get("/api/sessions")
826 async def list_sessions(session_store: SessionStore = Depends(get_session_store)):
827 # Use session_store
828 pass
829 """
830 global _session_store
832 if _session_store is None:
833 # Create default in-memory session store as fallback
834 # This should only happen during testing or before middleware initialization
835 logger.warning(
836 "Session store not registered globally, using fallback in-memory store. "
837 "This may indicate create_auth_middleware() was not called. "
838 "In production, register the session store via set_session_store()."
839 )
840 _session_store = InMemorySessionStore()
842 return _session_store
845def set_session_store(session_store: SessionStore) -> None:
846 """
847 Set the global session store instance
849 Args:
850 session_store: Session store to use globally
852 Example:
853 # At application startup
854 redis_store = create_session_store("redis", redis_url="redis://localhost:6379")
855 set_session_store(redis_store)
856 """
857 global _session_store
858 _session_store = session_store