Coverage for src / mcp_server_langgraph / resilience / config.py: 98%
59 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Resilience configuration with environment variable support.
4Centralized configuration for all resilience patterns:
5- Circuit breaker thresholds and timeouts
6- Retry policies and backoff strategies
7- Timeout values per operation type
8- Bulkhead concurrency limits
9"""
11import os
12from enum import Enum
14from pydantic import BaseModel, ConfigDict, Field
17class JitterStrategy(str, Enum):
18 """Jitter strategies for retry backoff.
20 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
21 """
23 SIMPLE = "simple" # +/- 20% of base delay (current implicit behavior)
24 FULL = "full" # random(0, delay) - high variance, good for many clients
25 DECORRELATED = "decorrelated" # min(max, random(base, prev*3)) - best for overload
28class CircuitBreakerConfig(BaseModel):
29 """Circuit breaker configuration for a service"""
31 name: str = Field(description="Service name")
32 fail_max: int = Field(default=5, description="Max failures before opening")
33 timeout_duration: int = Field(default=60, description="Seconds to stay open")
34 expected_exception: type = Field(default=Exception, description="Exception type to track")
36 model_config = ConfigDict(arbitrary_types_allowed=True)
39class OverloadRetryConfig(BaseModel):
40 """Configuration specific to overload/529 error handling.
42 These settings provide more aggressive retry behavior for overload errors,
43 which typically require longer wait times and more attempts to recover.
44 """
46 max_attempts: int = Field(default=6, description="Max attempts for overload errors")
47 exponential_base: float = Field(default=2.0, description="Backoff base for overload")
48 exponential_max: float = Field(default=60.0, description="Max backoff for overload (seconds)")
49 initial_delay: float = Field(default=5.0, description="Initial delay for overload (seconds)")
50 jitter_strategy: JitterStrategy = Field(
51 default=JitterStrategy.DECORRELATED,
52 description="Jitter strategy for overload retries",
53 )
54 honor_retry_after: bool = Field(default=True, description="Honor Retry-After header")
55 retry_after_max: float = Field(default=120.0, description="Max Retry-After to honor (seconds)")
58class RetryConfig(BaseModel):
59 """Retry configuration"""
61 max_attempts: int = Field(default=3, description="Maximum retry attempts")
62 exponential_base: float = Field(default=2.0, description="Exponential backoff base")
63 exponential_max: float = Field(default=10.0, description="Maximum backoff in seconds")
64 jitter: bool = Field(default=True, description="Add random jitter to backoff")
65 jitter_strategy: JitterStrategy = Field(
66 default=JitterStrategy.SIMPLE,
67 description="Jitter strategy for standard retries",
68 )
69 overload: OverloadRetryConfig = Field(
70 default_factory=OverloadRetryConfig,
71 description="Configuration for overload (529) error handling",
72 )
75class TimeoutConfig(BaseModel):
76 """Timeout configuration per operation type"""
78 default: int = Field(default=30, description="Default timeout in seconds")
79 llm: int = Field(default=60, description="LLM operation timeout")
80 auth: int = Field(default=5, description="Auth operation timeout")
81 db: int = Field(default=10, description="Database operation timeout")
82 http: int = Field(default=15, description="HTTP request timeout")
85class BulkheadConfig(BaseModel):
86 """Bulkhead configuration per resource type"""
88 llm_limit: int = Field(default=10, description="Max concurrent LLM calls")
89 openfga_limit: int = Field(default=50, description="Max concurrent OpenFGA checks")
90 redis_limit: int = Field(default=100, description="Max concurrent Redis operations")
91 db_limit: int = Field(default=20, description="Max concurrent DB queries")
94class ResilienceConfig(BaseModel):
95 """Master resilience configuration"""
97 enabled: bool = Field(default=True, description="Enable resilience patterns")
99 # Circuit breaker configs per service
100 circuit_breakers: dict[str, CircuitBreakerConfig] = Field(
101 default_factory=lambda: {
102 "llm": CircuitBreakerConfig(name="llm", fail_max=5, timeout_duration=60),
103 "openfga": CircuitBreakerConfig(name="openfga", fail_max=10, timeout_duration=30),
104 "redis": CircuitBreakerConfig(name="redis", fail_max=5, timeout_duration=30),
105 "keycloak": CircuitBreakerConfig(name="keycloak", fail_max=5, timeout_duration=60),
106 "prometheus": CircuitBreakerConfig(name="prometheus", fail_max=3, timeout_duration=30),
107 }
108 )
110 # Retry configuration
111 retry: RetryConfig = Field(default_factory=RetryConfig)
113 # Timeout configuration
114 timeout: TimeoutConfig = Field(default_factory=TimeoutConfig)
116 # Bulkhead configuration
117 bulkhead: BulkheadConfig = Field(default_factory=BulkheadConfig)
119 @classmethod
120 def from_env(cls) -> "ResilienceConfig":
121 """Load configuration from environment variables"""
122 # Parse jitter strategy from env
123 jitter_strategy_str = os.getenv("RETRY_JITTER_STRATEGY", "simple").lower()
124 jitter_strategy = (
125 JitterStrategy(jitter_strategy_str)
126 if jitter_strategy_str in [s.value for s in JitterStrategy]
127 else JitterStrategy.SIMPLE
128 )
130 # Parse overload jitter strategy from env
131 overload_jitter_str = os.getenv("RETRY_OVERLOAD_JITTER_STRATEGY", "decorrelated").lower()
132 overload_jitter_strategy = (
133 JitterStrategy(overload_jitter_str)
134 if overload_jitter_str in [s.value for s in JitterStrategy]
135 else JitterStrategy.DECORRELATED
136 )
138 return cls(
139 enabled=os.getenv("RESILIENCE_ENABLED", "true").lower() == "true",
140 retry=RetryConfig(
141 max_attempts=int(os.getenv("RETRY_MAX_ATTEMPTS", "3")),
142 exponential_base=float(os.getenv("RETRY_EXPONENTIAL_BASE", "2.0")),
143 exponential_max=float(os.getenv("RETRY_EXPONENTIAL_MAX", "10.0")),
144 jitter=os.getenv("RETRY_JITTER", "true").lower() == "true",
145 jitter_strategy=jitter_strategy,
146 overload=OverloadRetryConfig(
147 max_attempts=int(os.getenv("RETRY_OVERLOAD_MAX_ATTEMPTS", "6")),
148 exponential_base=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_BASE", "2.0")),
149 exponential_max=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_MAX", "60.0")),
150 initial_delay=float(os.getenv("RETRY_OVERLOAD_INITIAL_DELAY", "5.0")),
151 jitter_strategy=overload_jitter_strategy,
152 honor_retry_after=os.getenv("RETRY_OVERLOAD_HONOR_RETRY_AFTER", "true").lower() == "true",
153 retry_after_max=float(os.getenv("RETRY_OVERLOAD_RETRY_AFTER_MAX", "120.0")),
154 ),
155 ),
156 timeout=TimeoutConfig(
157 default=int(os.getenv("TIMEOUT_DEFAULT", "30")),
158 llm=int(os.getenv("TIMEOUT_LLM", "60")),
159 auth=int(os.getenv("TIMEOUT_AUTH", "5")),
160 db=int(os.getenv("TIMEOUT_DB", "10")),
161 http=int(os.getenv("TIMEOUT_HTTP", "15")),
162 ),
163 bulkhead=BulkheadConfig(
164 llm_limit=int(os.getenv("BULKHEAD_LLM_LIMIT", "10")),
165 openfga_limit=int(os.getenv("BULKHEAD_OPENFGA_LIMIT", "50")),
166 redis_limit=int(os.getenv("BULKHEAD_REDIS_LIMIT", "100")),
167 db_limit=int(os.getenv("BULKHEAD_DB_LIMIT", "20")),
168 ),
169 )
172# Global config instance
173_resilience_config: ResilienceConfig | None = None
176def get_resilience_config() -> ResilienceConfig:
177 """Get global resilience configuration (singleton)"""
178 global _resilience_config
179 if _resilience_config is None:
180 _resilience_config = ResilienceConfig.from_env()
181 return _resilience_config
184def set_resilience_config(config: ResilienceConfig) -> None:
185 """Set global resilience configuration (for testing)"""
186 global _resilience_config
187 _resilience_config = config