Coverage for src / mcp_server_langgraph / resilience / rate_limit.py: 94%

104 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 06:31 +0000

1""" 

2Token bucket rate limiter for pre-emptive rate limiting. 

3 

4Implements the token bucket algorithm to limit request rates before hitting 

5upstream provider limits. This provides smoother request distribution and 

6prevents 429/529 errors. 

7 

8Key features: 

9- Burst capacity: Allows short bursts up to bucket capacity 

10- Steady rate: Limits sustained throughput to refill rate 

11- Async support: Non-blocking acquire with wait capability 

12- Provider-aware: Different limits per LLM provider 

13 

14See ADR-0026 for design rationale. 

15 

16Reference: https://en.wikipedia.org/wiki/Token_bucket 

17""" 

18 

19import asyncio 

20import functools 

21import logging 

22import os 

23import threading 

24import time 

25from collections.abc import Callable 

26from typing import ParamSpec, TypeVar 

27 

28from opentelemetry import trace 

29 

30logger = logging.getLogger(__name__) 

31tracer = trace.get_tracer(__name__) 

32 

33P = ParamSpec("P") 

34T = TypeVar("T") 

35 

36 

37# ============================================================================= 

38# Provider Rate Limit Configuration 

39# ============================================================================= 

40# These are RPM (requests per minute) limits based on provider documentation. 

41# Formula: refill_rate = RPM / 60 (converts to requests per second) 

42# Burst capacity = refill_rate * typical_request_duration (allows ~10s worth of requests) 

43 

44PROVIDER_RATE_LIMITS: dict[str, dict[str, float]] = { 

45 # Anthropic Claude (direct API): Tier 1 default (50 RPM) 

46 # Higher tiers can be set via RATE_LIMIT_ANTHROPIC_RPM env var 

47 "anthropic": {"rpm": 50.0, "burst_seconds": 10.0}, 

48 # OpenAI GPT-4: Moderate tier (~500 RPM) 

49 "openai": {"rpm": 500.0, "burst_seconds": 10.0}, 

50 # Google Gemini via AI Studio 

51 "google": {"rpm": 300.0, "burst_seconds": 10.0}, 

52 # Google Vertex AI - Gemini models (vishnu-sandbox-20250310): 

53 # Gemini 2.5 Flash/Pro: 3.4M TPM → ~1400 RPM (2500 tokens/req) 

54 # Use 600 RPM to leave headroom 

55 "vertex_ai": {"rpm": 600.0, "burst_seconds": 15.0}, 

56 # Anthropic Claude via Vertex AI (MaaS) - vishnu-sandbox-20250310: 

57 # Claude Opus 4.5: 1,200 RPM, Sonnet 4.5: higher, Haiku 4.5: higher 

58 # Use Opus 4.5 as baseline (most restrictive 4.5 model) 

59 "vertex_ai_anthropic": {"rpm": 1000.0, "burst_seconds": 10.0}, 

60 # AWS Bedrock: Claude Opus is most restrictive (50 RPM) 

61 "bedrock": {"rpm": 50.0, "burst_seconds": 10.0}, 

62 # Ollama: Local model, no upstream limits 

63 "ollama": {"rpm": 1000.0, "burst_seconds": 30.0}, 

64 # Azure OpenAI: Similar to OpenAI 

65 "azure": {"rpm": 500.0, "burst_seconds": 10.0}, 

66} 

67 

68DEFAULT_RATE_LIMIT: dict[str, float] = {"rpm": 60.0, "burst_seconds": 10.0} 

69 

70 

71def get_provider_rate_config(provider: str) -> dict[str, float]: 

72 """ 

73 Get rate limit configuration for a provider. 

74 

75 Checks environment variable first (RATE_LIMIT_<PROVIDER>_RPM), 

76 then falls back to default for that provider. 

77 

78 Args: 

79 provider: Provider name (e.g., "anthropic", "openai") 

80 

81 Returns: 

82 Dict with 'rpm' and 'burst_seconds' keys 

83 """ 

84 # Check for environment variable override 

85 env_var = f"RATE_LIMIT_{provider.upper()}_RPM" 

86 env_value = os.getenv(env_var) 

87 

88 config = PROVIDER_RATE_LIMITS.get(provider, DEFAULT_RATE_LIMIT).copy() 

89 

90 if env_value is not None: 

91 try: 

92 config["rpm"] = float(env_value) 

93 except ValueError: 

94 logger.warning( 

95 f"Invalid {env_var}={env_value}, using default", 

96 extra={"env_var": env_var, "invalid_value": env_value}, 

97 ) 

98 

99 return config 

100 

101 

102# ============================================================================= 

103# Token Bucket Implementation 

104# ============================================================================= 

105 

106 

107class TokenBucket: 

108 """ 

109 Thread-safe token bucket rate limiter. 

110 

111 The token bucket algorithm allows bursting up to a maximum capacity, 

112 then limits sustained throughput to the refill rate. 

113 

114 Attributes: 

115 capacity: Maximum number of tokens in bucket (burst limit) 

116 refill_rate: Tokens added per second (sustained rate) 

117 tokens: Current number of available tokens 

118 

119 Example: 

120 >>> bucket = TokenBucket(capacity=10, refill_rate=1.0) 

121 >>> if bucket.try_acquire(): 

122 ... await call_api() 

123 >>> # Or with async waiting: 

124 >>> await bucket.acquire() # Waits if needed 

125 >>> await call_api() 

126 """ 

127 

128 def __init__(self, capacity: float, refill_rate: float): 

129 """ 

130 Initialize token bucket. 

131 

132 Args: 

133 capacity: Maximum tokens (burst capacity) 

134 refill_rate: Tokens per second (sustained rate) 

135 """ 

136 self.capacity = capacity 

137 self.refill_rate = refill_rate 

138 self._tokens = capacity # Start full 

139 self._last_refill = time.monotonic() 

140 self._lock = threading.Lock() 

141 

142 @property 

143 def tokens(self) -> float: 

144 """Get current token count (after refill).""" 

145 self._refill() 

146 return self._tokens 

147 

148 def _refill(self) -> None: 

149 """Refill tokens based on elapsed time.""" 

150 with self._lock: 

151 now = time.monotonic() 

152 elapsed = now - self._last_refill 

153 self._last_refill = now 

154 

155 # Add tokens based on elapsed time 

156 self._tokens = min(self.capacity, self._tokens + elapsed * self.refill_rate) 

157 

158 def try_acquire(self, tokens: float = 1) -> bool: 

159 """ 

160 Try to acquire tokens without blocking. 

161 

162 Args: 

163 tokens: Number of tokens to acquire (default: 1) 

164 

165 Returns: 

166 True if tokens acquired, False otherwise 

167 """ 

168 # Can never acquire more than capacity 

169 if tokens > self.capacity: 

170 return False 

171 

172 self._refill() 

173 

174 with self._lock: 

175 if self._tokens >= tokens: 

176 self._tokens -= tokens 

177 return True 

178 return False 

179 

180 async def acquire(self, tokens: float = 1, timeout: float | None = None) -> None: 

181 """ 

182 Acquire tokens, waiting if necessary. 

183 

184 Args: 

185 tokens: Number of tokens to acquire (default: 1) 

186 timeout: Maximum time to wait in seconds (None = wait forever) 

187 

188 Raises: 

189 asyncio.TimeoutError: If cannot acquire within timeout 

190 ValueError: If tokens > capacity 

191 """ 

192 if tokens > self.capacity: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 raise ValueError(f"Cannot acquire {tokens} tokens (capacity: {self.capacity})") 

194 

195 start = time.monotonic() 

196 

197 while True: 

198 if self.try_acquire(tokens): 

199 return 

200 

201 # Calculate wait time for enough tokens 

202 self._refill() 

203 tokens_needed = tokens - self._tokens 

204 wait_time = tokens_needed / self.refill_rate 

205 

206 # Check timeout 

207 if timeout is not None: 

208 elapsed = time.monotonic() - start 

209 remaining = timeout - elapsed 

210 if remaining <= 0: 

211 raise TimeoutError(f"Could not acquire {tokens} tokens within {timeout}s") 

212 wait_time = min(wait_time, remaining) 

213 

214 # Wait for tokens to refill 

215 await asyncio.sleep(max(0.01, wait_time)) # Min 10ms to avoid busy loop 

216 

217 

218# ============================================================================= 

219# Provider Token Bucket Factory 

220# ============================================================================= 

221 

222_provider_token_buckets: dict[str, TokenBucket] = {} 

223_bucket_lock = threading.Lock() 

224 

225 

226def get_provider_token_bucket(provider: str) -> TokenBucket: 

227 """ 

228 Get or create a token bucket for a specific LLM provider. 

229 

230 Uses provider-specific rate limits from configuration or environment. 

231 

232 Args: 

233 provider: LLM provider name (e.g., "anthropic", "openai") 

234 

235 Returns: 

236 TokenBucket configured for the provider's rate limits 

237 """ 

238 with _bucket_lock: 

239 if provider in _provider_token_buckets: 

240 return _provider_token_buckets[provider] 

241 

242 config = get_provider_rate_config(provider) 

243 rpm = config["rpm"] 

244 burst_seconds = config["burst_seconds"] 

245 

246 # Convert RPM to tokens per second 

247 refill_rate = rpm / 60.0 

248 

249 # Burst capacity = refill_rate * burst_seconds 

250 capacity = refill_rate * burst_seconds 

251 

252 bucket = TokenBucket(capacity=capacity, refill_rate=refill_rate) 

253 _provider_token_buckets[provider] = bucket 

254 

255 logger.info( 

256 f"Created token bucket for {provider}", 

257 extra={ 

258 "provider": provider, 

259 "rpm": rpm, 

260 "refill_rate": refill_rate, 

261 "capacity": capacity, 

262 }, 

263 ) 

264 

265 return bucket 

266 

267 

268def reset_all_token_buckets() -> None: 

269 """ 

270 Reset all token buckets (for testing). 

271 

272 Warning: Only use for testing! In production, buckets should not be reset. 

273 """ 

274 with _bucket_lock: 

275 _provider_token_buckets.clear() 

276 logger.warning("All token buckets reset (testing only)") 

277 

278 

279# ============================================================================= 

280# Rate Limiter Decorator 

281# ============================================================================= 

282 

283 

284def rate_limited( 

285 provider: str | None = None, 

286 capacity: float | None = None, 

287 refill_rate: float | None = None, 

288 timeout: float | None = 30.0, 

289) -> Callable[[Callable[P, T]], Callable[P, T]]: 

290 """ 

291 Decorator to rate-limit a function using token bucket. 

292 

293 Args: 

294 provider: LLM provider name (uses provider-specific limits) 

295 capacity: Override bucket capacity (burst limit) 

296 refill_rate: Override refill rate (tokens/sec) 

297 timeout: Max wait time for tokens (None = wait forever) 

298 

299 Usage: 

300 @rate_limited(provider="anthropic") 

301 async def call_claude(prompt: str) -> str: 

302 return await client.generate(prompt) 

303 

304 # Or with explicit limits: 

305 @rate_limited(capacity=5, refill_rate=1.0) 

306 async def my_function(): 

307 pass 

308 """ 

309 

310 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

311 # Get or create bucket 

312 if provider: 

313 bucket = get_provider_token_bucket(provider) 

314 elif capacity is not None and refill_rate is not None: 314 ↛ 318line 314 didn't jump to line 318 because the condition on line 314 was always true

315 # Create a custom bucket for this decorator 

316 bucket = TokenBucket(capacity=capacity, refill_rate=refill_rate) 

317 else: 

318 raise ValueError("Must specify either 'provider' or both 'capacity' and 'refill_rate'") 

319 

320 @functools.wraps(func) 

321 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

322 with tracer.start_as_current_span( 

323 f"rate_limit.{func.__name__}", 

324 attributes={ 

325 "rate_limit.provider": provider or "custom", 

326 "rate_limit.capacity": bucket.capacity, 

327 "rate_limit.refill_rate": bucket.refill_rate, 

328 "rate_limit.tokens_available": bucket.tokens, 

329 }, 

330 ) as span: 

331 # Acquire token (may wait) 

332 await bucket.acquire(timeout=timeout) 

333 span.set_attribute("rate_limit.acquired", True) 

334 

335 return await func(*args, **kwargs) # type: ignore[misc, no-any-return] 

336 

337 # Only works with async functions 

338 if not asyncio.iscoroutinefunction(func): 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true

339 raise TypeError(f"@rate_limited can only be applied to async functions, got {func.__name__}") 

340 

341 return async_wrapper # type: ignore[return-value] 

342 

343 return decorator