Coverage for src / mcp_server_langgraph / middleware / rate_limiter.py: 93%

85 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Rate limiting middleware for FastAPI with tiered limits. 

3 

4Provides DoS protection with: 

5- Tiered rate limits (anonymous, free, standard, premium, enterprise) 

6- Redis-backed distributed rate limiting 

7- User-based and IP-based limiting 

8- Automatic X-RateLimit-* headers 

9- Graceful degradation (fail-open if Redis is down) 

10 

11See ADR-0027 for design rationale. 

12""" 

13 

14from collections.abc import Callable 

15from typing import Any 

16 

17from fastapi import Request 

18from slowapi import Limiter 

19from slowapi.errors import RateLimitExceeded 

20from slowapi.util import get_remote_address 

21 

22from mcp_server_langgraph.core.config import settings 

23from mcp_server_langgraph.observability.telemetry import logger 

24 

25# Rate limit tiers (requests per minute) 

26RATE_LIMITS = { 

27 "anonymous": "10/minute", 

28 "free": "60/minute", 

29 "standard": "300/minute", 

30 "premium": "1000/minute", 

31 "enterprise": "999999/minute", # Effectively unlimited 

32} 

33 

34# Endpoint-specific rate limits (override tier limits) 

35ENDPOINT_RATE_LIMITS = { 

36 "auth_login": "10/minute", # Prevent brute force 

37 "auth_register": "5/minute", # Prevent spam registration 

38 "llm_chat": "30/minute", # LLM-heavy endpoints (cost control) 

39 "search": "100/minute", # Search endpoints 

40 "read": "200/minute", # Read-only endpoints 

41} 

42 

43 

44def get_user_id_from_jwt(request: Request) -> str | None: 

45 """ 

46 Extract user ID from JWT token in request. 

47 

48 IMPORTANT: This function is synchronous and relies on request.state.user 

49 being set by AuthMiddleware. It does NOT attempt async token verification 

50 to avoid event loop issues in slowapi's synchronous context. 

51 

52 Args: 

53 request: FastAPI request object 

54 

55 Returns: 

56 User ID if JWT is valid, None otherwise 

57 """ 

58 try: 

59 # Check if auth middleware already validated user 

60 if hasattr(request.state, "user") and request.state.user: 

61 return request.state.user.get("user_id") # type: ignore[no-any-return] 

62 

63 # No fallback to token verification - auth middleware must run first 

64 # This avoids event loop issues with slowapi's synchronous key_func 

65 return None 

66 

67 except Exception as e: 

68 logger.debug(f"Failed to extract user ID from JWT: {e}") 

69 return None 

70 

71 

72def get_user_tier(request: Request) -> str: 

73 """ 

74 Determine user tier from JWT claims. 

75 

76 IMPORTANT: This function is synchronous and relies on request.state.user 

77 being set by AuthMiddleware. It does NOT attempt async token verification 

78 to avoid event loop issues in slowapi's synchronous context. 

79 

80 Args: 

81 request: FastAPI request object 

82 

83 Returns: 

84 User tier (anonymous, free, standard, premium, enterprise) 

85 """ 

86 try: 

87 # Check if auth middleware already validated user 

88 if hasattr(request.state, "user") and request.state.user: 

89 user = request.state.user 

90 # Extract tier from roles or dedicated tier field 

91 roles = user.get("roles", []) 

92 # Check for premium/enterprise roles 

93 if "enterprise" in roles: 

94 return "enterprise" 

95 elif "premium" in roles: 

96 return "premium" 

97 elif "standard" in roles: 

98 return "standard" 

99 elif "free" in roles: 

100 return "free" 

101 # Fallback to checking tier field 

102 tier = user.get("tier") or user.get("plan") or "free" 

103 return tier if tier in RATE_LIMITS else "free" 

104 

105 # No fallback to token verification - auth middleware must run first 

106 # This avoids event loop issues with slowapi's synchronous default_limits 

107 return "anonymous" 

108 

109 except Exception as e: 

110 logger.debug(f"Failed to extract tier from JWT: {e}") 

111 return "anonymous" 

112 

113 

114def get_rate_limit_key(request: Request) -> str: 

115 """ 

116 Determine rate limit key (user ID > IP > global). 

117 

118 Hierarchy: 

119 1. User ID from JWT (most specific) 

120 2. IP address (less specific) 

121 3. Global anonymous (least specific) 

122 

123 Args: 

124 request: FastAPI request object 

125 

126 Returns: 

127 Rate limit key string 

128 """ 

129 # Try to get user ID from JWT 

130 if user_id := get_user_id_from_jwt(request): 

131 return f"user:{user_id}" 

132 

133 # Fall back to IP address 

134 ip_address = get_remote_address(request) 

135 if ip_address: 

136 return f"ip:{ip_address}" 

137 

138 # Last resort: global anonymous 

139 return "global:anonymous" 

140 

141 

142def get_rate_limit_for_tier(tier: str) -> str: 

143 """ 

144 Get rate limit string for a tier. 

145 

146 Args: 

147 tier: User tier name 

148 

149 Returns: 

150 Rate limit string (e.g., "60/minute") 

151 """ 

152 return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) 

153 

154 

155def get_dynamic_limit(request: Request) -> str: 

156 """ 

157 Get dynamic rate limit based on user tier. 

158 

159 Args: 

160 request: FastAPI request object 

161 

162 Returns: 

163 Rate limit string for the user's tier 

164 """ 

165 tier = get_user_tier(request) 

166 limit = get_rate_limit_for_tier(tier) 

167 

168 logger.debug( 

169 "Rate limit determined", 

170 extra={ 

171 "tier": tier, 

172 "limit": limit, 

173 "key": get_rate_limit_key(request), 

174 }, 

175 ) 

176 

177 return limit 

178 

179 

180# Configure Redis storage for distributed rate limiting 

181def get_redis_storage_uri() -> str: 

182 """ 

183 Get Redis storage URI for rate limiting. 

184 

185 Returns: 

186 Redis URI string 

187 

188 Fallback: If Redis is unavailable, slowapi will use in-memory storage 

189 """ 

190 redis_host = getattr(settings, "redis_host", "localhost") 

191 redis_port = getattr(settings, "redis_port", 6379) 

192 redis_db = getattr(settings, "redis_rate_limit_db", 3) # DB 3 for rate limiting 

193 

194 return f"redis://{redis_host}:{redis_port}/{redis_db}" 

195 

196 

197# Create limiter instance 

198limiter = Limiter( 

199 key_func=get_rate_limit_key, 

200 default_limits=[get_dynamic_limit], # Dynamic limits based on tier 

201 storage_uri=get_redis_storage_uri(), 

202 strategy="fixed-window", # fixed-window, moving-window 

203 headers_enabled=True, # Add X-RateLimit-* headers 

204 swallow_errors=True, # Fail-open: allow requests if rate limiting fails 

205) 

206 

207 

208async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): # type: ignore[no-untyped-def] 

209 """ 

210 Custom handler for rate limit exceeded errors. 

211 

212 Provides structured error response with retry information. 

213 

214 Args: 

215 request: FastAPI request object 

216 exc: RateLimitExceeded exception 

217 

218 Returns: 

219 JSONResponse with rate limit error 

220 """ 

221 from fastapi.responses import JSONResponse 

222 

223 from mcp_server_langgraph.core.exceptions import RateLimitExceededError 

224 

225 # Get user tier for error message 

226 tier = get_user_tier(request) 

227 limit = get_rate_limit_for_tier(tier) 

228 

229 # Create custom exception 

230 rate_limit_error = RateLimitExceededError( 

231 message=f"Rate limit exceeded for tier '{tier}' ({limit})", 

232 metadata={ 

233 "tier": tier, 

234 "limit": limit, 

235 "retry_after": 60, # Default retry after 60 seconds 

236 }, 

237 ) 

238 

239 # Emit metric 

240 from mcp_server_langgraph.observability.telemetry import error_counter 

241 

242 try: 

243 error_counter.add( 

244 1, 

245 attributes={ 

246 "error_code": rate_limit_error.error_code, 

247 "tier": tier, 

248 "endpoint": request.url.path, 

249 }, 

250 ) 

251 except Exception: 

252 pass # Don't let metrics failure break error handling 

253 

254 # Log rate limit violation 

255 logger.warning( 

256 "Rate limit exceeded", 

257 extra={ 

258 "tier": tier, 

259 "limit": limit, 

260 "key": get_rate_limit_key(request), 

261 "endpoint": request.url.path, 

262 "method": request.method, 

263 }, 

264 ) 

265 

266 # Return structured error response 

267 return JSONResponse( 

268 status_code=429, 

269 content=rate_limit_error.to_dict(), 

270 headers={ 

271 "Retry-After": "60", 

272 "X-RateLimit-Limit": str(limit.split("/")[0]), 

273 "X-RateLimit-Remaining": "0", 

274 }, 

275 ) 

276 

277 

278def setup_rate_limiting(app: Any) -> None: 

279 """ 

280 Setup rate limiting for FastAPI application. 

281 

282 Args: 

283 app: FastAPI application instance 

284 

285 Usage: 

286 from fastapi import FastAPI 

287 from mcp_server_langgraph.middleware.rate_limiter import setup_rate_limiting 

288 

289 app = FastAPI() 

290 setup_rate_limiting(app) 

291 """ 

292 # Add limiter to app state 

293 app.state.limiter = limiter 

294 

295 # Register custom exception handler 

296 app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) 

297 

298 try: 

299 logger.info( 

300 "Rate limiting configured", 

301 extra={ 

302 "storage": get_redis_storage_uri(), 

303 "strategy": "fixed-window", 

304 "tiers": list(RATE_LIMITS.keys()), 

305 }, 

306 ) 

307 except RuntimeError: 

308 # Observability not initialized yet, skip logging 

309 pass 

310 

311 

312# Decorators for endpoint-specific rate limits 

313def rate_limit_for_auth(func: Callable[..., Any]) -> Callable[..., Any]: 

314 """Rate limit decorator for authentication endpoints""" 

315 return limiter.limit(ENDPOINT_RATE_LIMITS["auth_login"])(func) # type: ignore[no-any-return] 

316 

317 

318def rate_limit_for_llm(func: Callable[..., Any]) -> Callable[..., Any]: 

319 """Rate limit decorator for LLM endpoints""" 

320 return limiter.limit(ENDPOINT_RATE_LIMITS["llm_chat"])(func) # type: ignore[no-any-return] 

321 

322 

323def rate_limit_for_search(func: Callable[..., Any]) -> Callable[..., Any]: 

324 """Rate limit decorator for search endpoints""" 

325 return limiter.limit(ENDPOINT_RATE_LIMITS["search"])(func) # type: ignore[no-any-return] 

326 

327 

328def exempt_from_rate_limit(func: Callable[..., Any]) -> Callable[..., Any]: 

329 """Exempt endpoint from rate limiting (health checks, metrics)""" 

330 return limiter.exempt(func) # type: ignore[no-any-return, no-untyped-call]