Coverage for src / mcp_server_langgraph / resilience / retry.py: 82%

203 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 06:31 +0000

1""" 

2Retry logic with exponential backoff and jitter. 

3 

4Automatically retries transient failures with configurable policies. 

5Uses tenacity library for declarative retry specifications. 

6 

7See ADR-0026 for design rationale. 

8""" 

9 

10import functools 

11import logging 

12import random 

13from collections.abc import Callable 

14from datetime import datetime 

15from email.utils import parsedate_to_datetime 

16from enum import Enum 

17from typing import ParamSpec, TypeVar 

18 

19from opentelemetry import trace 

20from tenacity import ( 

21 RetryCallState, 

22 RetryError, 

23 Retrying, 

24 retry_if_exception_type, 

25 stop_after_attempt, 

26 wait_exponential, 

27) 

28 

29from mcp_server_langgraph.observability.telemetry import retry_attempt_counter, retry_exhausted_counter 

30from mcp_server_langgraph.resilience.config import JitterStrategy, OverloadRetryConfig, get_resilience_config 

31 

32logger = logging.getLogger(__name__) 

33tracer = trace.get_tracer(__name__) 

34 

35P = ParamSpec("P") 

36T = TypeVar("T") 

37 

38# Check if redis is available (optional dependency) 

39_REDIS_AVAILABLE = False 

40try: 

41 import redis as _redis_module 

42 

43 _REDIS_AVAILABLE = True 

44except ImportError: 

45 _redis_module = None # type: ignore[assignment] 

46 logger.debug("Redis module not available. Redis error retry logic will be skipped.") 

47 

48 

49# ============================================================================= 

50# Jitter and Retry-After Utilities 

51# ============================================================================= 

52 

53 

54def calculate_jitter_delay( 

55 base_delay: float, 

56 prev_delay: float | None, 

57 max_delay: float, 

58 strategy: JitterStrategy, 

59) -> float: 

60 """Calculate delay with jitter based on strategy. 

61 

62 Args: 

63 base_delay: Base delay without jitter 

64 prev_delay: Previous delay (for decorrelated jitter) 

65 max_delay: Maximum allowed delay 

66 strategy: Jitter strategy to use 

67 

68 Returns: 

69 Delay with jitter applied 

70 

71 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 

72 """ 

73 if strategy == JitterStrategy.SIMPLE: 

74 # Simple jitter: +/- 20% of base delay 

75 jitter_factor = random.uniform(0.8, 1.2) 

76 return min(base_delay * jitter_factor, max_delay) 

77 

78 elif strategy == JitterStrategy.FULL: 

79 # Full jitter: random(0, delay) 

80 return random.uniform(0, min(base_delay, max_delay)) 

81 

82 else: # JitterStrategy.DECORRELATED 

83 # Decorrelated jitter: min(cap, random(base, prev * 3)) 

84 # Reference: AWS Architecture Blog 

85 if prev_delay is None: 

86 prev_delay = base_delay 

87 return min(max_delay, random.uniform(base_delay, prev_delay * 3)) 

88 

89 

90def parse_retry_after(value: str | int | float | None) -> float | None: 

91 """Parse Retry-After header value (RFC 7231). 

92 

93 Args: 

94 value: Either seconds (int/float) or HTTP-date string 

95 

96 Returns: 

97 Seconds to wait, or None if unparseable 

98 """ 

99 if value is None: 

100 return None 

101 

102 # Integer or float seconds 

103 if isinstance(value, (int, float)): 

104 return float(value) 

105 

106 # String: try as number first 

107 try: 

108 return float(value) 

109 except ValueError: 

110 pass 

111 

112 # String: try as HTTP-date 

113 try: 

114 retry_date = parsedate_to_datetime(value) 

115 delta = retry_date - datetime.now(retry_date.tzinfo) 

116 return max(0.0, delta.total_seconds()) 

117 except Exception: 

118 return None 

119 

120 

121def extract_retry_after_from_exception(exception: Exception) -> float | None: 

122 """Extract Retry-After value from LiteLLM/httpx exception. 

123 

124 Args: 

125 exception: The caught exception 

126 

127 Returns: 

128 Seconds to wait, or None if not available 

129 """ 

130 # Check for retry_after attribute (LiteLLM may add this) 

131 if hasattr(exception, "retry_after"): 

132 return parse_retry_after(exception.retry_after) 

133 

134 # Check for response headers (httpx exceptions) 

135 if hasattr(exception, "response") and exception.response is not None: 

136 headers = getattr(exception.response, "headers", {}) 

137 if headers: 137 ↛ 143line 137 didn't jump to line 143 because the condition on line 137 was always true

138 retry_after = headers.get("Retry-After") or headers.get("retry-after") 

139 if retry_after: 139 ↛ 143line 139 didn't jump to line 143 because the condition on line 139 was always true

140 return parse_retry_after(retry_after) 

141 

142 # Check LiteLLM's llm_provider_response_headers 

143 if hasattr(exception, "llm_provider_response_headers"): 143 ↛ 149line 143 didn't jump to line 149 because the condition on line 143 was always true

144 headers = exception.llm_provider_response_headers or {} 

145 retry_after = headers.get("Retry-After") or headers.get("retry-after") 

146 if retry_after: 

147 return parse_retry_after(retry_after) 

148 

149 return None 

150 

151 

152def is_overload_error(exception: Exception) -> bool: 

153 """Determine if exception indicates service overload. 

154 

155 Checks for: 

156 - HTTP 529 status code 

157 - "overloaded" in error message 

158 - LLMOverloadError type 

159 

160 Args: 

161 exception: The exception to check 

162 

163 Returns: 

164 True if this is an overload error, False otherwise 

165 """ 

166 # Check our custom exception type 

167 try: 

168 from mcp_server_langgraph.core.exceptions import LLMOverloadError 

169 

170 if isinstance(exception, LLMOverloadError): 

171 return True 

172 except ImportError: 

173 pass 

174 

175 # Check status code (LiteLLM exceptions have status_code attribute) 

176 status_code = getattr(exception, "status_code", None) 

177 if status_code == 529: 

178 return True 

179 

180 # Check error message patterns 

181 error_msg = str(exception).lower() 

182 overload_patterns = [ 

183 "overload", 

184 "service is overloaded", 

185 "capacity", 

186 ] 

187 

188 # 503 + overload message = treat as overload 

189 if status_code == 503 and any(p in error_msg for p in overload_patterns): 

190 return True 

191 

192 # Pure message-based detection (including "529" in message) 

193 return "overload" in error_msg or "529" in error_msg 

194 

195 

196class RetryPolicy(str, Enum): 

197 """Retry policies for different error types""" 

198 

199 NEVER = "never" # Never retry (client errors) 

200 ALWAYS = "always" # Always retry (transient failures) 

201 CONDITIONAL = "conditional" # Retry with conditions 

202 

203 

204class RetryStrategy(str, Enum): 

205 """Retry backoff strategies""" 

206 

207 EXPONENTIAL = "exponential" # Exponential backoff: 1s, 2s, 4s, 8s... 

208 LINEAR = "linear" # Linear backoff: 1s, 2s, 3s, 4s... 

209 FIXED = "fixed" # Fixed interval: 1s, 1s, 1s, 1s... 

210 RANDOM = "random" # Random jitter: 0-1s, 0-2s, 0-4s... 

211 

212 

213def should_retry_exception(exception: Exception) -> bool: 

214 """ 

215 Determine if an exception is retry-able. 

216 

217 Args: 

218 exception: The exception that occurred 

219 

220 Returns: 

221 True if should retry, False otherwise 

222 """ 

223 # Import here to avoid circular dependency 

224 try: 

225 from mcp_server_langgraph.core.exceptions import ( 

226 AuthenticationError, 

227 AuthorizationError, 

228 ExternalServiceError, 

229 RateLimitError, 

230 ResilienceError, 

231 ValidationError, 

232 ) 

233 

234 # Never retry client errors 

235 if isinstance(exception, (ValidationError, AuthorizationError)): 

236 return False 

237 

238 # Conditionally retry auth errors (e.g., token refresh) 

239 if isinstance(exception, AuthenticationError): 

240 # Only retry token expiration (can refresh) 

241 return exception.error_code == "auth.token_expired" 

242 

243 # Never retry rate limits from our own service 

244 if isinstance(exception, RateLimitError): 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 return False 

246 

247 # Always retry external service errors 

248 if isinstance(exception, ExternalServiceError): 

249 return True 

250 

251 # Always retry resilience errors (timeout, circuit breaker) 

252 if isinstance(exception, ResilienceError): 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 return True 

254 

255 except ImportError: 

256 # Exceptions not yet defined, fall back to generic logic 

257 pass 

258 

259 # Generic logic: retry network errors, timeouts 

260 import httpx 

261 

262 if isinstance(exception, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)): 

263 return True 

264 

265 # Check redis errors if redis is available (optional dependency) 

266 if _REDIS_AVAILABLE and _redis_module is not None: 266 ↛ 271line 266 didn't jump to line 271 because the condition on line 266 was always true

267 if isinstance(exception, (_redis_module.ConnectionError, _redis_module.TimeoutError)): 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true

268 return True 

269 

270 # Don't retry by default 

271 return False 

272 

273 

274def log_retry_attempt(retry_state: RetryCallState) -> None: 

275 """Log retry attempts for observability""" 

276 exception = retry_state.outcome.exception() if retry_state.outcome else None 

277 

278 logger.warning( 

279 f"Retrying after failure (attempt {retry_state.attempt_number})", 

280 extra={ 

281 "attempt_number": retry_state.attempt_number, 

282 "exception_type": type(exception).__name__ if exception else None, 

283 "next_action": str(retry_state.next_action), 

284 }, 

285 ) 

286 

287 # Emit metric 

288 retry_attempt_counter.add( 

289 1, 

290 attributes={ 

291 "attempt_number": retry_state.attempt_number, 

292 "exception_type": type(exception).__name__ if exception else "unknown", 

293 }, 

294 ) 

295 

296 

297def _calculate_retry_delay( 

298 attempt: int, 

299 exception: Exception, 

300 overload_aware: bool, 

301 overload_config: OverloadRetryConfig, 

302 exp_base: float, 

303 exp_max: float, 

304 jitter_strategy: JitterStrategy, 

305 prev_delay: float | None, 

306 strategy: RetryStrategy, 

307) -> float: 

308 """Calculate retry delay with overload awareness. 

309 

310 When overload_aware=True and the exception is an overload error, 

311 honors the Retry-After header if present and configured. 

312 Otherwise uses exponential backoff with jitter. 

313 

314 Args: 

315 attempt: Current attempt number (1-indexed) 

316 exception: The exception that caused the retry 

317 overload_aware: Whether overload-aware behavior is enabled 

318 overload_config: Configuration for overload-specific retry behavior 

319 exp_base: Exponential backoff base 

320 exp_max: Maximum delay 

321 jitter_strategy: Jitter strategy to use 

322 prev_delay: Previous delay (for decorrelated jitter) 

323 strategy: Overall retry strategy 

324 

325 Returns: 

326 Delay in seconds before next attempt 

327 """ 

328 # Check for Retry-After header (overload errors) 

329 if overload_aware and is_overload_error(exception) and overload_config.honor_retry_after: 

330 retry_after = extract_retry_after_from_exception(exception) 

331 if retry_after is not None: 331 ↛ 336line 331 didn't jump to line 336 because the condition on line 331 was always true

332 # Honor Retry-After but cap at configured max 

333 return min(retry_after, overload_config.retry_after_max) 

334 

335 # Calculate base delay based on strategy 

336 if strategy == RetryStrategy.EXPONENTIAL: 336 ↛ 338line 336 didn't jump to line 338 because the condition on line 336 was always true

337 base_delay = exp_base**attempt 

338 elif strategy == RetryStrategy.LINEAR: 

339 base_delay = exp_base * attempt 

340 elif strategy == RetryStrategy.FIXED: 

341 base_delay = exp_base 

342 else: # RANDOM - handled by jitter 

343 base_delay = exp_base**attempt 

344 

345 # Apply jitter 

346 return calculate_jitter_delay( 

347 base_delay=base_delay, 

348 prev_delay=prev_delay, 

349 max_delay=exp_max, 

350 strategy=jitter_strategy, 

351 ) 

352 

353 

354def log_retry_attempt_manual( 

355 attempt_number: int, 

356 exception: Exception, 

357 delay: float, 

358 func_name: str, 

359) -> None: 

360 """Log retry attempt for observability (manual retry loop version). 

361 

362 Args: 

363 attempt_number: Current attempt number 

364 exception: The exception that caused the retry 

365 delay: Delay before next attempt 

366 func_name: Name of the function being retried 

367 """ 

368 logger.warning( 

369 f"Retrying after failure (attempt {attempt_number})", 

370 extra={ 

371 "attempt_number": attempt_number, 

372 "exception_type": type(exception).__name__, 

373 "delay_seconds": delay, 

374 "function": func_name, 

375 }, 

376 ) 

377 

378 # Emit metric 

379 retry_attempt_counter.add( 

380 1, 

381 attributes={ 

382 "attempt_number": attempt_number, 

383 "exception_type": type(exception).__name__, 

384 }, 

385 ) 

386 

387 

388def retry_with_backoff( # noqa: C901 

389 max_attempts: int | None = None, 

390 exponential_base: float | None = None, 

391 exponential_max: float | None = None, 

392 retry_on: type[Exception] | tuple[type[Exception], ...] | None = None, 

393 strategy: RetryStrategy = RetryStrategy.EXPONENTIAL, 

394 jitter_strategy: JitterStrategy | None = None, 

395 overload_aware: bool = False, 

396) -> Callable[[Callable[P, T]], Callable[P, T]]: 

397 """ 

398 Decorator to retry a function with exponential backoff. 

399 

400 Args: 

401 max_attempts: Maximum number of retry attempts (default: from config) 

402 exponential_base: Base for exponential backoff (default: from config) 

403 exponential_max: Maximum backoff time in seconds (default: from config) 

404 retry_on: Exception type(s) to retry on (default: auto-detect) 

405 strategy: Retry strategy (exponential, linear, fixed, random) 

406 jitter_strategy: Jitter strategy for randomizing delays (default: from config) 

407 overload_aware: Enable extended retry behavior for 529/overload errors 

408 

409 Usage: 

410 @retry_with_backoff(max_attempts=3, exponential_base=2) 

411 async def call_external_api() -> dict[str, Any]: 

412 async with httpx.AsyncClient() as client: 

413 response = await client.get("https://api.example.com") 

414 return response.json() 

415 

416 # With custom exception types 

417 @retry_with_backoff(retry_on=(httpx.TimeoutException, redis.ConnectionError)) 

418 async def fetch_data() -> str: 

419 # Will retry on timeout or connection errors 

420 return await get_data() 

421 

422 # With overload awareness for 529 errors 

423 @retry_with_backoff(max_attempts=3, overload_aware=True) 

424 async def call_llm() -> str: 

425 # Will use extended retry config for 529 overload errors 

426 return await llm_client.generate(prompt) 

427 """ 

428 # Load configuration 

429 config = get_resilience_config() 

430 retry_config = config.retry 

431 overload_config = retry_config.overload 

432 

433 # Use config defaults if not specified 

434 standard_max_attempts = max_attempts or retry_config.max_attempts 

435 standard_exponential_base = exponential_base or retry_config.exponential_base 

436 standard_exponential_max = exponential_max or retry_config.exponential_max 

437 standard_jitter_strategy = jitter_strategy or retry_config.jitter_strategy 

438 

439 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

440 @functools.wraps(func) 

441 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

442 """Async wrapper with retry logic. 

443 

444 When overload_aware=True and an overload error (529) is detected, 

445 the retry logic dynamically switches from standard config to the 

446 more aggressive overload config (more attempts, longer backoff). 

447 """ 

448 import asyncio as aio 

449 

450 # Initialize with standard config 

451 current_max_attempts = standard_max_attempts 

452 current_exp_base = standard_exponential_base 

453 current_exp_max = standard_exponential_max 

454 current_jitter = standard_jitter_strategy 

455 switched_to_overload_config = False 

456 

457 with tracer.start_as_current_span( 

458 f"retry.{func.__name__}", 

459 attributes={ 

460 "retry.max_attempts": current_max_attempts, 

461 "retry.strategy": strategy.value, 

462 "retry.overload_aware": overload_aware, 

463 }, 

464 ) as span: 

465 attempt_number = 0 

466 last_exception: Exception | None = None 

467 prev_delay: float | None = None 

468 

469 while attempt_number < current_max_attempts: 469 ↛ 532line 469 didn't jump to line 532 because the condition on line 469 was always true

470 attempt_number += 1 

471 

472 try: 

473 result: T = await func(*args, **kwargs) # type: ignore[misc] 

474 span.set_attribute("retry.success", True) 

475 span.set_attribute("retry.attempts", attempt_number) 

476 span.set_attribute("retry.switched_to_overload", switched_to_overload_config) 

477 return result 

478 

479 except Exception as e: 

480 last_exception = e 

481 

482 # Check if we should filter this exception type 

483 if retry_on and not isinstance(e, retry_on): 

484 raise 

485 

486 # Check if we should switch to overload config 

487 if overload_aware and not switched_to_overload_config and is_overload_error(e): 

488 # Switch to overload config for extended retry 

489 current_max_attempts = overload_config.max_attempts 

490 current_exp_base = overload_config.exponential_base 

491 current_exp_max = overload_config.exponential_max 

492 current_jitter = overload_config.jitter_strategy 

493 switched_to_overload_config = True 

494 logger.info( 

495 f"Overload detected, switching to extended retry config " 

496 f"(max_attempts={current_max_attempts}, " 

497 f"exponential_max={current_exp_max}s)", 

498 extra={ 

499 "function": func.__name__, 

500 "attempt": attempt_number, 

501 "overload_max_attempts": current_max_attempts, 

502 }, 

503 ) 

504 span.set_attribute("retry.overload_detected", True) 

505 span.set_attribute("retry.new_max_attempts", current_max_attempts) 

506 

507 # Check if we've exhausted attempts 

508 if attempt_number >= current_max_attempts: 

509 break 

510 

511 # Calculate delay with overload awareness 

512 delay = _calculate_retry_delay( 

513 attempt=attempt_number, 

514 exception=e, 

515 overload_aware=overload_aware and switched_to_overload_config, 

516 overload_config=overload_config, 

517 exp_base=current_exp_base, 

518 exp_max=current_exp_max, 

519 jitter_strategy=current_jitter, 

520 prev_delay=prev_delay, 

521 strategy=strategy, 

522 ) 

523 prev_delay = delay 

524 

525 # Log retry attempt 

526 log_retry_attempt_manual(attempt_number, e, delay, func.__name__) 

527 

528 # Wait before next attempt 

529 await aio.sleep(delay) 

530 

531 # Retry exhausted 

532 span.set_attribute("retry.success", False) 

533 span.set_attribute("retry.exhausted", True) 

534 span.set_attribute("retry.final_attempts", attempt_number) 

535 

536 logger.error( 

537 f"Retry exhausted after {attempt_number} attempts", 

538 exc_info=last_exception, 

539 extra={ 

540 "max_attempts": current_max_attempts, 

541 "function": func.__name__, 

542 "switched_to_overload": switched_to_overload_config, 

543 }, 

544 ) 

545 

546 # Emit metric 

547 retry_exhausted_counter.add(1, attributes={"function": func.__name__}) 

548 

549 # Wrap in our custom exception 

550 from mcp_server_langgraph.core.exceptions import RetryExhaustedError 

551 

552 raise RetryExhaustedError( 

553 message=f"Retry exhausted after {attempt_number} attempts", 

554 metadata={ 

555 "max_attempts": current_max_attempts, 

556 "function": func.__name__, 

557 "switched_to_overload": switched_to_overload_config, 

558 }, 

559 ) from last_exception 

560 

561 @functools.wraps(func) 

562 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

563 """Sync wrapper with retry logic""" 

564 # Use computed values from decorator (guaranteed non-None) 

565 retry_kwargs = { 

566 "stop": stop_after_attempt(standard_max_attempts), 

567 "reraise": False, # Raise RetryError instead of original exception 

568 "before_sleep": log_retry_attempt, 

569 } 

570 

571 if strategy == RetryStrategy.EXPONENTIAL: 

572 retry_kwargs["wait"] = wait_exponential( 

573 multiplier=standard_exponential_base, 

574 max=standard_exponential_max, 

575 ) 

576 

577 if retry_on: 

578 retry_kwargs["retry"] = retry_if_exception_type(retry_on) 

579 

580 try: 

581 for attempt in Retrying(**retry_kwargs): # type: ignore[arg-type] 

582 with attempt: 

583 return func(*args, **kwargs) 

584 except RetryError as e: 

585 from mcp_server_langgraph.core.exceptions import RetryExhaustedError 

586 

587 raise RetryExhaustedError( 

588 message=f"Retry exhausted after {standard_max_attempts} attempts", 

589 metadata={"max_attempts": standard_max_attempts}, 

590 ) from e.last_attempt.exception() 

591 # This should never be reached, but mypy needs an explicit return path 

592 msg = "Unreachable code" 

593 raise RuntimeError(msg) # pragma: no cover 

594 

595 # Return appropriate wrapper 

596 import asyncio 

597 

598 if asyncio.iscoroutinefunction(func): 598 ↛ 601line 598 didn't jump to line 601 because the condition on line 598 was always true

599 return async_wrapper # type: ignore[return-value] 

600 else: 

601 return sync_wrapper 

602 

603 return decorator