Coverage for src / mcp_server_langgraph / resilience / config.py: 99%

71 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 06:31 +0000

1""" 

2Resilience configuration with environment variable support. 

3 

4Centralized configuration for all resilience patterns: 

5- Circuit breaker thresholds and timeouts 

6- Retry policies and backoff strategies 

7- Timeout values per operation type 

8- Bulkhead concurrency limits 

9- Provider-aware concurrency limits 

10 

11Provider rate limit references: 

12- Anthropic: https://docs.anthropic.com/en/api/rate-limits (Tier 1: 50 RPM → ~8 concurrent) 

13- OpenAI: https://platform.openai.com/docs/guides/rate-limits (~500 RPM → ~15 concurrent) 

14- Google Vertex AI: https://cloud.google.com/vertex-ai/generative-ai/docs/quotas (DSQ, ~5000 RPM) 

15- AWS Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html (Claude Opus: 50 RPM) 

16""" 

17 

18import os 

19from enum import Enum 

20 

21from pydantic import BaseModel, ConfigDict, Field 

22 

23 

24# ============================================================================= 

25# Provider-Specific Concurrency Limits 

26# ============================================================================= 

27# These limits are based on typical rate limits and assume ~5-10s average LLM latency. 

28# Formula: concurrent_limit ≈ (requests_per_minute / 60) * avg_latency_seconds 

29# 

30# Example: Anthropic Tier 1 = 50 RPM, 10s latency → (50/60) * 10 ≈ 8 concurrent 

31 

32PROVIDER_CONCURRENCY_LIMITS: dict[str, int] = { 

33 # Anthropic Claude (direct API): Conservative Tier 1 (50 RPM) 

34 # Higher tiers (Tier 4: 4000 RPM) can override via env var 

35 "anthropic": 8, 

36 # OpenAI GPT-4: Moderate tier (~500 RPM typical) 

37 "openai": 15, 

38 # Google Gemini via AI Studio or direct API 

39 # Gemini 2.5 Flash: 5000 RPM per region (very generous) 

40 "google": 15, 

41 # Google Vertex AI (Gemini models): Higher limits for vishnu-sandbox-20250310 

42 # Gemini 2.5 Flash/Pro: 3.4M TPM → ~1400 RPM (2500 tokens/req) 

43 # Conservative: 15 concurrent (assumes 1s avg latency) 

44 "vertex_ai": 15, 

45 # Anthropic Claude via Vertex AI (MaaS) - vishnu-sandbox-20250310: 

46 # Claude Opus 4.5: 12M TPM, Sonnet 4.5: 3M TPM, Haiku 4.5: 10M TPM 

47 # Lowest is Sonnet 4.5: 3M TPM → ~1200 RPM (2500 tokens/req) 

48 # With ~3s avg latency: 20 req/sec * 3s = 60 concurrent possible 

49 # Conservative: 15 concurrent (balanced across all 4.5 models) 

50 "vertex_ai_anthropic": 15, 

51 # AWS Bedrock: Claude Opus is most restrictive (50 RPM) 

52 # Sonnet has higher limits (500 RPM) but we use conservative default 

53 "bedrock": 8, 

54 # Ollama: Local model, no upstream rate limits 

55 # Limited only by local hardware resources 

56 "ollama": 50, 

57 # Azure OpenAI: Similar to OpenAI, but deployment-specific 

58 "azure": 15, 

59} 

60 

61# Default limit for unknown providers 

62DEFAULT_PROVIDER_LIMIT = 10 

63 

64 

65def get_provider_limit(provider: str) -> int: 

66 """ 

67 Get the concurrency limit for a specific provider. 

68 

69 Checks environment variables first (e.g., BULKHEAD_ANTHROPIC_LIMIT=5), 

70 then falls back to the default for that provider. 

71 

72 Args: 

73 provider: Provider name (e.g., "anthropic", "openai", "vertex_ai") 

74 

75 Returns: 

76 Concurrency limit for the provider 

77 """ 

78 # Check for environment variable override 

79 env_var = f"BULKHEAD_{provider.upper()}_LIMIT" 

80 env_value = os.getenv(env_var) 

81 

82 if env_value is not None: 

83 try: 

84 return int(env_value) 

85 except ValueError: 

86 # Invalid value, log and use default 

87 import logging 

88 

89 logging.warning( 

90 f"Invalid {env_var}={env_value}, using default for {provider}", 

91 extra={"env_var": env_var, "invalid_value": env_value}, 

92 ) 

93 

94 # Use provider-specific limit or default 

95 return PROVIDER_CONCURRENCY_LIMITS.get(provider, DEFAULT_PROVIDER_LIMIT) 

96 

97 

98class JitterStrategy(str, Enum): 

99 """Jitter strategies for retry backoff. 

100 

101 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 

102 """ 

103 

104 SIMPLE = "simple" # +/- 20% of base delay (current implicit behavior) 

105 FULL = "full" # random(0, delay) - high variance, good for many clients 

106 DECORRELATED = "decorrelated" # min(max, random(base, prev*3)) - best for overload 

107 

108 

109class CircuitBreakerConfig(BaseModel): 

110 """Circuit breaker configuration for a service""" 

111 

112 name: str = Field(description="Service name") 

113 fail_max: int = Field(default=5, description="Max failures before opening") 

114 timeout_duration: int = Field(default=60, description="Seconds to stay open") 

115 expected_exception: type = Field(default=Exception, description="Exception type to track") 

116 

117 model_config = ConfigDict(arbitrary_types_allowed=True) 

118 

119 

120class OverloadRetryConfig(BaseModel): 

121 """Configuration specific to overload/529 error handling. 

122 

123 These settings provide more aggressive retry behavior for overload errors, 

124 which typically require longer wait times and more attempts to recover. 

125 """ 

126 

127 max_attempts: int = Field(default=6, description="Max attempts for overload errors") 

128 exponential_base: float = Field(default=2.0, description="Backoff base for overload") 

129 exponential_max: float = Field(default=60.0, description="Max backoff for overload (seconds)") 

130 initial_delay: float = Field(default=5.0, description="Initial delay for overload (seconds)") 

131 jitter_strategy: JitterStrategy = Field( 

132 default=JitterStrategy.DECORRELATED, 

133 description="Jitter strategy for overload retries", 

134 ) 

135 honor_retry_after: bool = Field(default=True, description="Honor Retry-After header") 

136 retry_after_max: float = Field(default=120.0, description="Max Retry-After to honor (seconds)") 

137 

138 

139class RetryConfig(BaseModel): 

140 """Retry configuration""" 

141 

142 max_attempts: int = Field(default=3, description="Maximum retry attempts") 

143 exponential_base: float = Field(default=2.0, description="Exponential backoff base") 

144 exponential_max: float = Field(default=10.0, description="Maximum backoff in seconds") 

145 jitter: bool = Field(default=True, description="Add random jitter to backoff") 

146 jitter_strategy: JitterStrategy = Field( 

147 default=JitterStrategy.SIMPLE, 

148 description="Jitter strategy for standard retries", 

149 ) 

150 overload: OverloadRetryConfig = Field( 

151 default_factory=OverloadRetryConfig, 

152 description="Configuration for overload (529) error handling", 

153 ) 

154 

155 

156class TimeoutConfig(BaseModel): 

157 """Timeout configuration per operation type""" 

158 

159 default: int = Field(default=30, description="Default timeout in seconds") 

160 llm: int = Field(default=60, description="LLM operation timeout") 

161 auth: int = Field(default=5, description="Auth operation timeout") 

162 db: int = Field(default=10, description="Database operation timeout") 

163 http: int = Field(default=15, description="HTTP request timeout") 

164 

165 

166class BulkheadConfig(BaseModel): 

167 """Bulkhead configuration per resource type""" 

168 

169 llm_limit: int = Field(default=10, description="Max concurrent LLM calls") 

170 openfga_limit: int = Field(default=50, description="Max concurrent OpenFGA checks") 

171 redis_limit: int = Field(default=100, description="Max concurrent Redis operations") 

172 db_limit: int = Field(default=20, description="Max concurrent DB queries") 

173 

174 

175class ResilienceConfig(BaseModel): 

176 """Master resilience configuration""" 

177 

178 enabled: bool = Field(default=True, description="Enable resilience patterns") 

179 

180 # Circuit breaker configs per service 

181 circuit_breakers: dict[str, CircuitBreakerConfig] = Field( 

182 default_factory=lambda: { 

183 "llm": CircuitBreakerConfig(name="llm", fail_max=5, timeout_duration=60), 

184 "openfga": CircuitBreakerConfig(name="openfga", fail_max=10, timeout_duration=30), 

185 "redis": CircuitBreakerConfig(name="redis", fail_max=5, timeout_duration=30), 

186 "keycloak": CircuitBreakerConfig(name="keycloak", fail_max=5, timeout_duration=60), 

187 "prometheus": CircuitBreakerConfig(name="prometheus", fail_max=3, timeout_duration=30), 

188 } 

189 ) 

190 

191 # Retry configuration 

192 retry: RetryConfig = Field(default_factory=RetryConfig) 

193 

194 # Timeout configuration 

195 timeout: TimeoutConfig = Field(default_factory=TimeoutConfig) 

196 

197 # Bulkhead configuration 

198 bulkhead: BulkheadConfig = Field(default_factory=BulkheadConfig) 

199 

200 @classmethod 

201 def from_env(cls) -> "ResilienceConfig": 

202 """Load configuration from environment variables""" 

203 # Parse jitter strategy from env 

204 jitter_strategy_str = os.getenv("RETRY_JITTER_STRATEGY", "simple").lower() 

205 jitter_strategy = ( 

206 JitterStrategy(jitter_strategy_str) 

207 if jitter_strategy_str in [s.value for s in JitterStrategy] 

208 else JitterStrategy.SIMPLE 

209 ) 

210 

211 # Parse overload jitter strategy from env 

212 overload_jitter_str = os.getenv("RETRY_OVERLOAD_JITTER_STRATEGY", "decorrelated").lower() 

213 overload_jitter_strategy = ( 

214 JitterStrategy(overload_jitter_str) 

215 if overload_jitter_str in [s.value for s in JitterStrategy] 

216 else JitterStrategy.DECORRELATED 

217 ) 

218 

219 return cls( 

220 enabled=os.getenv("RESILIENCE_ENABLED", "true").lower() == "true", 

221 retry=RetryConfig( 

222 max_attempts=int(os.getenv("RETRY_MAX_ATTEMPTS", "3")), 

223 exponential_base=float(os.getenv("RETRY_EXPONENTIAL_BASE", "2.0")), 

224 exponential_max=float(os.getenv("RETRY_EXPONENTIAL_MAX", "10.0")), 

225 jitter=os.getenv("RETRY_JITTER", "true").lower() == "true", 

226 jitter_strategy=jitter_strategy, 

227 overload=OverloadRetryConfig( 

228 max_attempts=int(os.getenv("RETRY_OVERLOAD_MAX_ATTEMPTS", "6")), 

229 exponential_base=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_BASE", "2.0")), 

230 exponential_max=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_MAX", "60.0")), 

231 initial_delay=float(os.getenv("RETRY_OVERLOAD_INITIAL_DELAY", "5.0")), 

232 jitter_strategy=overload_jitter_strategy, 

233 honor_retry_after=os.getenv("RETRY_OVERLOAD_HONOR_RETRY_AFTER", "true").lower() == "true", 

234 retry_after_max=float(os.getenv("RETRY_OVERLOAD_RETRY_AFTER_MAX", "120.0")), 

235 ), 

236 ), 

237 timeout=TimeoutConfig( 

238 default=int(os.getenv("TIMEOUT_DEFAULT", "30")), 

239 llm=int(os.getenv("TIMEOUT_LLM", "60")), 

240 auth=int(os.getenv("TIMEOUT_AUTH", "5")), 

241 db=int(os.getenv("TIMEOUT_DB", "10")), 

242 http=int(os.getenv("TIMEOUT_HTTP", "15")), 

243 ), 

244 bulkhead=BulkheadConfig( 

245 llm_limit=int(os.getenv("BULKHEAD_LLM_LIMIT", "10")), 

246 openfga_limit=int(os.getenv("BULKHEAD_OPENFGA_LIMIT", "50")), 

247 redis_limit=int(os.getenv("BULKHEAD_REDIS_LIMIT", "100")), 

248 db_limit=int(os.getenv("BULKHEAD_DB_LIMIT", "20")), 

249 ), 

250 ) 

251 

252 

253# Global config instance 

254_resilience_config: ResilienceConfig | None = None 

255 

256 

257def get_resilience_config() -> ResilienceConfig: 

258 """Get global resilience configuration (singleton)""" 

259 global _resilience_config 

260 if _resilience_config is None: 

261 _resilience_config = ResilienceConfig.from_env() 

262 return _resilience_config 

263 

264 

265def set_resilience_config(config: ResilienceConfig) -> None: 

266 """Set global resilience configuration (for testing)""" 

267 global _resilience_config 

268 _resilience_config = config