Coverage for src / mcp_server_langgraph / resilience / circuit_breaker.py: 93%

159 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Circuit breaker pattern implementation using pybreaker. 

3 

4Prevents cascade failures by failing fast when a service is unhealthy. 

5Automatically recovers by testing the service periodically. 

6 

7States: 

8- CLOSED: Normal operation, requests pass through 

9- OPEN: Service is failing, fail fast without calling service 

10- HALF_OPEN: Testing if service has recovered 

11 

12See ADR-0026 for design rationale. 

13""" 

14 

15import functools 

16import logging 

17from collections.abc import Callable 

18from datetime import datetime 

19from enum import Enum 

20from typing import Any, ParamSpec, TypeVar, cast 

21 

22import pybreaker 

23from opentelemetry import trace 

24 

25from mcp_server_langgraph.resilience.config import get_resilience_config 

26 

27logger = logging.getLogger(__name__) 

28tracer = trace.get_tracer(__name__) 

29 

30P = ParamSpec("P") 

31T = TypeVar("T") 

32 

33 

34class CircuitBreakerState(str, Enum): 

35 """Circuit breaker states""" 

36 

37 CLOSED = "closed" # Normal operation 

38 OPEN = "open" # Failing, reject requests 

39 HALF_OPEN = "half_open" # Testing recovery 

40 

41 

42class CircuitBreakerMetricsListener(pybreaker.CircuitBreakerListener): 

43 """ 

44 Listener for circuit breaker events. 

45 

46 Emits metrics and logs for observability. 

47 """ 

48 

49 def __init__(self, name: str) -> None: 

50 self.name = name 

51 self._state = CircuitBreakerState.CLOSED 

52 self._last_state_change = datetime.now() 

53 

54 def state_change( 

55 self, 

56 breaker: pybreaker.CircuitBreaker, 

57 old: pybreaker.CircuitBreakerState | None, 

58 new: pybreaker.CircuitBreakerState | None, 

59 ) -> None: 

60 """Called when circuit breaker state changes""" 

61 if old is None or new is None: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 return 

63 old_state = self._map_state(old) 

64 new_state = self._map_state(new) 

65 

66 # Log state changes at appropriate level: 

67 # - WARNING when transitioning to OPEN (service failure detected) 

68 # - INFO for normal transitions (HALF_OPEN, CLOSED - recovery/normal operation) 

69 log_level = logger.warning if new_state == CircuitBreakerState.OPEN else logger.info 

70 log_level( 

71 f"Circuit breaker state changed: {self.name}", 

72 extra={ 

73 "service": self.name, 

74 "old_state": old_state.value, 

75 "new_state": new_state.value, 

76 "failure_count": breaker.fail_counter, 

77 }, 

78 ) 

79 

80 # Record state change time 

81 self._state = new_state 

82 self._last_state_change = datetime.now() 

83 

84 # Emit metric 

85 from mcp_server_langgraph.observability.telemetry import circuit_breaker_state_gauge 

86 

87 circuit_breaker_state_gauge.set( 

88 1 if new_state == CircuitBreakerState.OPEN else 0, 

89 attributes={"service": self.name, "state": new_state.value}, 

90 ) 

91 

92 def before_call(self, cb: pybreaker.CircuitBreaker, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: 

93 """Called before calling the protected function""" 

94 

95 def success(self, breaker: pybreaker.CircuitBreaker) -> None: 

96 """Called on successful call""" 

97 from mcp_server_langgraph.observability.telemetry import circuit_breaker_success_counter 

98 

99 circuit_breaker_success_counter.add(1, attributes={"service": self.name}) 

100 

101 def failure(self, breaker: pybreaker.CircuitBreaker, exception: BaseException) -> None: 

102 """Called on failed call""" 

103 from mcp_server_langgraph.observability.telemetry import circuit_breaker_failure_counter 

104 

105 logger.warning( 

106 f"Circuit breaker failure: {self.name}", 

107 extra={ 

108 "service": self.name, 

109 "failure_count": breaker.fail_counter, 

110 "exception_type": type(exception).__name__, 

111 }, 

112 ) 

113 

114 circuit_breaker_failure_counter.add( 

115 1, 

116 attributes={ 

117 "service": self.name, 

118 "exception_type": type(exception).__name__, 

119 }, 

120 ) 

121 

122 @staticmethod 

123 def _map_state(state: pybreaker.CircuitBreakerState) -> CircuitBreakerState: 

124 """Map pybreaker state to our enum""" 

125 if state.name == pybreaker.STATE_CLOSED: 

126 return CircuitBreakerState.CLOSED 

127 elif state.name == pybreaker.STATE_OPEN: 

128 return CircuitBreakerState.OPEN 

129 else: # STATE_HALF_OPEN 

130 return CircuitBreakerState.HALF_OPEN 

131 

132 

133# Global circuit breaker instances 

134_circuit_breakers: dict[str, pybreaker.CircuitBreaker] = {} 

135 

136 

137def get_circuit_breaker(name: str) -> pybreaker.CircuitBreaker: 

138 """ 

139 Get or create a circuit breaker for a service. 

140 

141 Args: 

142 name: Service name (e.g., "llm", "openfga", "redis") 

143 

144 Returns: 

145 CircuitBreaker instance 

146 """ 

147 if name in _circuit_breakers: 

148 return _circuit_breakers[name] 

149 

150 # Load configuration 

151 config = get_resilience_config() 

152 cb_config = config.circuit_breakers.get(name) 

153 

154 if not cb_config: 

155 # Create default config 

156 from mcp_server_langgraph.resilience.config import CircuitBreakerConfig 

157 

158 cb_config = CircuitBreakerConfig(name=name) 

159 

160 # Create circuit breaker 

161 breaker = pybreaker.CircuitBreaker( 

162 fail_max=cb_config.fail_max, 

163 reset_timeout=cb_config.timeout_duration, # pybreaker uses reset_timeout, not timeout_duration 

164 exclude=[], # Don't exclude any exceptions 

165 listeners=[CircuitBreakerMetricsListener(name)], 

166 name=name, 

167 ) 

168 

169 _circuit_breakers[name] = breaker 

170 logger.info( 

171 f"Created circuit breaker: {name}", 

172 extra={ 

173 "fail_max": cb_config.fail_max, 

174 "timeout_duration": cb_config.timeout_duration, 

175 }, 

176 ) 

177 

178 return breaker 

179 

180 

181def circuit_breaker( # noqa: C901 

182 name: str, 

183 fail_max: int | None = None, 

184 timeout: int | None = None, 

185 fallback: Callable[..., Any] | None = None, 

186) -> Callable[[Callable[P, T]], Callable[P, T]]: 

187 """ 

188 Decorator to protect a function with a circuit breaker. 

189 

190 Args: 

191 name: Service name for the circuit breaker 

192 fail_max: Max failures before opening (optional override) 

193 timeout: Timeout duration in seconds (optional override) 

194 fallback: Fallback function to call when circuit is open 

195 

196 Usage: 

197 @circuit_breaker(name="llm", fail_max=5, timeout=60) 

198 async def call_llm(prompt: str) -> str: 

199 return await llm_client.generate(prompt) 

200 

201 # With fallback 

202 @circuit_breaker(name="openfga", fallback=lambda *args: True) 

203 async def check_permission(user: str, resource: str) -> bool: 

204 return await openfga_client.check(user, resource) 

205 """ 

206 

207 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

208 # Get or create circuit breaker 

209 breaker = get_circuit_breaker(name) 

210 

211 # Override config if provided 

212 if fail_max is not None: 

213 breaker._fail_max = fail_max 

214 if timeout is not None: 

215 breaker._reset_timeout = timeout # pybreaker uses _reset_timeout in seconds 

216 

217 @functools.wraps(func) 

218 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

219 """Async wrapper with circuit breaker""" 

220 import asyncio 

221 

222 with tracer.start_as_current_span( 

223 f"circuit_breaker.{name}", 

224 attributes={ 

225 "circuit_breaker.name": name, 

226 "circuit_breaker.state": breaker.current_state, 

227 }, 

228 ) as span: 

229 try: 

230 # Wrap the async function in a sync callable for pybreaker 

231 # pybreaker will handle the state transitions 

232 def _sync_wrapper() -> None: 

233 # This won't work directly - we need a different approach 

234 msg = "Should not be called directly" 

235 raise RuntimeError(msg) 

236 

237 # Call before_call to handle state transitions (OPEN -> HALF_OPEN after timeout) 

238 try: 

239 with breaker._lock: 

240 # This will transition to HALF_OPEN if timeout elapsed, or raise if still OPEN 

241 state_any = breaker.state 

242 state_any.before_call(func, *args, **kwargs) 

243 # Also notify listeners 

244 for listener in cast(list[Any], breaker.listeners): 

245 # cast(Any, listener).before_call(breaker, func, *args, **kwargs) # Disabled due to MyPy issues 

246 pass 

247 except pybreaker.CircuitBreakerError: 

248 # Circuit is still OPEN (timeout not elapsed) 

249 span.set_attribute("circuit_breaker.success", False) 

250 span.set_attribute("circuit_breaker.fallback_used", fallback is not None) 

251 

252 logger.warning( 

253 f"Circuit breaker open for {name}, failing fast", 

254 extra={"service": name, "fallback": fallback is not None}, 

255 ) 

256 

257 if fallback: 

258 logger.info(f"Using fallback for {name}") 

259 if asyncio.iscoroutinefunction(fallback): 

260 res = await fallback(*args, **kwargs) 

261 return cast(T, res) 

262 else: 

263 res = fallback(*args, **kwargs) 

264 return cast(T, res) 

265 else: 

266 # Raise our custom exception 

267 from mcp_server_langgraph.core.exceptions import CircuitBreakerOpenError 

268 

269 raise CircuitBreakerOpenError( 

270 message=f"Circuit breaker open for {name}", 

271 metadata={"service": name, "state": breaker.current_state}, 

272 ) 

273 

274 # Call the function 

275 try: 

276 result: T = await func(*args, **kwargs) # type: ignore[misc] 

277 

278 # Success - handle via state machine 

279 with breaker._lock: 

280 breaker._state_storage.increment_counter() 

281 for listener in cast(list[Any], breaker.listeners): 

282 cast(Any, listener).success(breaker) 

283 breaker.state.on_success() 

284 

285 span.set_attribute("circuit_breaker.success", True) 

286 return result 

287 

288 except Exception as e: 

289 # Failure - handle via state machine 

290 try: 

291 with breaker._lock: 

292 if breaker.is_system_error(e): 292 ↛ 308line 292 didn't jump to line 308 because the condition on line 292 was always true

293 logger.debug( 

294 f"Circuit breaker {breaker.name}: system error {type(e).__name__}, " 

295 f"counter before={breaker.fail_counter}, fail_max={breaker.fail_max}, " 

296 f"state={breaker.state.name}" 

297 ) 

298 breaker._inc_counter() 

299 for listener in cast(list[Any], breaker.listeners): 

300 cast(Any, listener).failure(breaker, e) 

301 breaker.state.on_failure(e) 

302 logger.debug( 

303 f"Circuit breaker {breaker.name}: after on_failure, " 

304 f"counter={breaker.fail_counter}, state={breaker.state.name}" 

305 ) 

306 else: 

307 # Not a system error, treat as success 

308 logger.debug( 

309 f"Circuit breaker {breaker.name}: non-system error {type(e).__name__}, " 

310 f"treating as success" 

311 ) 

312 breaker._state_storage.increment_counter() 

313 for listener in breaker.listeners: 

314 listener.success(breaker) 

315 breaker.state.on_success() 

316 except pybreaker.CircuitBreakerError: 

317 # Circuit just opened on this failure 

318 # Don't use fallback here - the circuit opened because of THIS failure 

319 # Fallback is only used when circuit is already open BEFORE the call 

320 pass 

321 

322 # Re-raise the original exception 

323 span.set_attribute("circuit_breaker.success", False) 

324 raise 

325 

326 except pybreaker.CircuitBreakerError as e: # noqa: F841 

327 # This shouldn't happen with our manual implementation, but just in case 

328 span.set_attribute("circuit_breaker.success", False) 

329 raise 

330 

331 @functools.wraps(func) 

332 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

333 """Sync wrapper with circuit breaker""" 

334 # Similar to async_wrapper but for sync functions 

335 try: 

336 result: T = breaker.call(func, *args, **kwargs) 

337 return result 

338 except pybreaker.CircuitBreakerError as e: 

339 if fallback: 

340 res = fallback(*args, **kwargs) 

341 return cast(T, res) 

342 else: 

343 from mcp_server_langgraph.core.exceptions import CircuitBreakerOpenError 

344 

345 raise CircuitBreakerOpenError( 

346 message=f"Circuit breaker open for {name}", 

347 metadata={"service": name}, 

348 ) from e 

349 

350 # Return appropriate wrapper 

351 import asyncio 

352 

353 if asyncio.iscoroutinefunction(func): 

354 return async_wrapper # type: ignore[return-value] 

355 else: 

356 return sync_wrapper 

357 

358 return decorator 

359 

360 

361def reset_circuit_breaker(name: str) -> None: 

362 """ 

363 Manually reset a circuit breaker to closed state. 

364 

365 Args: 

366 name: Service name 

367 

368 Usage: 

369 reset_circuit_breaker("llm") # Force close circuit 

370 """ 

371 if name in _circuit_breakers: 

372 breaker = _circuit_breakers[name] 

373 # Use the proper pybreaker API to close the circuit 

374 breaker.close() 

375 logger.info(f"Circuit breaker manually reset: {name}") 

376 

377 

378def get_circuit_breaker_state(name: str) -> CircuitBreakerState: 

379 """ 

380 Get current state of a circuit breaker. 

381 

382 Args: 

383 name: Service name 

384 

385 Returns: 

386 Current circuit breaker state 

387 """ 

388 if name not in _circuit_breakers: 

389 return CircuitBreakerState.CLOSED 

390 

391 breaker = _circuit_breakers[name] 

392 return CircuitBreakerMetricsListener._map_state(breaker.state) 

393 

394 

395def get_all_circuit_breaker_states() -> dict[str, CircuitBreakerState]: 

396 """ 

397 Get states of all circuit breakers. 

398 

399 Returns: 

400 Dictionary mapping service name to state 

401 """ 

402 return {name: get_circuit_breaker_state(name) for name in _circuit_breakers} 

403 

404 

405def reset_all_circuit_breakers() -> None: 

406 """ 

407 Reset all circuit breakers (for testing). 

408 

409 Resets the state of all existing circuit breakers to CLOSED without 

410 removing them from the registry. This preserves decorator closure integrity 

411 by ensuring decorators continue to reference the same circuit breaker instances. 

412 

413 Important: This function does NOT clear the registry. If you need to clear 

414 the registry completely (e.g., in teardown), call _circuit_breakers.clear() 

415 directly, but be aware that decorator closures will hold stale references. 

416 

417 See: tests/resilience/test_circuit_breaker_decorator_isolation.py 

418 See: adr/ADR-0054-circuit-breaker-decorator-closure-isolation.md 

419 """ 

420 # Reset state of all existing circuit breakers instead of clearing registry 

421 # This preserves decorator closure integrity 

422 for name, breaker in list(_circuit_breakers.items()): 

423 breaker.close() # Reset to CLOSED state 

424 logger.debug(f"Circuit breaker state reset: {name}") 

425 

426 logger.info(f"All circuit breakers reset ({len(_circuit_breakers)} total)")