Coverage for src / mcp_server_langgraph / resilience / rate_limit.py: 94%
104 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
1"""
2Token bucket rate limiter for pre-emptive rate limiting.
4Implements the token bucket algorithm to limit request rates before hitting
5upstream provider limits. This provides smoother request distribution and
6prevents 429/529 errors.
8Key features:
9- Burst capacity: Allows short bursts up to bucket capacity
10- Steady rate: Limits sustained throughput to refill rate
11- Async support: Non-blocking acquire with wait capability
12- Provider-aware: Different limits per LLM provider
14See ADR-0026 for design rationale.
16Reference: https://en.wikipedia.org/wiki/Token_bucket
17"""
19import asyncio
20import functools
21import logging
22import os
23import threading
24import time
25from collections.abc import Callable
26from typing import ParamSpec, TypeVar
28from opentelemetry import trace
30logger = logging.getLogger(__name__)
31tracer = trace.get_tracer(__name__)
33P = ParamSpec("P")
34T = TypeVar("T")
37# =============================================================================
38# Provider Rate Limit Configuration
39# =============================================================================
40# These are RPM (requests per minute) limits based on provider documentation.
41# Formula: refill_rate = RPM / 60 (converts to requests per second)
42# Burst capacity = refill_rate * typical_request_duration (allows ~10s worth of requests)
44PROVIDER_RATE_LIMITS: dict[str, dict[str, float]] = {
45 # Anthropic Claude (direct API): Tier 1 default (50 RPM)
46 # Higher tiers can be set via RATE_LIMIT_ANTHROPIC_RPM env var
47 "anthropic": {"rpm": 50.0, "burst_seconds": 10.0},
48 # OpenAI GPT-4: Moderate tier (~500 RPM)
49 "openai": {"rpm": 500.0, "burst_seconds": 10.0},
50 # Google Gemini via AI Studio
51 "google": {"rpm": 300.0, "burst_seconds": 10.0},
52 # Google Vertex AI - Gemini models (vishnu-sandbox-20250310):
53 # Gemini 2.5 Flash/Pro: 3.4M TPM → ~1400 RPM (2500 tokens/req)
54 # Use 600 RPM to leave headroom
55 "vertex_ai": {"rpm": 600.0, "burst_seconds": 15.0},
56 # Anthropic Claude via Vertex AI (MaaS) - vishnu-sandbox-20250310:
57 # Claude Opus 4.5: 1,200 RPM, Sonnet 4.5: higher, Haiku 4.5: higher
58 # Use Opus 4.5 as baseline (most restrictive 4.5 model)
59 "vertex_ai_anthropic": {"rpm": 1000.0, "burst_seconds": 10.0},
60 # AWS Bedrock: Claude Opus is most restrictive (50 RPM)
61 "bedrock": {"rpm": 50.0, "burst_seconds": 10.0},
62 # Ollama: Local model, no upstream limits
63 "ollama": {"rpm": 1000.0, "burst_seconds": 30.0},
64 # Azure OpenAI: Similar to OpenAI
65 "azure": {"rpm": 500.0, "burst_seconds": 10.0},
66}
68DEFAULT_RATE_LIMIT: dict[str, float] = {"rpm": 60.0, "burst_seconds": 10.0}
71def get_provider_rate_config(provider: str) -> dict[str, float]:
72 """
73 Get rate limit configuration for a provider.
75 Checks environment variable first (RATE_LIMIT_<PROVIDER>_RPM),
76 then falls back to default for that provider.
78 Args:
79 provider: Provider name (e.g., "anthropic", "openai")
81 Returns:
82 Dict with 'rpm' and 'burst_seconds' keys
83 """
84 # Check for environment variable override
85 env_var = f"RATE_LIMIT_{provider.upper()}_RPM"
86 env_value = os.getenv(env_var)
88 config = PROVIDER_RATE_LIMITS.get(provider, DEFAULT_RATE_LIMIT).copy()
90 if env_value is not None:
91 try:
92 config["rpm"] = float(env_value)
93 except ValueError:
94 logger.warning(
95 f"Invalid {env_var}={env_value}, using default",
96 extra={"env_var": env_var, "invalid_value": env_value},
97 )
99 return config
102# =============================================================================
103# Token Bucket Implementation
104# =============================================================================
107class TokenBucket:
108 """
109 Thread-safe token bucket rate limiter.
111 The token bucket algorithm allows bursting up to a maximum capacity,
112 then limits sustained throughput to the refill rate.
114 Attributes:
115 capacity: Maximum number of tokens in bucket (burst limit)
116 refill_rate: Tokens added per second (sustained rate)
117 tokens: Current number of available tokens
119 Example:
120 >>> bucket = TokenBucket(capacity=10, refill_rate=1.0)
121 >>> if bucket.try_acquire():
122 ... await call_api()
123 >>> # Or with async waiting:
124 >>> await bucket.acquire() # Waits if needed
125 >>> await call_api()
126 """
128 def __init__(self, capacity: float, refill_rate: float):
129 """
130 Initialize token bucket.
132 Args:
133 capacity: Maximum tokens (burst capacity)
134 refill_rate: Tokens per second (sustained rate)
135 """
136 self.capacity = capacity
137 self.refill_rate = refill_rate
138 self._tokens = capacity # Start full
139 self._last_refill = time.monotonic()
140 self._lock = threading.Lock()
142 @property
143 def tokens(self) -> float:
144 """Get current token count (after refill)."""
145 self._refill()
146 return self._tokens
148 def _refill(self) -> None:
149 """Refill tokens based on elapsed time."""
150 with self._lock:
151 now = time.monotonic()
152 elapsed = now - self._last_refill
153 self._last_refill = now
155 # Add tokens based on elapsed time
156 self._tokens = min(self.capacity, self._tokens + elapsed * self.refill_rate)
158 def try_acquire(self, tokens: float = 1) -> bool:
159 """
160 Try to acquire tokens without blocking.
162 Args:
163 tokens: Number of tokens to acquire (default: 1)
165 Returns:
166 True if tokens acquired, False otherwise
167 """
168 # Can never acquire more than capacity
169 if tokens > self.capacity:
170 return False
172 self._refill()
174 with self._lock:
175 if self._tokens >= tokens:
176 self._tokens -= tokens
177 return True
178 return False
180 async def acquire(self, tokens: float = 1, timeout: float | None = None) -> None:
181 """
182 Acquire tokens, waiting if necessary.
184 Args:
185 tokens: Number of tokens to acquire (default: 1)
186 timeout: Maximum time to wait in seconds (None = wait forever)
188 Raises:
189 asyncio.TimeoutError: If cannot acquire within timeout
190 ValueError: If tokens > capacity
191 """
192 if tokens > self.capacity: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 raise ValueError(f"Cannot acquire {tokens} tokens (capacity: {self.capacity})")
195 start = time.monotonic()
197 while True:
198 if self.try_acquire(tokens):
199 return
201 # Calculate wait time for enough tokens
202 self._refill()
203 tokens_needed = tokens - self._tokens
204 wait_time = tokens_needed / self.refill_rate
206 # Check timeout
207 if timeout is not None:
208 elapsed = time.monotonic() - start
209 remaining = timeout - elapsed
210 if remaining <= 0:
211 raise TimeoutError(f"Could not acquire {tokens} tokens within {timeout}s")
212 wait_time = min(wait_time, remaining)
214 # Wait for tokens to refill
215 await asyncio.sleep(max(0.01, wait_time)) # Min 10ms to avoid busy loop
218# =============================================================================
219# Provider Token Bucket Factory
220# =============================================================================
222_provider_token_buckets: dict[str, TokenBucket] = {}
223_bucket_lock = threading.Lock()
226def get_provider_token_bucket(provider: str) -> TokenBucket:
227 """
228 Get or create a token bucket for a specific LLM provider.
230 Uses provider-specific rate limits from configuration or environment.
232 Args:
233 provider: LLM provider name (e.g., "anthropic", "openai")
235 Returns:
236 TokenBucket configured for the provider's rate limits
237 """
238 with _bucket_lock:
239 if provider in _provider_token_buckets:
240 return _provider_token_buckets[provider]
242 config = get_provider_rate_config(provider)
243 rpm = config["rpm"]
244 burst_seconds = config["burst_seconds"]
246 # Convert RPM to tokens per second
247 refill_rate = rpm / 60.0
249 # Burst capacity = refill_rate * burst_seconds
250 capacity = refill_rate * burst_seconds
252 bucket = TokenBucket(capacity=capacity, refill_rate=refill_rate)
253 _provider_token_buckets[provider] = bucket
255 logger.info(
256 f"Created token bucket for {provider}",
257 extra={
258 "provider": provider,
259 "rpm": rpm,
260 "refill_rate": refill_rate,
261 "capacity": capacity,
262 },
263 )
265 return bucket
268def reset_all_token_buckets() -> None:
269 """
270 Reset all token buckets (for testing).
272 Warning: Only use for testing! In production, buckets should not be reset.
273 """
274 with _bucket_lock:
275 _provider_token_buckets.clear()
276 logger.warning("All token buckets reset (testing only)")
279# =============================================================================
280# Rate Limiter Decorator
281# =============================================================================
284def rate_limited(
285 provider: str | None = None,
286 capacity: float | None = None,
287 refill_rate: float | None = None,
288 timeout: float | None = 30.0,
289) -> Callable[[Callable[P, T]], Callable[P, T]]:
290 """
291 Decorator to rate-limit a function using token bucket.
293 Args:
294 provider: LLM provider name (uses provider-specific limits)
295 capacity: Override bucket capacity (burst limit)
296 refill_rate: Override refill rate (tokens/sec)
297 timeout: Max wait time for tokens (None = wait forever)
299 Usage:
300 @rate_limited(provider="anthropic")
301 async def call_claude(prompt: str) -> str:
302 return await client.generate(prompt)
304 # Or with explicit limits:
305 @rate_limited(capacity=5, refill_rate=1.0)
306 async def my_function():
307 pass
308 """
310 def decorator(func: Callable[P, T]) -> Callable[P, T]:
311 # Get or create bucket
312 if provider:
313 bucket = get_provider_token_bucket(provider)
314 elif capacity is not None and refill_rate is not None: 314 ↛ 318line 314 didn't jump to line 318 because the condition on line 314 was always true
315 # Create a custom bucket for this decorator
316 bucket = TokenBucket(capacity=capacity, refill_rate=refill_rate)
317 else:
318 raise ValueError("Must specify either 'provider' or both 'capacity' and 'refill_rate'")
320 @functools.wraps(func)
321 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
322 with tracer.start_as_current_span(
323 f"rate_limit.{func.__name__}",
324 attributes={
325 "rate_limit.provider": provider or "custom",
326 "rate_limit.capacity": bucket.capacity,
327 "rate_limit.refill_rate": bucket.refill_rate,
328 "rate_limit.tokens_available": bucket.tokens,
329 },
330 ) as span:
331 # Acquire token (may wait)
332 await bucket.acquire(timeout=timeout)
333 span.set_attribute("rate_limit.acquired", True)
335 return await func(*args, **kwargs) # type: ignore[misc, no-any-return]
337 # Only works with async functions
338 if not asyncio.iscoroutinefunction(func): 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true
339 raise TypeError(f"@rate_limited can only be applied to async functions, got {func.__name__}")
341 return async_wrapper # type: ignore[return-value]
343 return decorator