Coverage for src / mcp_server_langgraph / resilience / config.py: 98%

59 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Resilience configuration with environment variable support. 

3 

4Centralized configuration for all resilience patterns: 

5- Circuit breaker thresholds and timeouts 

6- Retry policies and backoff strategies 

7- Timeout values per operation type 

8- Bulkhead concurrency limits 

9""" 

10 

11import os 

12from enum import Enum 

13 

14from pydantic import BaseModel, ConfigDict, Field 

15 

16 

17class JitterStrategy(str, Enum): 

18 """Jitter strategies for retry backoff. 

19 

20 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 

21 """ 

22 

23 SIMPLE = "simple" # +/- 20% of base delay (current implicit behavior) 

24 FULL = "full" # random(0, delay) - high variance, good for many clients 

25 DECORRELATED = "decorrelated" # min(max, random(base, prev*3)) - best for overload 

26 

27 

28class CircuitBreakerConfig(BaseModel): 

29 """Circuit breaker configuration for a service""" 

30 

31 name: str = Field(description="Service name") 

32 fail_max: int = Field(default=5, description="Max failures before opening") 

33 timeout_duration: int = Field(default=60, description="Seconds to stay open") 

34 expected_exception: type = Field(default=Exception, description="Exception type to track") 

35 

36 model_config = ConfigDict(arbitrary_types_allowed=True) 

37 

38 

39class OverloadRetryConfig(BaseModel): 

40 """Configuration specific to overload/529 error handling. 

41 

42 These settings provide more aggressive retry behavior for overload errors, 

43 which typically require longer wait times and more attempts to recover. 

44 """ 

45 

46 max_attempts: int = Field(default=6, description="Max attempts for overload errors") 

47 exponential_base: float = Field(default=2.0, description="Backoff base for overload") 

48 exponential_max: float = Field(default=60.0, description="Max backoff for overload (seconds)") 

49 initial_delay: float = Field(default=5.0, description="Initial delay for overload (seconds)") 

50 jitter_strategy: JitterStrategy = Field( 

51 default=JitterStrategy.DECORRELATED, 

52 description="Jitter strategy for overload retries", 

53 ) 

54 honor_retry_after: bool = Field(default=True, description="Honor Retry-After header") 

55 retry_after_max: float = Field(default=120.0, description="Max Retry-After to honor (seconds)") 

56 

57 

58class RetryConfig(BaseModel): 

59 """Retry configuration""" 

60 

61 max_attempts: int = Field(default=3, description="Maximum retry attempts") 

62 exponential_base: float = Field(default=2.0, description="Exponential backoff base") 

63 exponential_max: float = Field(default=10.0, description="Maximum backoff in seconds") 

64 jitter: bool = Field(default=True, description="Add random jitter to backoff") 

65 jitter_strategy: JitterStrategy = Field( 

66 default=JitterStrategy.SIMPLE, 

67 description="Jitter strategy for standard retries", 

68 ) 

69 overload: OverloadRetryConfig = Field( 

70 default_factory=OverloadRetryConfig, 

71 description="Configuration for overload (529) error handling", 

72 ) 

73 

74 

75class TimeoutConfig(BaseModel): 

76 """Timeout configuration per operation type""" 

77 

78 default: int = Field(default=30, description="Default timeout in seconds") 

79 llm: int = Field(default=60, description="LLM operation timeout") 

80 auth: int = Field(default=5, description="Auth operation timeout") 

81 db: int = Field(default=10, description="Database operation timeout") 

82 http: int = Field(default=15, description="HTTP request timeout") 

83 

84 

85class BulkheadConfig(BaseModel): 

86 """Bulkhead configuration per resource type""" 

87 

88 llm_limit: int = Field(default=10, description="Max concurrent LLM calls") 

89 openfga_limit: int = Field(default=50, description="Max concurrent OpenFGA checks") 

90 redis_limit: int = Field(default=100, description="Max concurrent Redis operations") 

91 db_limit: int = Field(default=20, description="Max concurrent DB queries") 

92 

93 

94class ResilienceConfig(BaseModel): 

95 """Master resilience configuration""" 

96 

97 enabled: bool = Field(default=True, description="Enable resilience patterns") 

98 

99 # Circuit breaker configs per service 

100 circuit_breakers: dict[str, CircuitBreakerConfig] = Field( 

101 default_factory=lambda: { 

102 "llm": CircuitBreakerConfig(name="llm", fail_max=5, timeout_duration=60), 

103 "openfga": CircuitBreakerConfig(name="openfga", fail_max=10, timeout_duration=30), 

104 "redis": CircuitBreakerConfig(name="redis", fail_max=5, timeout_duration=30), 

105 "keycloak": CircuitBreakerConfig(name="keycloak", fail_max=5, timeout_duration=60), 

106 "prometheus": CircuitBreakerConfig(name="prometheus", fail_max=3, timeout_duration=30), 

107 } 

108 ) 

109 

110 # Retry configuration 

111 retry: RetryConfig = Field(default_factory=RetryConfig) 

112 

113 # Timeout configuration 

114 timeout: TimeoutConfig = Field(default_factory=TimeoutConfig) 

115 

116 # Bulkhead configuration 

117 bulkhead: BulkheadConfig = Field(default_factory=BulkheadConfig) 

118 

119 @classmethod 

120 def from_env(cls) -> "ResilienceConfig": 

121 """Load configuration from environment variables""" 

122 # Parse jitter strategy from env 

123 jitter_strategy_str = os.getenv("RETRY_JITTER_STRATEGY", "simple").lower() 

124 jitter_strategy = ( 

125 JitterStrategy(jitter_strategy_str) 

126 if jitter_strategy_str in [s.value for s in JitterStrategy] 

127 else JitterStrategy.SIMPLE 

128 ) 

129 

130 # Parse overload jitter strategy from env 

131 overload_jitter_str = os.getenv("RETRY_OVERLOAD_JITTER_STRATEGY", "decorrelated").lower() 

132 overload_jitter_strategy = ( 

133 JitterStrategy(overload_jitter_str) 

134 if overload_jitter_str in [s.value for s in JitterStrategy] 

135 else JitterStrategy.DECORRELATED 

136 ) 

137 

138 return cls( 

139 enabled=os.getenv("RESILIENCE_ENABLED", "true").lower() == "true", 

140 retry=RetryConfig( 

141 max_attempts=int(os.getenv("RETRY_MAX_ATTEMPTS", "3")), 

142 exponential_base=float(os.getenv("RETRY_EXPONENTIAL_BASE", "2.0")), 

143 exponential_max=float(os.getenv("RETRY_EXPONENTIAL_MAX", "10.0")), 

144 jitter=os.getenv("RETRY_JITTER", "true").lower() == "true", 

145 jitter_strategy=jitter_strategy, 

146 overload=OverloadRetryConfig( 

147 max_attempts=int(os.getenv("RETRY_OVERLOAD_MAX_ATTEMPTS", "6")), 

148 exponential_base=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_BASE", "2.0")), 

149 exponential_max=float(os.getenv("RETRY_OVERLOAD_EXPONENTIAL_MAX", "60.0")), 

150 initial_delay=float(os.getenv("RETRY_OVERLOAD_INITIAL_DELAY", "5.0")), 

151 jitter_strategy=overload_jitter_strategy, 

152 honor_retry_after=os.getenv("RETRY_OVERLOAD_HONOR_RETRY_AFTER", "true").lower() == "true", 

153 retry_after_max=float(os.getenv("RETRY_OVERLOAD_RETRY_AFTER_MAX", "120.0")), 

154 ), 

155 ), 

156 timeout=TimeoutConfig( 

157 default=int(os.getenv("TIMEOUT_DEFAULT", "30")), 

158 llm=int(os.getenv("TIMEOUT_LLM", "60")), 

159 auth=int(os.getenv("TIMEOUT_AUTH", "5")), 

160 db=int(os.getenv("TIMEOUT_DB", "10")), 

161 http=int(os.getenv("TIMEOUT_HTTP", "15")), 

162 ), 

163 bulkhead=BulkheadConfig( 

164 llm_limit=int(os.getenv("BULKHEAD_LLM_LIMIT", "10")), 

165 openfga_limit=int(os.getenv("BULKHEAD_OPENFGA_LIMIT", "50")), 

166 redis_limit=int(os.getenv("BULKHEAD_REDIS_LIMIT", "100")), 

167 db_limit=int(os.getenv("BULKHEAD_DB_LIMIT", "20")), 

168 ), 

169 ) 

170 

171 

172# Global config instance 

173_resilience_config: ResilienceConfig | None = None 

174 

175 

176def get_resilience_config() -> ResilienceConfig: 

177 """Get global resilience configuration (singleton)""" 

178 global _resilience_config 

179 if _resilience_config is None: 

180 _resilience_config = ResilienceConfig.from_env() 

181 return _resilience_config 

182 

183 

184def set_resilience_config(config: ResilienceConfig) -> None: 

185 """Set global resilience configuration (for testing)""" 

186 global _resilience_config 

187 _resilience_config = config