Coverage for src / mcp_server_langgraph / resilience / retry.py: 82%
203 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
1"""
2Retry logic with exponential backoff and jitter.
4Automatically retries transient failures with configurable policies.
5Uses tenacity library for declarative retry specifications.
7See ADR-0026 for design rationale.
8"""
10import functools
11import logging
12import random
13from collections.abc import Callable
14from datetime import datetime
15from email.utils import parsedate_to_datetime
16from enum import Enum
17from typing import ParamSpec, TypeVar
19from opentelemetry import trace
20from tenacity import (
21 RetryCallState,
22 RetryError,
23 Retrying,
24 retry_if_exception_type,
25 stop_after_attempt,
26 wait_exponential,
27)
29from mcp_server_langgraph.observability.telemetry import retry_attempt_counter, retry_exhausted_counter
30from mcp_server_langgraph.resilience.config import JitterStrategy, OverloadRetryConfig, get_resilience_config
32logger = logging.getLogger(__name__)
33tracer = trace.get_tracer(__name__)
35P = ParamSpec("P")
36T = TypeVar("T")
38# Check if redis is available (optional dependency)
39_REDIS_AVAILABLE = False
40try:
41 import redis as _redis_module
43 _REDIS_AVAILABLE = True
44except ImportError:
45 _redis_module = None # type: ignore[assignment]
46 logger.debug("Redis module not available. Redis error retry logic will be skipped.")
49# =============================================================================
50# Jitter and Retry-After Utilities
51# =============================================================================
54def calculate_jitter_delay(
55 base_delay: float,
56 prev_delay: float | None,
57 max_delay: float,
58 strategy: JitterStrategy,
59) -> float:
60 """Calculate delay with jitter based on strategy.
62 Args:
63 base_delay: Base delay without jitter
64 prev_delay: Previous delay (for decorrelated jitter)
65 max_delay: Maximum allowed delay
66 strategy: Jitter strategy to use
68 Returns:
69 Delay with jitter applied
71 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
72 """
73 if strategy == JitterStrategy.SIMPLE:
74 # Simple jitter: +/- 20% of base delay
75 jitter_factor = random.uniform(0.8, 1.2)
76 return min(base_delay * jitter_factor, max_delay)
78 elif strategy == JitterStrategy.FULL:
79 # Full jitter: random(0, delay)
80 return random.uniform(0, min(base_delay, max_delay))
82 else: # JitterStrategy.DECORRELATED
83 # Decorrelated jitter: min(cap, random(base, prev * 3))
84 # Reference: AWS Architecture Blog
85 if prev_delay is None:
86 prev_delay = base_delay
87 return min(max_delay, random.uniform(base_delay, prev_delay * 3))
90def parse_retry_after(value: str | int | float | None) -> float | None:
91 """Parse Retry-After header value (RFC 7231).
93 Args:
94 value: Either seconds (int/float) or HTTP-date string
96 Returns:
97 Seconds to wait, or None if unparseable
98 """
99 if value is None:
100 return None
102 # Integer or float seconds
103 if isinstance(value, (int, float)):
104 return float(value)
106 # String: try as number first
107 try:
108 return float(value)
109 except ValueError:
110 pass
112 # String: try as HTTP-date
113 try:
114 retry_date = parsedate_to_datetime(value)
115 delta = retry_date - datetime.now(retry_date.tzinfo)
116 return max(0.0, delta.total_seconds())
117 except Exception:
118 return None
121def extract_retry_after_from_exception(exception: Exception) -> float | None:
122 """Extract Retry-After value from LiteLLM/httpx exception.
124 Args:
125 exception: The caught exception
127 Returns:
128 Seconds to wait, or None if not available
129 """
130 # Check for retry_after attribute (LiteLLM may add this)
131 if hasattr(exception, "retry_after"):
132 return parse_retry_after(exception.retry_after)
134 # Check for response headers (httpx exceptions)
135 if hasattr(exception, "response") and exception.response is not None:
136 headers = getattr(exception.response, "headers", {})
137 if headers: 137 ↛ 143line 137 didn't jump to line 143 because the condition on line 137 was always true
138 retry_after = headers.get("Retry-After") or headers.get("retry-after")
139 if retry_after: 139 ↛ 143line 139 didn't jump to line 143 because the condition on line 139 was always true
140 return parse_retry_after(retry_after)
142 # Check LiteLLM's llm_provider_response_headers
143 if hasattr(exception, "llm_provider_response_headers"): 143 ↛ 149line 143 didn't jump to line 149 because the condition on line 143 was always true
144 headers = exception.llm_provider_response_headers or {}
145 retry_after = headers.get("Retry-After") or headers.get("retry-after")
146 if retry_after:
147 return parse_retry_after(retry_after)
149 return None
152def is_overload_error(exception: Exception) -> bool:
153 """Determine if exception indicates service overload.
155 Checks for:
156 - HTTP 529 status code
157 - "overloaded" in error message
158 - LLMOverloadError type
160 Args:
161 exception: The exception to check
163 Returns:
164 True if this is an overload error, False otherwise
165 """
166 # Check our custom exception type
167 try:
168 from mcp_server_langgraph.core.exceptions import LLMOverloadError
170 if isinstance(exception, LLMOverloadError):
171 return True
172 except ImportError:
173 pass
175 # Check status code (LiteLLM exceptions have status_code attribute)
176 status_code = getattr(exception, "status_code", None)
177 if status_code == 529:
178 return True
180 # Check error message patterns
181 error_msg = str(exception).lower()
182 overload_patterns = [
183 "overload",
184 "service is overloaded",
185 "capacity",
186 ]
188 # 503 + overload message = treat as overload
189 if status_code == 503 and any(p in error_msg for p in overload_patterns):
190 return True
192 # Pure message-based detection (including "529" in message)
193 return "overload" in error_msg or "529" in error_msg
196class RetryPolicy(str, Enum):
197 """Retry policies for different error types"""
199 NEVER = "never" # Never retry (client errors)
200 ALWAYS = "always" # Always retry (transient failures)
201 CONDITIONAL = "conditional" # Retry with conditions
204class RetryStrategy(str, Enum):
205 """Retry backoff strategies"""
207 EXPONENTIAL = "exponential" # Exponential backoff: 1s, 2s, 4s, 8s...
208 LINEAR = "linear" # Linear backoff: 1s, 2s, 3s, 4s...
209 FIXED = "fixed" # Fixed interval: 1s, 1s, 1s, 1s...
210 RANDOM = "random" # Random jitter: 0-1s, 0-2s, 0-4s...
213def should_retry_exception(exception: Exception) -> bool:
214 """
215 Determine if an exception is retry-able.
217 Args:
218 exception: The exception that occurred
220 Returns:
221 True if should retry, False otherwise
222 """
223 # Import here to avoid circular dependency
224 try:
225 from mcp_server_langgraph.core.exceptions import (
226 AuthenticationError,
227 AuthorizationError,
228 ExternalServiceError,
229 RateLimitError,
230 ResilienceError,
231 ValidationError,
232 )
234 # Never retry client errors
235 if isinstance(exception, (ValidationError, AuthorizationError)):
236 return False
238 # Conditionally retry auth errors (e.g., token refresh)
239 if isinstance(exception, AuthenticationError):
240 # Only retry token expiration (can refresh)
241 return exception.error_code == "auth.token_expired"
243 # Never retry rate limits from our own service
244 if isinstance(exception, RateLimitError): 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true
245 return False
247 # Always retry external service errors
248 if isinstance(exception, ExternalServiceError):
249 return True
251 # Always retry resilience errors (timeout, circuit breaker)
252 if isinstance(exception, ResilienceError): 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 return True
255 except ImportError:
256 # Exceptions not yet defined, fall back to generic logic
257 pass
259 # Generic logic: retry network errors, timeouts
260 import httpx
262 if isinstance(exception, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)):
263 return True
265 # Check redis errors if redis is available (optional dependency)
266 if _REDIS_AVAILABLE and _redis_module is not None: 266 ↛ 271line 266 didn't jump to line 271 because the condition on line 266 was always true
267 if isinstance(exception, (_redis_module.ConnectionError, _redis_module.TimeoutError)): 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true
268 return True
270 # Don't retry by default
271 return False
274def log_retry_attempt(retry_state: RetryCallState) -> None:
275 """Log retry attempts for observability"""
276 exception = retry_state.outcome.exception() if retry_state.outcome else None
278 logger.warning(
279 f"Retrying after failure (attempt {retry_state.attempt_number})",
280 extra={
281 "attempt_number": retry_state.attempt_number,
282 "exception_type": type(exception).__name__ if exception else None,
283 "next_action": str(retry_state.next_action),
284 },
285 )
287 # Emit metric
288 retry_attempt_counter.add(
289 1,
290 attributes={
291 "attempt_number": retry_state.attempt_number,
292 "exception_type": type(exception).__name__ if exception else "unknown",
293 },
294 )
297def _calculate_retry_delay(
298 attempt: int,
299 exception: Exception,
300 overload_aware: bool,
301 overload_config: OverloadRetryConfig,
302 exp_base: float,
303 exp_max: float,
304 jitter_strategy: JitterStrategy,
305 prev_delay: float | None,
306 strategy: RetryStrategy,
307) -> float:
308 """Calculate retry delay with overload awareness.
310 When overload_aware=True and the exception is an overload error,
311 honors the Retry-After header if present and configured.
312 Otherwise uses exponential backoff with jitter.
314 Args:
315 attempt: Current attempt number (1-indexed)
316 exception: The exception that caused the retry
317 overload_aware: Whether overload-aware behavior is enabled
318 overload_config: Configuration for overload-specific retry behavior
319 exp_base: Exponential backoff base
320 exp_max: Maximum delay
321 jitter_strategy: Jitter strategy to use
322 prev_delay: Previous delay (for decorrelated jitter)
323 strategy: Overall retry strategy
325 Returns:
326 Delay in seconds before next attempt
327 """
328 # Check for Retry-After header (overload errors)
329 if overload_aware and is_overload_error(exception) and overload_config.honor_retry_after:
330 retry_after = extract_retry_after_from_exception(exception)
331 if retry_after is not None: 331 ↛ 336line 331 didn't jump to line 336 because the condition on line 331 was always true
332 # Honor Retry-After but cap at configured max
333 return min(retry_after, overload_config.retry_after_max)
335 # Calculate base delay based on strategy
336 if strategy == RetryStrategy.EXPONENTIAL: 336 ↛ 338line 336 didn't jump to line 338 because the condition on line 336 was always true
337 base_delay = exp_base**attempt
338 elif strategy == RetryStrategy.LINEAR:
339 base_delay = exp_base * attempt
340 elif strategy == RetryStrategy.FIXED:
341 base_delay = exp_base
342 else: # RANDOM - handled by jitter
343 base_delay = exp_base**attempt
345 # Apply jitter
346 return calculate_jitter_delay(
347 base_delay=base_delay,
348 prev_delay=prev_delay,
349 max_delay=exp_max,
350 strategy=jitter_strategy,
351 )
354def log_retry_attempt_manual(
355 attempt_number: int,
356 exception: Exception,
357 delay: float,
358 func_name: str,
359) -> None:
360 """Log retry attempt for observability (manual retry loop version).
362 Args:
363 attempt_number: Current attempt number
364 exception: The exception that caused the retry
365 delay: Delay before next attempt
366 func_name: Name of the function being retried
367 """
368 logger.warning(
369 f"Retrying after failure (attempt {attempt_number})",
370 extra={
371 "attempt_number": attempt_number,
372 "exception_type": type(exception).__name__,
373 "delay_seconds": delay,
374 "function": func_name,
375 },
376 )
378 # Emit metric
379 retry_attempt_counter.add(
380 1,
381 attributes={
382 "attempt_number": attempt_number,
383 "exception_type": type(exception).__name__,
384 },
385 )
388def retry_with_backoff( # noqa: C901
389 max_attempts: int | None = None,
390 exponential_base: float | None = None,
391 exponential_max: float | None = None,
392 retry_on: type[Exception] | tuple[type[Exception], ...] | None = None,
393 strategy: RetryStrategy = RetryStrategy.EXPONENTIAL,
394 jitter_strategy: JitterStrategy | None = None,
395 overload_aware: bool = False,
396) -> Callable[[Callable[P, T]], Callable[P, T]]:
397 """
398 Decorator to retry a function with exponential backoff.
400 Args:
401 max_attempts: Maximum number of retry attempts (default: from config)
402 exponential_base: Base for exponential backoff (default: from config)
403 exponential_max: Maximum backoff time in seconds (default: from config)
404 retry_on: Exception type(s) to retry on (default: auto-detect)
405 strategy: Retry strategy (exponential, linear, fixed, random)
406 jitter_strategy: Jitter strategy for randomizing delays (default: from config)
407 overload_aware: Enable extended retry behavior for 529/overload errors
409 Usage:
410 @retry_with_backoff(max_attempts=3, exponential_base=2)
411 async def call_external_api() -> dict[str, Any]:
412 async with httpx.AsyncClient() as client:
413 response = await client.get("https://api.example.com")
414 return response.json()
416 # With custom exception types
417 @retry_with_backoff(retry_on=(httpx.TimeoutException, redis.ConnectionError))
418 async def fetch_data() -> str:
419 # Will retry on timeout or connection errors
420 return await get_data()
422 # With overload awareness for 529 errors
423 @retry_with_backoff(max_attempts=3, overload_aware=True)
424 async def call_llm() -> str:
425 # Will use extended retry config for 529 overload errors
426 return await llm_client.generate(prompt)
427 """
428 # Load configuration
429 config = get_resilience_config()
430 retry_config = config.retry
431 overload_config = retry_config.overload
433 # Use config defaults if not specified
434 standard_max_attempts = max_attempts or retry_config.max_attempts
435 standard_exponential_base = exponential_base or retry_config.exponential_base
436 standard_exponential_max = exponential_max or retry_config.exponential_max
437 standard_jitter_strategy = jitter_strategy or retry_config.jitter_strategy
439 def decorator(func: Callable[P, T]) -> Callable[P, T]:
440 @functools.wraps(func)
441 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
442 """Async wrapper with retry logic.
444 When overload_aware=True and an overload error (529) is detected,
445 the retry logic dynamically switches from standard config to the
446 more aggressive overload config (more attempts, longer backoff).
447 """
448 import asyncio as aio
450 # Initialize with standard config
451 current_max_attempts = standard_max_attempts
452 current_exp_base = standard_exponential_base
453 current_exp_max = standard_exponential_max
454 current_jitter = standard_jitter_strategy
455 switched_to_overload_config = False
457 with tracer.start_as_current_span(
458 f"retry.{func.__name__}",
459 attributes={
460 "retry.max_attempts": current_max_attempts,
461 "retry.strategy": strategy.value,
462 "retry.overload_aware": overload_aware,
463 },
464 ) as span:
465 attempt_number = 0
466 last_exception: Exception | None = None
467 prev_delay: float | None = None
469 while attempt_number < current_max_attempts: 469 ↛ 532line 469 didn't jump to line 532 because the condition on line 469 was always true
470 attempt_number += 1
472 try:
473 result: T = await func(*args, **kwargs) # type: ignore[misc]
474 span.set_attribute("retry.success", True)
475 span.set_attribute("retry.attempts", attempt_number)
476 span.set_attribute("retry.switched_to_overload", switched_to_overload_config)
477 return result
479 except Exception as e:
480 last_exception = e
482 # Check if we should filter this exception type
483 if retry_on and not isinstance(e, retry_on):
484 raise
486 # Check if we should switch to overload config
487 if overload_aware and not switched_to_overload_config and is_overload_error(e):
488 # Switch to overload config for extended retry
489 current_max_attempts = overload_config.max_attempts
490 current_exp_base = overload_config.exponential_base
491 current_exp_max = overload_config.exponential_max
492 current_jitter = overload_config.jitter_strategy
493 switched_to_overload_config = True
494 logger.info(
495 f"Overload detected, switching to extended retry config "
496 f"(max_attempts={current_max_attempts}, "
497 f"exponential_max={current_exp_max}s)",
498 extra={
499 "function": func.__name__,
500 "attempt": attempt_number,
501 "overload_max_attempts": current_max_attempts,
502 },
503 )
504 span.set_attribute("retry.overload_detected", True)
505 span.set_attribute("retry.new_max_attempts", current_max_attempts)
507 # Check if we've exhausted attempts
508 if attempt_number >= current_max_attempts:
509 break
511 # Calculate delay with overload awareness
512 delay = _calculate_retry_delay(
513 attempt=attempt_number,
514 exception=e,
515 overload_aware=overload_aware and switched_to_overload_config,
516 overload_config=overload_config,
517 exp_base=current_exp_base,
518 exp_max=current_exp_max,
519 jitter_strategy=current_jitter,
520 prev_delay=prev_delay,
521 strategy=strategy,
522 )
523 prev_delay = delay
525 # Log retry attempt
526 log_retry_attempt_manual(attempt_number, e, delay, func.__name__)
528 # Wait before next attempt
529 await aio.sleep(delay)
531 # Retry exhausted
532 span.set_attribute("retry.success", False)
533 span.set_attribute("retry.exhausted", True)
534 span.set_attribute("retry.final_attempts", attempt_number)
536 logger.error(
537 f"Retry exhausted after {attempt_number} attempts",
538 exc_info=last_exception,
539 extra={
540 "max_attempts": current_max_attempts,
541 "function": func.__name__,
542 "switched_to_overload": switched_to_overload_config,
543 },
544 )
546 # Emit metric
547 retry_exhausted_counter.add(1, attributes={"function": func.__name__})
549 # Wrap in our custom exception
550 from mcp_server_langgraph.core.exceptions import RetryExhaustedError
552 raise RetryExhaustedError(
553 message=f"Retry exhausted after {attempt_number} attempts",
554 metadata={
555 "max_attempts": current_max_attempts,
556 "function": func.__name__,
557 "switched_to_overload": switched_to_overload_config,
558 },
559 ) from last_exception
561 @functools.wraps(func)
562 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
563 """Sync wrapper with retry logic"""
564 # Use computed values from decorator (guaranteed non-None)
565 retry_kwargs = {
566 "stop": stop_after_attempt(standard_max_attempts),
567 "reraise": False, # Raise RetryError instead of original exception
568 "before_sleep": log_retry_attempt,
569 }
571 if strategy == RetryStrategy.EXPONENTIAL:
572 retry_kwargs["wait"] = wait_exponential(
573 multiplier=standard_exponential_base,
574 max=standard_exponential_max,
575 )
577 if retry_on:
578 retry_kwargs["retry"] = retry_if_exception_type(retry_on)
580 try:
581 for attempt in Retrying(**retry_kwargs): # type: ignore[arg-type]
582 with attempt:
583 return func(*args, **kwargs)
584 except RetryError as e:
585 from mcp_server_langgraph.core.exceptions import RetryExhaustedError
587 raise RetryExhaustedError(
588 message=f"Retry exhausted after {standard_max_attempts} attempts",
589 metadata={"max_attempts": standard_max_attempts},
590 ) from e.last_attempt.exception()
591 # This should never be reached, but mypy needs an explicit return path
592 msg = "Unreachable code"
593 raise RuntimeError(msg) # pragma: no cover
595 # Return appropriate wrapper
596 import asyncio
598 if asyncio.iscoroutinefunction(func): 598 ↛ 601line 598 didn't jump to line 601 because the condition on line 598 was always true
599 return async_wrapper # type: ignore[return-value]
600 else:
601 return sync_wrapper
603 return decorator