Coverage for src / mcp_server_langgraph / resilience / config.py: 99%
71 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
1"""
2Resilience configuration with environment variable support.
4Centralized configuration for all resilience patterns:
5- Circuit breaker thresholds and timeouts
6- Retry policies and backoff strategies
7- Timeout values per operation type
8- Bulkhead concurrency limits
9- Provider-aware concurrency limits
11Provider rate limit references:
12- Anthropic: https://docs.anthropic.com/en/api/rate-limits (Tier 1: 50 RPM → ~8 concurrent)
13- OpenAI: https://platform.openai.com/docs/guides/rate-limits (~500 RPM → ~15 concurrent)
14- Google Vertex AI: https://cloud.google.com/vertex-ai/generative-ai/docs/quotas (DSQ, ~5000 RPM)
15- AWS Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html (Claude Opus: 50 RPM)
16"""
18import os
19from enum import Enum
21from pydantic import BaseModel, ConfigDict, Field
24# =============================================================================
25# Provider-Specific Concurrency Limits
26# =============================================================================
27# These limits are based on typical rate limits and assume ~5-10s average LLM latency.
28# Formula: concurrent_limit ≈ (requests_per_minute / 60) * avg_latency_seconds
29#
30# Example: Anthropic Tier 1 = 50 RPM, 10s latency → (50/60) * 10 ≈ 8 concurrent
32PROVIDER_CONCURRENCY_LIMITS: dict[str, int] = {
33 # Anthropic Claude (direct API): Conservative Tier 1 (50 RPM)
34 # Higher tiers (Tier 4: 4000 RPM) can override via env var
35 "anthropic": 8,
36 # OpenAI GPT-4: Moderate tier (~500 RPM typical)
37 "openai": 15,
38 # Google Gemini via AI Studio or direct API
39 # Gemini 2.5 Flash: 5000 RPM per region (very generous)
40 "google": 15,
41 # Google Vertex AI (Gemini models): Higher limits for vishnu-sandbox-20250310
42 # Gemini 2.5 Flash/Pro: 3.4M TPM → ~1400 RPM (2500 tokens/req)
43 # Conservative: 15 concurrent (assumes 1s avg latency)
44 "vertex_ai": 15,
45 # Anthropic Claude via Vertex AI (MaaS) - vishnu-sandbox-20250310:
46 # Claude Opus 4.5: 12M TPM, Sonnet 4.5: 3M TPM, Haiku 4.5: 10M TPM
47 # Lowest is Sonnet 4.5: 3M TPM → ~1200 RPM (2500 tokens/req)
48 # With ~3s avg latency: 20 req/sec * 3s = 60 concurrent possible
49 # Conservative: 15 concurrent (balanced across all 4.5 models)
50 "vertex_ai_anthropic": 15,
51 # AWS Bedrock: Claude Opus is most restrictive (50 RPM)
52 # Sonnet has higher limits (500 RPM) but we use conservative default
53 "bedrock": 8,
54 # Ollama: Local model, no upstream rate limits
55 # Limited only by local hardware resources
56 "ollama": 50,
57 # Azure OpenAI: Similar to OpenAI, but deployment-specific
58 "azure": 15,
59}
61# Default limit for unknown providers
62DEFAULT_PROVIDER_LIMIT = 10
65def get_provider_limit(provider: str) -> int:
66 """
67 Get the concurrency limit for a specific provider.
69 Checks environment variables first (e.g., BULKHEAD_ANTHROPIC_LIMIT=5),
70 then falls back to the default for that provider.
72 Args:
73 provider: Provider name (e.g., "anthropic", "openai", "vertex_ai")
75 Returns:
76 Concurrency limit for the provider
77 """
78 # Check for environment variable override
79 env_var = f"BULKHEAD_{provider.upper()}_LIMIT"
80 env_value = os.getenv(env_var)
82 if env_value is not None:
83 try:
84 return int(env_value)
85 except ValueError:
86 # Invalid value, log and use default
87 import logging
89 logging.warning(
90 f"Invalid {env_var}={env_value}, using default for {provider}",
91 extra={"env_var": env_var, "invalid_value": env_value},
92 )
94 # Use provider-specific limit or default
95 return PROVIDER_CONCURRENCY_LIMITS.get(provider, DEFAULT_PROVIDER_LIMIT)
98class JitterStrategy(str, Enum):
99 """Jitter strategies for retry backoff.
101 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
102 """
104 SIMPLE = "simple" # +/- 20% of base delay (current implicit behavior)
105 FULL = "full" # random(0, delay) - high variance, good for many clients
106 DECORRELATED = "decorrelated" # min(max, random(base, prev*3)) - best for overload
109class CircuitBreakerConfig(BaseModel):
110 """Circuit breaker configuration for a service"""
112 name: str = Field(description="Service name")
113 fail_max: int = Field(default=5, description="Max failures before opening")
114 timeout_duration: int = Field(default=60, description="Seconds to stay open")
115 expected_exception: type = Field(default=Exception, description="Exception type to track")
117 model_config = ConfigDict(arbitrary_types_allowed=True)
120class OverloadRetryConfig(BaseModel):
121 """Configuration specific to overload/529 error handling.
123 These settings provide more aggressive retry behavior for overload errors,
124 which typically require longer wait times and more attempts to recover.
125 """
127 max_attempts: int = Field(default=6, description="Max attempts for overload errors")
128 exponential_base: float = Field(default=2.0, description="Backoff base for overload")
129 exponential_max: float = Field(default=60.0, description="Max backoff for overload (seconds)")
130 initial_delay: float = Field(default=5.0, description="Initial delay for overload (seconds)")
131 jitter_strategy: JitterStrategy = Field(
132 default=JitterStrategy.DECORRELATED,
133 description="Jitter strategy for overload retries",
134 )
135 honor_retry_after: bool = Field(default=True, description="Honor Retry-After header")
136 retry_after_max: float = Field(default=120.0, description="Max Retry-After to honor (seconds)")
139class RetryConfig(BaseModel):
140 """Retry configuration"""
142 max_attempts: int = Field(default=3, description="Maximum retry attempts")
143 exponential_base: float = Field(default=2.0, description="Exponential backoff base")
144 exponential_max: float = Field(default=10.0, description="Maximum backoff in seconds")
145 jitter: bool = Field(default=True, description="Add random jitter to backoff")
146 jitter_strategy: JitterStrategy = Field(
147 default=JitterStrategy.SIMPLE,
148 description="Jitter strategy for standard retries",
149 )
150 overload: OverloadRetryConfig = Field(
151 default_factory=OverloadRetryConfig,
152 description="Configuration for overload (529) error handling",
153 )
156class TimeoutConfig(BaseModel):
157 """Timeout configuration per operation type"""
159 default: int = Field(default=30, description="Default timeout in seconds")
160 llm: int = Field(default=60, description="LLM operation timeout")
161 auth: int = Field(default=5, description="Auth operation timeout")
162 db: int = Field(default=10, description="Database operation timeout")
163 http: int = Field(default=15, description="HTTP request timeout")
166class BulkheadConfig(BaseModel):
167 """Bulkhead configuration per resource type"""
169 llm_limit: int = Field(default=10, description="Max concurrent LLM calls")
170 openfga_limit: int = Field(default=50, description="Max concurrent OpenFGA checks")
171 redis_limit: int = Field(default=100, description="Max concurrent Redis operations")
172 db_limit: int = Field(default=20, description="Max concurrent DB queries")
175class ResilienceConfig(BaseModel):
176 """Master resilience configuration"""
178 enabled: bool = Field(default=True, description="Enable resilience patterns")
180 # Circuit breaker configs per service
181 circuit_breakers: dict[str, CircuitBreakerConfig] = Field(
182 default_factory=lambda: {
183 "llm": CircuitBreakerConfig(name="llm", fail_max=5, timeout_duration=60),
184 "openfga": CircuitBreakerConfig(name="openfga", fail_max=10, timeout_duration=30),
185 "redis": CircuitBreakerConfig(name="redis", fail_max=5, timeout_duration=30),
186 "keycloak": CircuitBreakerConfig(name="keycloak", fail_max=5, timeout_duration=60),
187 "prometheus": CircuitBreakerConfig(name="prometheus", fail_max=3, timeout_duration=30),
188 }
189 )
191 # Retry configuration
192 retry: RetryConfig = Field(default_factory=RetryConfig)
194 # Timeout configuration
195 timeout: TimeoutConfig = Field(default_factory=TimeoutConfig)
197 # Bulkhead configuration
198 bulkhead: BulkheadConfig = Field(default_factory=BulkheadConfig)
200 @classmethod
201 def from_env(cls) -> "ResilienceConfig":
202 """Load configuration from environment variables"""
203 # Parse jitter strategy from env
204 jitter_strategy_str = os.getenv("RETRY_JITTER_STRATEGY", "simple").lower()
205 jitter_strategy = (
206 JitterStrategy(jitter_strategy_str)
207 if jitter_strategy_str in [s.value for s in JitterStrategy]
208 else JitterStrategy.SIMPLE
209 )
211 # Parse overload jitter strategy from env
212 overload_jitter_str = os.getenv("RETRY_OVERLOAD_JITTER_STRATEGY", "decorrelated").lower()
213 overload_jitter_strategy = (
214 JitterStrategy(overload_jitter_str)
215 if overload_jitter_str in [s.value for s in JitterStrategy]
216 else JitterStrategy.DECORRELATED
217 )
219 return cls(
220 enabled=os.getenv("RESILIENCE_ENABLED", "true").lower() == "true",
221 retry=RetryConfig(
222 max_attempts=int(os.getenv("RETRY_MAX_ATTEMPTS", "3")),
223 exponential_base=float(os.getenv("RETRY_EXPONENTIAL_BASE", "2.0")),
224 exponential_max=float(os.getenv("RETRY_EXPONENTIAL_MAX", "10.0")),
225 jitter=os.getenv("RETRY_JITTER", "true").lower() == "true",
226 jitter_strategy=jitter_strategy,
227 overload=OverloadRetryConfig(
228 max_attempts=int(os.getenv("RETRY_OVERLOAD_MAX_ATTEMPTS", "6")),
229 exponential_base=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_BASE", "2.0")),
230 exponential_max=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_MAX", "60.0")),
231 initial_delay=float(os.getenv("RETRY_OVERLOAD_INITIAL_DELAY", "5.0")),
232 jitter_strategy=overload_jitter_strategy,
233 honor_retry_after=os.getenv("RETRY_OVERLOAD_HONOR_RETRY_AFTER", "true").lower() == "true",
234 retry_after_max=float(os.getenv("RETRY_OVERLOAD_RETRY_AFTER_MAX", "120.0")),
235 ),
236 ),
237 timeout=TimeoutConfig(
238 default=int(os.getenv("TIMEOUT_DEFAULT", "30")),
239 llm=int(os.getenv("TIMEOUT_LLM", "60")),
240 auth=int(os.getenv("TIMEOUT_AUTH", "5")),
241 db=int(os.getenv("TIMEOUT_DB", "10")),
242 http=int(os.getenv("TIMEOUT_HTTP", "15")),
243 ),
244 bulkhead=BulkheadConfig(
245 llm_limit=int(os.getenv("BULKHEAD_LLM_LIMIT", "10")),
246 openfga_limit=int(os.getenv("BULKHEAD_OPENFGA_LIMIT", "50")),
247 redis_limit=int(os.getenv("BULKHEAD_REDIS_LIMIT", "100")),
248 db_limit=int(os.getenv("BULKHEAD_DB_LIMIT", "20")),
249 ),
250 )
253# Global config instance
254_resilience_config: ResilienceConfig | None = None
257def get_resilience_config() -> ResilienceConfig:
258 """Get global resilience configuration (singleton)"""
259 global _resilience_config
260 if _resilience_config is None:
261 _resilience_config = ResilienceConfig.from_env()
262 return _resilience_config
265def set_resilience_config(config: ResilienceConfig) -> None:
266 """Set global resilience configuration (for testing)"""
267 global _resilience_config
268 _resilience_config = config