Coverage for src / mcp_server_langgraph / resilience / adaptive.py: 88%
96 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
1"""
2Adaptive bulkhead with self-tuning concurrency limits.
4Implements AIMD (Additive Increase Multiplicative Decrease) algorithm to
5automatically adjust concurrency limits based on observed error rates.
7Key features:
8- Monitors 429/529 error rates in sliding window
9- Multiplicatively decreases limit on errors (fast decrease)
10- Additively increases limit on success streaks (slow recovery)
11- Respects floor and ceiling limits
13This is similar to TCP congestion control (Jacobson 1988) and provides
14self-healing behavior for rate limit issues.
16See ADR-0026 for design rationale.
17"""
19import asyncio
20import logging
21import threading
22import time
23from collections import deque
25from opentelemetry import trace
27from mcp_server_langgraph.resilience.config import get_provider_limit
29logger = logging.getLogger(__name__)
30tracer = trace.get_tracer(__name__)
33# =============================================================================
34# Adaptive Bulkhead Configuration
35# =============================================================================
37# Default AIMD parameters
38DEFAULT_MIN_LIMIT = 2
39DEFAULT_MAX_LIMIT = 50
40DEFAULT_INITIAL_LIMIT = 10
41DEFAULT_ERROR_THRESHOLD = 0.1 # 10% error rate triggers decrease
42DEFAULT_DECREASE_FACTOR = 0.75 # Reduce by 25% on error
43DEFAULT_INCREASE_AMOUNT = 1 # Add 1 on success streak
44DEFAULT_SUCCESS_STREAK_THRESHOLD = 10 # 10 successes to trigger increase
45DEFAULT_WINDOW_SIZE = 100 # Sliding window for error rate calculation
48# =============================================================================
49# Adaptive Bulkhead Implementation
50# =============================================================================
53class AdaptiveBulkhead:
54 """
55 Self-tuning bulkhead with AIMD algorithm.
57 Automatically adjusts concurrency limit based on error rates:
58 - On error: limit = max(min_limit, limit * decrease_factor)
59 - On success streak: limit = min(max_limit, limit + increase_amount)
61 Attributes:
62 min_limit: Floor for concurrency limit
63 max_limit: Ceiling for concurrency limit
64 current_limit: Current concurrency limit
65 error_threshold: Error rate that triggers decrease
67 Example:
68 >>> bulkhead = AdaptiveBulkhead(min_limit=5, max_limit=50, initial_limit=10)
69 >>> semaphore = bulkhead.get_semaphore()
70 >>> async with semaphore:
71 ... try:
72 ... result = await call_api()
73 ... bulkhead.record_success()
74 ... except RateLimitError:
75 ... bulkhead.record_error()
76 """
78 def __init__(
79 self,
80 min_limit: int = DEFAULT_MIN_LIMIT,
81 max_limit: int = DEFAULT_MAX_LIMIT,
82 initial_limit: int | None = None,
83 error_threshold: float = DEFAULT_ERROR_THRESHOLD,
84 decrease_factor: float = DEFAULT_DECREASE_FACTOR,
85 increase_amount: int = DEFAULT_INCREASE_AMOUNT,
86 success_streak_threshold: int = DEFAULT_SUCCESS_STREAK_THRESHOLD,
87 window_size: int = DEFAULT_WINDOW_SIZE,
88 ):
89 """
90 Initialize adaptive bulkhead.
92 Args:
93 min_limit: Minimum concurrency limit (floor)
94 max_limit: Maximum concurrency limit (ceiling)
95 initial_limit: Starting limit (defaults to (min + max) / 2)
96 error_threshold: Error rate that triggers limit decrease
97 decrease_factor: Multiplicative decrease factor (0-1)
98 increase_amount: Additive increase on success streak
99 success_streak_threshold: Successes needed to trigger increase
100 window_size: Size of sliding window for error rate
101 """
102 self.min_limit = min_limit
103 self.max_limit = max_limit
104 self.error_threshold = error_threshold
105 self.decrease_factor = decrease_factor
106 self.increase_amount = increase_amount
107 self.success_streak_threshold = success_streak_threshold
108 self.window_size = window_size
110 # Set initial limit
111 if initial_limit is not None:
112 self._current_limit = float(max(min_limit, min(max_limit, initial_limit)))
113 else:
114 self._current_limit = float((min_limit + max_limit) // 2)
116 # Tracking state
117 self._success_streak = 0
118 self._samples: deque[bool] = deque(maxlen=window_size) # True = success, False = error
119 self._lock = threading.Lock()
120 self._semaphore: asyncio.Semaphore | None = None
121 self._last_update = time.monotonic()
123 logger.info(
124 "Initialized adaptive bulkhead",
125 extra={
126 "min_limit": min_limit,
127 "max_limit": max_limit,
128 "initial_limit": self._current_limit,
129 "error_threshold": error_threshold,
130 },
131 )
133 @property
134 def current_limit(self) -> int:
135 """Get current concurrency limit (integer)."""
136 return int(self._current_limit)
138 def record_success(self) -> None:
139 """Record a successful operation."""
140 with self._lock:
141 self._samples.append(True)
142 self._success_streak += 1
144 # Check if we should increase limit
145 if self._success_streak >= self.success_streak_threshold:
146 old_limit = self._current_limit
147 self._current_limit = min(self.max_limit, self._current_limit + self.increase_amount)
148 self._success_streak = 0 # Reset streak after increase
149 self._invalidate_semaphore()
151 if self._current_limit > old_limit: 151 ↛ exitline 151 didn't jump to the function exit
152 logger.info(
153 f"Adaptive bulkhead increased limit: {old_limit:.0f} -> {self._current_limit:.0f}",
154 extra={
155 "old_limit": old_limit,
156 "new_limit": self._current_limit,
157 "reason": "success_streak",
158 },
159 )
161 def record_error(self) -> None:
162 """Record an error (429/529 or similar)."""
163 with self._lock:
164 self._samples.append(False)
165 self._success_streak = 0 # Reset streak on error
167 # Decrease limit multiplicatively
168 old_limit = self._current_limit
169 self._current_limit = max(self.min_limit, self._current_limit * self.decrease_factor)
170 self._invalidate_semaphore()
172 logger.warning(
173 f"Adaptive bulkhead decreased limit: {old_limit:.0f} -> {self._current_limit:.0f}",
174 extra={
175 "old_limit": old_limit,
176 "new_limit": self._current_limit,
177 "reason": "error",
178 },
179 )
181 def get_error_rate(self) -> float:
182 """
183 Get current error rate from sliding window.
185 Returns:
186 Error rate as float (0.0 - 1.0)
187 """
188 with self._lock:
189 if not self._samples: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 return 0.0
191 errors = sum(1 for s in self._samples if not s)
192 return errors / len(self._samples)
194 def get_semaphore(self) -> asyncio.Semaphore:
195 """
196 Get asyncio semaphore with current limit.
198 Returns a new semaphore if limit has changed since last call.
200 Returns:
201 asyncio.Semaphore configured for current limit
202 """
203 with self._lock:
204 if self._semaphore is None or self._needs_new_semaphore(): 204 ↛ 207line 204 didn't jump to line 207 because the condition on line 204 was always true
205 self._semaphore = asyncio.Semaphore(self.current_limit)
206 self._last_update = time.monotonic()
207 return self._semaphore
209 def _invalidate_semaphore(self) -> None:
210 """Mark semaphore as needing update."""
211 self._semaphore = None
213 def _needs_new_semaphore(self) -> bool:
214 """Check if semaphore needs to be recreated."""
215 if self._semaphore is None:
216 return True
217 # Check if limit has changed
218 return self._semaphore._value != self.current_limit
220 def get_stats(self) -> dict[str, int | float]:
221 """Get current bulkhead statistics."""
222 with self._lock:
223 return {
224 "current_limit": self.current_limit,
225 "min_limit": self.min_limit,
226 "max_limit": self.max_limit,
227 "error_rate": self.get_error_rate(),
228 "success_streak": self._success_streak,
229 "samples_count": len(self._samples),
230 }
233# =============================================================================
234# Provider Adaptive Bulkhead Factory
235# =============================================================================
237_provider_adaptive_bulkheads: dict[str, AdaptiveBulkhead] = {}
238_bulkhead_lock = threading.Lock()
241def get_provider_adaptive_bulkhead(provider: str) -> AdaptiveBulkhead:
242 """
243 Get or create an adaptive bulkhead for a specific LLM provider.
245 Uses provider-specific limits from configuration as initial/max values.
247 Args:
248 provider: LLM provider name (e.g., "anthropic", "openai")
250 Returns:
251 AdaptiveBulkhead configured for the provider
252 """
253 with _bulkhead_lock:
254 if provider in _provider_adaptive_bulkheads:
255 return _provider_adaptive_bulkheads[provider]
257 # Get provider-specific limits
258 base_limit = get_provider_limit(provider)
260 # Configure adaptive bulkhead based on provider limits
261 bulkhead = AdaptiveBulkhead(
262 min_limit=max(2, base_limit // 4), # Floor at 25% of base
263 max_limit=base_limit * 2, # Ceiling at 200% of base
264 initial_limit=base_limit, # Start at base limit
265 error_threshold=0.1, # 10% error rate triggers decrease
266 )
268 _provider_adaptive_bulkheads[provider] = bulkhead
270 logger.info(
271 f"Created adaptive bulkhead for {provider}",
272 extra={
273 "provider": provider,
274 "initial_limit": base_limit,
275 "min_limit": bulkhead.min_limit,
276 "max_limit": bulkhead.max_limit,
277 },
278 )
280 return bulkhead
283def reset_all_adaptive_bulkheads() -> None:
284 """
285 Reset all adaptive bulkheads (for testing).
287 Warning: Only use for testing! In production, bulkheads should not be reset.
288 """
289 with _bulkhead_lock:
290 _provider_adaptive_bulkheads.clear()
291 logger.warning("All adaptive bulkheads reset (testing only)")
294def get_all_adaptive_bulkhead_stats() -> dict[str, dict[str, int | float]]:
295 """
296 Get statistics for all adaptive bulkheads.
298 Returns:
299 Dict mapping provider to bulkhead stats
300 """
301 with _bulkhead_lock:
302 return {provider: bulkhead.get_stats() for provider, bulkhead in _provider_adaptive_bulkheads.items()}