Coverage for src / mcp_server_langgraph / resilience / adaptive.py: 88%

96 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 06:31 +0000

1""" 

2Adaptive bulkhead with self-tuning concurrency limits. 

3 

4Implements AIMD (Additive Increase Multiplicative Decrease) algorithm to 

5automatically adjust concurrency limits based on observed error rates. 

6 

7Key features: 

8- Monitors 429/529 error rates in sliding window 

9- Multiplicatively decreases limit on errors (fast decrease) 

10- Additively increases limit on success streaks (slow recovery) 

11- Respects floor and ceiling limits 

12 

13This is similar to TCP congestion control (Jacobson 1988) and provides 

14self-healing behavior for rate limit issues. 

15 

16See ADR-0026 for design rationale. 

17""" 

18 

19import asyncio 

20import logging 

21import threading 

22import time 

23from collections import deque 

24 

25from opentelemetry import trace 

26 

27from mcp_server_langgraph.resilience.config import get_provider_limit 

28 

29logger = logging.getLogger(__name__) 

30tracer = trace.get_tracer(__name__) 

31 

32 

33# ============================================================================= 

34# Adaptive Bulkhead Configuration 

35# ============================================================================= 

36 

37# Default AIMD parameters 

38DEFAULT_MIN_LIMIT = 2 

39DEFAULT_MAX_LIMIT = 50 

40DEFAULT_INITIAL_LIMIT = 10 

41DEFAULT_ERROR_THRESHOLD = 0.1 # 10% error rate triggers decrease 

42DEFAULT_DECREASE_FACTOR = 0.75 # Reduce by 25% on error 

43DEFAULT_INCREASE_AMOUNT = 1 # Add 1 on success streak 

44DEFAULT_SUCCESS_STREAK_THRESHOLD = 10 # 10 successes to trigger increase 

45DEFAULT_WINDOW_SIZE = 100 # Sliding window for error rate calculation 

46 

47 

48# ============================================================================= 

49# Adaptive Bulkhead Implementation 

50# ============================================================================= 

51 

52 

53class AdaptiveBulkhead: 

54 """ 

55 Self-tuning bulkhead with AIMD algorithm. 

56 

57 Automatically adjusts concurrency limit based on error rates: 

58 - On error: limit = max(min_limit, limit * decrease_factor) 

59 - On success streak: limit = min(max_limit, limit + increase_amount) 

60 

61 Attributes: 

62 min_limit: Floor for concurrency limit 

63 max_limit: Ceiling for concurrency limit 

64 current_limit: Current concurrency limit 

65 error_threshold: Error rate that triggers decrease 

66 

67 Example: 

68 >>> bulkhead = AdaptiveBulkhead(min_limit=5, max_limit=50, initial_limit=10) 

69 >>> semaphore = bulkhead.get_semaphore() 

70 >>> async with semaphore: 

71 ... try: 

72 ... result = await call_api() 

73 ... bulkhead.record_success() 

74 ... except RateLimitError: 

75 ... bulkhead.record_error() 

76 """ 

77 

78 def __init__( 

79 self, 

80 min_limit: int = DEFAULT_MIN_LIMIT, 

81 max_limit: int = DEFAULT_MAX_LIMIT, 

82 initial_limit: int | None = None, 

83 error_threshold: float = DEFAULT_ERROR_THRESHOLD, 

84 decrease_factor: float = DEFAULT_DECREASE_FACTOR, 

85 increase_amount: int = DEFAULT_INCREASE_AMOUNT, 

86 success_streak_threshold: int = DEFAULT_SUCCESS_STREAK_THRESHOLD, 

87 window_size: int = DEFAULT_WINDOW_SIZE, 

88 ): 

89 """ 

90 Initialize adaptive bulkhead. 

91 

92 Args: 

93 min_limit: Minimum concurrency limit (floor) 

94 max_limit: Maximum concurrency limit (ceiling) 

95 initial_limit: Starting limit (defaults to (min + max) / 2) 

96 error_threshold: Error rate that triggers limit decrease 

97 decrease_factor: Multiplicative decrease factor (0-1) 

98 increase_amount: Additive increase on success streak 

99 success_streak_threshold: Successes needed to trigger increase 

100 window_size: Size of sliding window for error rate 

101 """ 

102 self.min_limit = min_limit 

103 self.max_limit = max_limit 

104 self.error_threshold = error_threshold 

105 self.decrease_factor = decrease_factor 

106 self.increase_amount = increase_amount 

107 self.success_streak_threshold = success_streak_threshold 

108 self.window_size = window_size 

109 

110 # Set initial limit 

111 if initial_limit is not None: 

112 self._current_limit = float(max(min_limit, min(max_limit, initial_limit))) 

113 else: 

114 self._current_limit = float((min_limit + max_limit) // 2) 

115 

116 # Tracking state 

117 self._success_streak = 0 

118 self._samples: deque[bool] = deque(maxlen=window_size) # True = success, False = error 

119 self._lock = threading.Lock() 

120 self._semaphore: asyncio.Semaphore | None = None 

121 self._last_update = time.monotonic() 

122 

123 logger.info( 

124 "Initialized adaptive bulkhead", 

125 extra={ 

126 "min_limit": min_limit, 

127 "max_limit": max_limit, 

128 "initial_limit": self._current_limit, 

129 "error_threshold": error_threshold, 

130 }, 

131 ) 

132 

133 @property 

134 def current_limit(self) -> int: 

135 """Get current concurrency limit (integer).""" 

136 return int(self._current_limit) 

137 

138 def record_success(self) -> None: 

139 """Record a successful operation.""" 

140 with self._lock: 

141 self._samples.append(True) 

142 self._success_streak += 1 

143 

144 # Check if we should increase limit 

145 if self._success_streak >= self.success_streak_threshold: 

146 old_limit = self._current_limit 

147 self._current_limit = min(self.max_limit, self._current_limit + self.increase_amount) 

148 self._success_streak = 0 # Reset streak after increase 

149 self._invalidate_semaphore() 

150 

151 if self._current_limit > old_limit: 151 ↛ exitline 151 didn't jump to the function exit

152 logger.info( 

153 f"Adaptive bulkhead increased limit: {old_limit:.0f} -> {self._current_limit:.0f}", 

154 extra={ 

155 "old_limit": old_limit, 

156 "new_limit": self._current_limit, 

157 "reason": "success_streak", 

158 }, 

159 ) 

160 

161 def record_error(self) -> None: 

162 """Record an error (429/529 or similar).""" 

163 with self._lock: 

164 self._samples.append(False) 

165 self._success_streak = 0 # Reset streak on error 

166 

167 # Decrease limit multiplicatively 

168 old_limit = self._current_limit 

169 self._current_limit = max(self.min_limit, self._current_limit * self.decrease_factor) 

170 self._invalidate_semaphore() 

171 

172 logger.warning( 

173 f"Adaptive bulkhead decreased limit: {old_limit:.0f} -> {self._current_limit:.0f}", 

174 extra={ 

175 "old_limit": old_limit, 

176 "new_limit": self._current_limit, 

177 "reason": "error", 

178 }, 

179 ) 

180 

181 def get_error_rate(self) -> float: 

182 """ 

183 Get current error rate from sliding window. 

184 

185 Returns: 

186 Error rate as float (0.0 - 1.0) 

187 """ 

188 with self._lock: 

189 if not self._samples: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 return 0.0 

191 errors = sum(1 for s in self._samples if not s) 

192 return errors / len(self._samples) 

193 

194 def get_semaphore(self) -> asyncio.Semaphore: 

195 """ 

196 Get asyncio semaphore with current limit. 

197 

198 Returns a new semaphore if limit has changed since last call. 

199 

200 Returns: 

201 asyncio.Semaphore configured for current limit 

202 """ 

203 with self._lock: 

204 if self._semaphore is None or self._needs_new_semaphore(): 204 ↛ 207line 204 didn't jump to line 207 because the condition on line 204 was always true

205 self._semaphore = asyncio.Semaphore(self.current_limit) 

206 self._last_update = time.monotonic() 

207 return self._semaphore 

208 

209 def _invalidate_semaphore(self) -> None: 

210 """Mark semaphore as needing update.""" 

211 self._semaphore = None 

212 

213 def _needs_new_semaphore(self) -> bool: 

214 """Check if semaphore needs to be recreated.""" 

215 if self._semaphore is None: 

216 return True 

217 # Check if limit has changed 

218 return self._semaphore._value != self.current_limit 

219 

220 def get_stats(self) -> dict[str, int | float]: 

221 """Get current bulkhead statistics.""" 

222 with self._lock: 

223 return { 

224 "current_limit": self.current_limit, 

225 "min_limit": self.min_limit, 

226 "max_limit": self.max_limit, 

227 "error_rate": self.get_error_rate(), 

228 "success_streak": self._success_streak, 

229 "samples_count": len(self._samples), 

230 } 

231 

232 

233# ============================================================================= 

234# Provider Adaptive Bulkhead Factory 

235# ============================================================================= 

236 

237_provider_adaptive_bulkheads: dict[str, AdaptiveBulkhead] = {} 

238_bulkhead_lock = threading.Lock() 

239 

240 

241def get_provider_adaptive_bulkhead(provider: str) -> AdaptiveBulkhead: 

242 """ 

243 Get or create an adaptive bulkhead for a specific LLM provider. 

244 

245 Uses provider-specific limits from configuration as initial/max values. 

246 

247 Args: 

248 provider: LLM provider name (e.g., "anthropic", "openai") 

249 

250 Returns: 

251 AdaptiveBulkhead configured for the provider 

252 """ 

253 with _bulkhead_lock: 

254 if provider in _provider_adaptive_bulkheads: 

255 return _provider_adaptive_bulkheads[provider] 

256 

257 # Get provider-specific limits 

258 base_limit = get_provider_limit(provider) 

259 

260 # Configure adaptive bulkhead based on provider limits 

261 bulkhead = AdaptiveBulkhead( 

262 min_limit=max(2, base_limit // 4), # Floor at 25% of base 

263 max_limit=base_limit * 2, # Ceiling at 200% of base 

264 initial_limit=base_limit, # Start at base limit 

265 error_threshold=0.1, # 10% error rate triggers decrease 

266 ) 

267 

268 _provider_adaptive_bulkheads[provider] = bulkhead 

269 

270 logger.info( 

271 f"Created adaptive bulkhead for {provider}", 

272 extra={ 

273 "provider": provider, 

274 "initial_limit": base_limit, 

275 "min_limit": bulkhead.min_limit, 

276 "max_limit": bulkhead.max_limit, 

277 }, 

278 ) 

279 

280 return bulkhead 

281 

282 

283def reset_all_adaptive_bulkheads() -> None: 

284 """ 

285 Reset all adaptive bulkheads (for testing). 

286 

287 Warning: Only use for testing! In production, bulkheads should not be reset. 

288 """ 

289 with _bulkhead_lock: 

290 _provider_adaptive_bulkheads.clear() 

291 logger.warning("All adaptive bulkheads reset (testing only)") 

292 

293 

294def get_all_adaptive_bulkhead_stats() -> dict[str, dict[str, int | float]]: 

295 """ 

296 Get statistics for all adaptive bulkheads. 

297 

298 Returns: 

299 Dict mapping provider to bulkhead stats 

300 """ 

301 with _bulkhead_lock: 

302 return {provider: bulkhead.get_stats() for provider, bulkhead in _provider_adaptive_bulkheads.items()}