Coverage for src / mcp_server_langgraph / resilience / retry.py: 81%

166 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Retry logic with exponential backoff and jitter. 

3 

4Automatically retries transient failures with configurable policies. 

5Uses tenacity library for declarative retry specifications. 

6 

7See ADR-0026 for design rationale. 

8""" 

9 

10import functools 

11import logging 

12import random 

13from collections.abc import Callable 

14from datetime import datetime 

15from email.utils import parsedate_to_datetime 

16from enum import Enum 

17from typing import ParamSpec, TypeVar 

18 

19from opentelemetry import trace 

20from tenacity import ( 

21 AsyncRetrying, 

22 RetryCallState, 

23 RetryError, 

24 Retrying, 

25 retry_if_exception_type, 

26 stop_after_attempt, 

27 wait_exponential, 

28 wait_random, 

29) 

30 

31from mcp_server_langgraph.observability.telemetry import retry_attempt_counter, retry_exhausted_counter 

32from mcp_server_langgraph.resilience.config import JitterStrategy, get_resilience_config 

33 

34logger = logging.getLogger(__name__) 

35tracer = trace.get_tracer(__name__) 

36 

37P = ParamSpec("P") 

38T = TypeVar("T") 

39 

40# Check if redis is available (optional dependency) 

41_REDIS_AVAILABLE = False 

42try: 

43 import redis as _redis_module 

44 

45 _REDIS_AVAILABLE = True 

46except ImportError: 

47 _redis_module = None # type: ignore[assignment] 

48 logger.debug("Redis module not available. Redis error retry logic will be skipped.") 

49 

50 

51# ============================================================================= 

52# Jitter and Retry-After Utilities 

53# ============================================================================= 

54 

55 

56def calculate_jitter_delay( 

57 base_delay: float, 

58 prev_delay: float | None, 

59 max_delay: float, 

60 strategy: JitterStrategy, 

61) -> float: 

62 """Calculate delay with jitter based on strategy. 

63 

64 Args: 

65 base_delay: Base delay without jitter 

66 prev_delay: Previous delay (for decorrelated jitter) 

67 max_delay: Maximum allowed delay 

68 strategy: Jitter strategy to use 

69 

70 Returns: 

71 Delay with jitter applied 

72 

73 See: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 

74 """ 

75 if strategy == JitterStrategy.SIMPLE: 

76 # Simple jitter: +/- 20% of base delay 

77 jitter_factor = random.uniform(0.8, 1.2) 

78 return min(base_delay * jitter_factor, max_delay) 

79 

80 elif strategy == JitterStrategy.FULL: 

81 # Full jitter: random(0, delay) 

82 return random.uniform(0, min(base_delay, max_delay)) 

83 

84 else: # JitterStrategy.DECORRELATED 

85 # Decorrelated jitter: min(cap, random(base, prev * 3)) 

86 # Reference: AWS Architecture Blog 

87 if prev_delay is None: 

88 prev_delay = base_delay 

89 return min(max_delay, random.uniform(base_delay, prev_delay * 3)) 

90 

91 

92def parse_retry_after(value: str | int | float | None) -> float | None: 

93 """Parse Retry-After header value (RFC 7231). 

94 

95 Args: 

96 value: Either seconds (int/float) or HTTP-date string 

97 

98 Returns: 

99 Seconds to wait, or None if unparseable 

100 """ 

101 if value is None: 

102 return None 

103 

104 # Integer or float seconds 

105 if isinstance(value, (int, float)): 

106 return float(value) 

107 

108 # String: try as number first 

109 try: 

110 return float(value) 

111 except ValueError: 

112 pass 

113 

114 # String: try as HTTP-date 

115 try: 

116 retry_date = parsedate_to_datetime(value) 

117 delta = retry_date - datetime.now(retry_date.tzinfo) 

118 return max(0.0, delta.total_seconds()) 

119 except Exception: 

120 return None 

121 

122 

123def extract_retry_after_from_exception(exception: Exception) -> float | None: 

124 """Extract Retry-After value from LiteLLM/httpx exception. 

125 

126 Args: 

127 exception: The caught exception 

128 

129 Returns: 

130 Seconds to wait, or None if not available 

131 """ 

132 # Check for retry_after attribute (LiteLLM may add this) 

133 if hasattr(exception, "retry_after"): 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 return parse_retry_after(exception.retry_after) 

135 

136 # Check for response headers (httpx exceptions) 

137 if hasattr(exception, "response") and exception.response is not None: 

138 headers = getattr(exception.response, "headers", {}) 

139 if headers: 139 ↛ 145line 139 didn't jump to line 145 because the condition on line 139 was always true

140 retry_after = headers.get("Retry-After") or headers.get("retry-after") 

141 if retry_after: 141 ↛ 145line 141 didn't jump to line 145 because the condition on line 141 was always true

142 return parse_retry_after(retry_after) 

143 

144 # Check LiteLLM's llm_provider_response_headers 

145 if hasattr(exception, "llm_provider_response_headers"): 145 ↛ 151line 145 didn't jump to line 151 because the condition on line 145 was always true

146 headers = exception.llm_provider_response_headers or {} 

147 retry_after = headers.get("Retry-After") or headers.get("retry-after") 

148 if retry_after: 

149 return parse_retry_after(retry_after) 

150 

151 return None 

152 

153 

154def is_overload_error(exception: Exception) -> bool: 

155 """Determine if exception indicates service overload. 

156 

157 Checks for: 

158 - HTTP 529 status code 

159 - "overloaded" in error message 

160 - LLMOverloadError type 

161 

162 Args: 

163 exception: The exception to check 

164 

165 Returns: 

166 True if this is an overload error, False otherwise 

167 """ 

168 # Check our custom exception type 

169 try: 

170 from mcp_server_langgraph.core.exceptions import LLMOverloadError 

171 

172 if isinstance(exception, LLMOverloadError): 

173 return True 

174 except ImportError: 

175 pass 

176 

177 # Check status code (LiteLLM exceptions have status_code attribute) 

178 status_code = getattr(exception, "status_code", None) 

179 if status_code == 529: 

180 return True 

181 

182 # Check error message patterns 

183 error_msg = str(exception).lower() 

184 overload_patterns = [ 

185 "overload", 

186 "service is overloaded", 

187 "capacity", 

188 ] 

189 

190 # 503 + overload message = treat as overload 

191 if status_code == 503 and any(p in error_msg for p in overload_patterns): 

192 return True 

193 

194 # Pure message-based detection (including "529" in message) 

195 return "overload" in error_msg or "529" in error_msg 

196 

197 

198class RetryPolicy(str, Enum): 

199 """Retry policies for different error types""" 

200 

201 NEVER = "never" # Never retry (client errors) 

202 ALWAYS = "always" # Always retry (transient failures) 

203 CONDITIONAL = "conditional" # Retry with conditions 

204 

205 

206class RetryStrategy(str, Enum): 

207 """Retry backoff strategies""" 

208 

209 EXPONENTIAL = "exponential" # Exponential backoff: 1s, 2s, 4s, 8s... 

210 LINEAR = "linear" # Linear backoff: 1s, 2s, 3s, 4s... 

211 FIXED = "fixed" # Fixed interval: 1s, 1s, 1s, 1s... 

212 RANDOM = "random" # Random jitter: 0-1s, 0-2s, 0-4s... 

213 

214 

215def should_retry_exception(exception: Exception) -> bool: 

216 """ 

217 Determine if an exception is retry-able. 

218 

219 Args: 

220 exception: The exception that occurred 

221 

222 Returns: 

223 True if should retry, False otherwise 

224 """ 

225 # Import here to avoid circular dependency 

226 try: 

227 from mcp_server_langgraph.core.exceptions import ( 

228 AuthenticationError, 

229 AuthorizationError, 

230 ExternalServiceError, 

231 RateLimitError, 

232 ResilienceError, 

233 ValidationError, 

234 ) 

235 

236 # Never retry client errors 

237 if isinstance(exception, (ValidationError, AuthorizationError)): 

238 return False 

239 

240 # Conditionally retry auth errors (e.g., token refresh) 

241 if isinstance(exception, AuthenticationError): 

242 # Only retry token expiration (can refresh) 

243 return exception.error_code == "auth.token_expired" 

244 

245 # Never retry rate limits from our own service 

246 if isinstance(exception, RateLimitError): 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 return False 

248 

249 # Always retry external service errors 

250 if isinstance(exception, ExternalServiceError): 

251 return True 

252 

253 # Always retry resilience errors (timeout, circuit breaker) 

254 if isinstance(exception, ResilienceError): 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true

255 return True 

256 

257 except ImportError: 

258 # Exceptions not yet defined, fall back to generic logic 

259 pass 

260 

261 # Generic logic: retry network errors, timeouts 

262 import httpx 

263 

264 if isinstance(exception, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)): 

265 return True 

266 

267 # Check redis errors if redis is available (optional dependency) 

268 if _REDIS_AVAILABLE and _redis_module is not None: 268 ↛ 273line 268 didn't jump to line 273 because the condition on line 268 was always true

269 if isinstance(exception, (_redis_module.ConnectionError, _redis_module.TimeoutError)): 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true

270 return True 

271 

272 # Don't retry by default 

273 return False 

274 

275 

276def log_retry_attempt(retry_state: RetryCallState) -> None: 

277 """Log retry attempts for observability""" 

278 exception = retry_state.outcome.exception() if retry_state.outcome else None 

279 

280 logger.warning( 

281 f"Retrying after failure (attempt {retry_state.attempt_number})", 

282 extra={ 

283 "attempt_number": retry_state.attempt_number, 

284 "exception_type": type(exception).__name__ if exception else None, 

285 "next_action": str(retry_state.next_action), 

286 }, 

287 ) 

288 

289 # Emit metric 

290 retry_attempt_counter.add( 

291 1, 

292 attributes={ 

293 "attempt_number": retry_state.attempt_number, 

294 "exception_type": type(exception).__name__ if exception else "unknown", 

295 }, 

296 ) 

297 

298 

299def retry_with_backoff( # noqa: C901 

300 max_attempts: int | None = None, 

301 exponential_base: float | None = None, 

302 exponential_max: float | None = None, 

303 retry_on: type[Exception] | tuple[type[Exception], ...] | None = None, 

304 strategy: RetryStrategy = RetryStrategy.EXPONENTIAL, 

305 jitter_strategy: JitterStrategy | None = None, 

306 overload_aware: bool = False, 

307) -> Callable[[Callable[P, T]], Callable[P, T]]: 

308 """ 

309 Decorator to retry a function with exponential backoff. 

310 

311 Args: 

312 max_attempts: Maximum number of retry attempts (default: from config) 

313 exponential_base: Base for exponential backoff (default: from config) 

314 exponential_max: Maximum backoff time in seconds (default: from config) 

315 retry_on: Exception type(s) to retry on (default: auto-detect) 

316 strategy: Retry strategy (exponential, linear, fixed, random) 

317 jitter_strategy: Jitter strategy for randomizing delays (default: from config) 

318 overload_aware: Enable extended retry behavior for 529/overload errors 

319 

320 Usage: 

321 @retry_with_backoff(max_attempts=3, exponential_base=2) 

322 async def call_external_api() -> dict[str, Any]: 

323 async with httpx.AsyncClient() as client: 

324 response = await client.get("https://api.example.com") 

325 return response.json() 

326 

327 # With custom exception types 

328 @retry_with_backoff(retry_on=(httpx.TimeoutException, redis.ConnectionError)) 

329 async def fetch_data() -> str: 

330 # Will retry on timeout or connection errors 

331 return await get_data() 

332 

333 # With overload awareness for 529 errors 

334 @retry_with_backoff(max_attempts=3, overload_aware=True) 

335 async def call_llm() -> str: 

336 # Will use extended retry config for 529 overload errors 

337 return await llm_client.generate(prompt) 

338 """ 

339 # Load configuration 

340 config = get_resilience_config() 

341 retry_config = config.retry 

342 

343 # Use config defaults if not specified 

344 max_attempts = max_attempts or retry_config.max_attempts 

345 exponential_base = exponential_base or retry_config.exponential_base 

346 exponential_max = exponential_max or retry_config.exponential_max 

347 jitter_strategy = jitter_strategy or retry_config.jitter_strategy 

348 

349 # Note: overload_aware enables extended retry behavior for 529/overload errors 

350 # Future enhancement: dynamically adjust max_attempts and backoff for overload 

351 _ = overload_aware # Mark as used for now 

352 

353 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

354 @functools.wraps(func) 

355 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

356 """Async wrapper with retry logic""" 

357 with tracer.start_as_current_span( 

358 f"retry.{func.__name__}", 

359 attributes={ 

360 "retry.max_attempts": max_attempts, 

361 "retry.strategy": strategy.value, 

362 }, 

363 ) as span: 

364 # Configure retry behavior 

365 retry_kwargs = { 

366 "stop": stop_after_attempt(max_attempts), 

367 "reraise": False, # Raise RetryError instead of original exception 

368 "before_sleep": log_retry_attempt, 

369 } 

370 

371 # Configure wait strategy 

372 if strategy == RetryStrategy.EXPONENTIAL: 372 ↛ 377line 372 didn't jump to line 377 because the condition on line 372 was always true

373 retry_kwargs["wait"] = wait_exponential( 

374 multiplier=exponential_base, 

375 max=exponential_max, 

376 ) 

377 elif strategy == RetryStrategy.RANDOM: 

378 retry_kwargs["wait"] = wait_random(min=0, max=exponential_max) 

379 # Add other strategies as needed 

380 

381 # Configure retry condition 

382 if retry_on: 

383 retry_kwargs["retry"] = retry_if_exception_type(retry_on) 

384 # Otherwise, retry all exceptions (tenacity default behavior) 

385 

386 try: 

387 # Execute with retry 

388 async for attempt in AsyncRetrying(**retry_kwargs): # type: ignore[arg-type] 388 ↛ 420line 388 didn't jump to line 420 because the loop on line 388 didn't complete

389 with attempt: 

390 result: T = await func(*args, **kwargs) # type: ignore[misc] 

391 span.set_attribute("retry.success", True) 

392 span.set_attribute("retry.attempts", attempt.retry_state.attempt_number) 

393 return result 

394 

395 except RetryError as e: 

396 # Retry exhausted 

397 span.set_attribute("retry.success", False) 

398 span.set_attribute("retry.exhausted", True) 

399 

400 logger.error( 

401 f"Retry exhausted after {max_attempts} attempts", 

402 exc_info=True, 

403 extra={"max_attempts": max_attempts, "function": func.__name__}, 

404 ) 

405 

406 # Emit metric 

407 retry_exhausted_counter.add(1, attributes={"function": func.__name__}) 

408 

409 # Wrap in our custom exception 

410 from mcp_server_langgraph.core.exceptions import RetryExhaustedError 

411 

412 raise RetryExhaustedError( 

413 message=f"Retry exhausted after {max_attempts} attempts", 

414 metadata={ 

415 "max_attempts": max_attempts, 

416 "function": func.__name__, 

417 }, 

418 ) from e.last_attempt.exception() 

419 # This should never be reached, but mypy needs an explicit return path 

420 msg = "Unreachable code" 

421 raise RuntimeError(msg) # pragma: no cover 

422 

423 @functools.wraps(func) 

424 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

425 """Sync wrapper with retry logic""" 

426 # Similar to async_wrapper but for sync functions 

427 retry_kwargs = { 

428 "stop": stop_after_attempt(max_attempts), 

429 "reraise": False, # Raise RetryError instead of original exception 

430 "before_sleep": log_retry_attempt, 

431 } 

432 

433 if strategy == RetryStrategy.EXPONENTIAL: 

434 retry_kwargs["wait"] = wait_exponential( 

435 multiplier=exponential_base, 

436 max=exponential_max, 

437 ) 

438 

439 if retry_on: 

440 retry_kwargs["retry"] = retry_if_exception_type(retry_on) 

441 

442 try: 

443 for attempt in Retrying(**retry_kwargs): # type: ignore[arg-type] 

444 with attempt: 

445 return func(*args, **kwargs) 

446 except RetryError as e: 

447 from mcp_server_langgraph.core.exceptions import RetryExhaustedError 

448 

449 raise RetryExhaustedError( 

450 message=f"Retry exhausted after {max_attempts} attempts", 

451 metadata={"max_attempts": max_attempts}, 

452 ) from e.last_attempt.exception() 

453 # This should never be reached, but mypy needs an explicit return path 

454 msg = "Unreachable code" 

455 raise RuntimeError(msg) # pragma: no cover 

456 

457 # Return appropriate wrapper 

458 import asyncio 

459 

460 if asyncio.iscoroutinefunction(func): 460 ↛ 463line 460 didn't jump to line 463 because the condition on line 460 was always true

461 return async_wrapper # type: ignore[return-value] 

462 else: 

463 return sync_wrapper 

464 

465 return decorator