Coverage for src / mcp_server_langgraph / middleware / rate_limiter.py: 93%
85 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Rate limiting middleware for FastAPI with tiered limits.
4Provides DoS protection with:
5- Tiered rate limits (anonymous, free, standard, premium, enterprise)
6- Redis-backed distributed rate limiting
7- User-based and IP-based limiting
8- Automatic X-RateLimit-* headers
9- Graceful degradation (fail-open if Redis is down)
11See ADR-0027 for design rationale.
12"""
14from collections.abc import Callable
15from typing import Any
17from fastapi import Request
18from slowapi import Limiter
19from slowapi.errors import RateLimitExceeded
20from slowapi.util import get_remote_address
22from mcp_server_langgraph.core.config import settings
23from mcp_server_langgraph.observability.telemetry import logger
25# Rate limit tiers (requests per minute)
26RATE_LIMITS = {
27 "anonymous": "10/minute",
28 "free": "60/minute",
29 "standard": "300/minute",
30 "premium": "1000/minute",
31 "enterprise": "999999/minute", # Effectively unlimited
32}
34# Endpoint-specific rate limits (override tier limits)
35ENDPOINT_RATE_LIMITS = {
36 "auth_login": "10/minute", # Prevent brute force
37 "auth_register": "5/minute", # Prevent spam registration
38 "llm_chat": "30/minute", # LLM-heavy endpoints (cost control)
39 "search": "100/minute", # Search endpoints
40 "read": "200/minute", # Read-only endpoints
41}
44def get_user_id_from_jwt(request: Request) -> str | None:
45 """
46 Extract user ID from JWT token in request.
48 IMPORTANT: This function is synchronous and relies on request.state.user
49 being set by AuthMiddleware. It does NOT attempt async token verification
50 to avoid event loop issues in slowapi's synchronous context.
52 Args:
53 request: FastAPI request object
55 Returns:
56 User ID if JWT is valid, None otherwise
57 """
58 try:
59 # Check if auth middleware already validated user
60 if hasattr(request.state, "user") and request.state.user:
61 return request.state.user.get("user_id") # type: ignore[no-any-return]
63 # No fallback to token verification - auth middleware must run first
64 # This avoids event loop issues with slowapi's synchronous key_func
65 return None
67 except Exception as e:
68 logger.debug(f"Failed to extract user ID from JWT: {e}")
69 return None
72def get_user_tier(request: Request) -> str:
73 """
74 Determine user tier from JWT claims.
76 IMPORTANT: This function is synchronous and relies on request.state.user
77 being set by AuthMiddleware. It does NOT attempt async token verification
78 to avoid event loop issues in slowapi's synchronous context.
80 Args:
81 request: FastAPI request object
83 Returns:
84 User tier (anonymous, free, standard, premium, enterprise)
85 """
86 try:
87 # Check if auth middleware already validated user
88 if hasattr(request.state, "user") and request.state.user:
89 user = request.state.user
90 # Extract tier from roles or dedicated tier field
91 roles = user.get("roles", [])
92 # Check for premium/enterprise roles
93 if "enterprise" in roles:
94 return "enterprise"
95 elif "premium" in roles:
96 return "premium"
97 elif "standard" in roles:
98 return "standard"
99 elif "free" in roles:
100 return "free"
101 # Fallback to checking tier field
102 tier = user.get("tier") or user.get("plan") or "free"
103 return tier if tier in RATE_LIMITS else "free"
105 # No fallback to token verification - auth middleware must run first
106 # This avoids event loop issues with slowapi's synchronous default_limits
107 return "anonymous"
109 except Exception as e:
110 logger.debug(f"Failed to extract tier from JWT: {e}")
111 return "anonymous"
114def get_rate_limit_key(request: Request) -> str:
115 """
116 Determine rate limit key (user ID > IP > global).
118 Hierarchy:
119 1. User ID from JWT (most specific)
120 2. IP address (less specific)
121 3. Global anonymous (least specific)
123 Args:
124 request: FastAPI request object
126 Returns:
127 Rate limit key string
128 """
129 # Try to get user ID from JWT
130 if user_id := get_user_id_from_jwt(request):
131 return f"user:{user_id}"
133 # Fall back to IP address
134 ip_address = get_remote_address(request)
135 if ip_address:
136 return f"ip:{ip_address}"
138 # Last resort: global anonymous
139 return "global:anonymous"
142def get_rate_limit_for_tier(tier: str) -> str:
143 """
144 Get rate limit string for a tier.
146 Args:
147 tier: User tier name
149 Returns:
150 Rate limit string (e.g., "60/minute")
151 """
152 return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
155def get_dynamic_limit(request: Request) -> str:
156 """
157 Get dynamic rate limit based on user tier.
159 Args:
160 request: FastAPI request object
162 Returns:
163 Rate limit string for the user's tier
164 """
165 tier = get_user_tier(request)
166 limit = get_rate_limit_for_tier(tier)
168 logger.debug(
169 "Rate limit determined",
170 extra={
171 "tier": tier,
172 "limit": limit,
173 "key": get_rate_limit_key(request),
174 },
175 )
177 return limit
180# Configure Redis storage for distributed rate limiting
181def get_redis_storage_uri() -> str:
182 """
183 Get Redis storage URI for rate limiting.
185 Returns:
186 Redis URI string
188 Fallback: If Redis is unavailable, slowapi will use in-memory storage
189 """
190 redis_host = getattr(settings, "redis_host", "localhost")
191 redis_port = getattr(settings, "redis_port", 6379)
192 redis_db = getattr(settings, "redis_rate_limit_db", 3) # DB 3 for rate limiting
194 return f"redis://{redis_host}:{redis_port}/{redis_db}"
197# Create limiter instance
198limiter = Limiter(
199 key_func=get_rate_limit_key,
200 default_limits=[get_dynamic_limit], # Dynamic limits based on tier
201 storage_uri=get_redis_storage_uri(),
202 strategy="fixed-window", # fixed-window, moving-window
203 headers_enabled=True, # Add X-RateLimit-* headers
204 swallow_errors=True, # Fail-open: allow requests if rate limiting fails
205)
208async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): # type: ignore[no-untyped-def]
209 """
210 Custom handler for rate limit exceeded errors.
212 Provides structured error response with retry information.
214 Args:
215 request: FastAPI request object
216 exc: RateLimitExceeded exception
218 Returns:
219 JSONResponse with rate limit error
220 """
221 from fastapi.responses import JSONResponse
223 from mcp_server_langgraph.core.exceptions import RateLimitExceededError
225 # Get user tier for error message
226 tier = get_user_tier(request)
227 limit = get_rate_limit_for_tier(tier)
229 # Create custom exception
230 rate_limit_error = RateLimitExceededError(
231 message=f"Rate limit exceeded for tier '{tier}' ({limit})",
232 metadata={
233 "tier": tier,
234 "limit": limit,
235 "retry_after": 60, # Default retry after 60 seconds
236 },
237 )
239 # Emit metric
240 from mcp_server_langgraph.observability.telemetry import error_counter
242 try:
243 error_counter.add(
244 1,
245 attributes={
246 "error_code": rate_limit_error.error_code,
247 "tier": tier,
248 "endpoint": request.url.path,
249 },
250 )
251 except Exception:
252 pass # Don't let metrics failure break error handling
254 # Log rate limit violation
255 logger.warning(
256 "Rate limit exceeded",
257 extra={
258 "tier": tier,
259 "limit": limit,
260 "key": get_rate_limit_key(request),
261 "endpoint": request.url.path,
262 "method": request.method,
263 },
264 )
266 # Return structured error response
267 return JSONResponse(
268 status_code=429,
269 content=rate_limit_error.to_dict(),
270 headers={
271 "Retry-After": "60",
272 "X-RateLimit-Limit": str(limit.split("/")[0]),
273 "X-RateLimit-Remaining": "0",
274 },
275 )
278def setup_rate_limiting(app: Any) -> None:
279 """
280 Setup rate limiting for FastAPI application.
282 Args:
283 app: FastAPI application instance
285 Usage:
286 from fastapi import FastAPI
287 from mcp_server_langgraph.middleware.rate_limiter import setup_rate_limiting
289 app = FastAPI()
290 setup_rate_limiting(app)
291 """
292 # Add limiter to app state
293 app.state.limiter = limiter
295 # Register custom exception handler
296 app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
298 try:
299 logger.info(
300 "Rate limiting configured",
301 extra={
302 "storage": get_redis_storage_uri(),
303 "strategy": "fixed-window",
304 "tiers": list(RATE_LIMITS.keys()),
305 },
306 )
307 except RuntimeError:
308 # Observability not initialized yet, skip logging
309 pass
312# Decorators for endpoint-specific rate limits
313def rate_limit_for_auth(func: Callable[..., Any]) -> Callable[..., Any]:
314 """Rate limit decorator for authentication endpoints"""
315 return limiter.limit(ENDPOINT_RATE_LIMITS["auth_login"])(func) # type: ignore[no-any-return]
318def rate_limit_for_llm(func: Callable[..., Any]) -> Callable[..., Any]:
319 """Rate limit decorator for LLM endpoints"""
320 return limiter.limit(ENDPOINT_RATE_LIMITS["llm_chat"])(func) # type: ignore[no-any-return]
323def rate_limit_for_search(func: Callable[..., Any]) -> Callable[..., Any]:
324 """Rate limit decorator for search endpoints"""
325 return limiter.limit(ENDPOINT_RATE_LIMITS["search"])(func) # type: ignore[no-any-return]
328def exempt_from_rate_limit(func: Callable[..., Any]) -> Callable[..., Any]:
329 """Exempt endpoint from rate limiting (health checks, metrics)"""
330 return limiter.exempt(func) # type: ignore[no-any-return, no-untyped-call]