Coverage for src / mcp_server_langgraph / resilience / circuit_breaker.py: 93%
159 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Circuit breaker pattern implementation using pybreaker.
4Prevents cascade failures by failing fast when a service is unhealthy.
5Automatically recovers by testing the service periodically.
7States:
8- CLOSED: Normal operation, requests pass through
9- OPEN: Service is failing, fail fast without calling service
10- HALF_OPEN: Testing if service has recovered
12See ADR-0026 for design rationale.
13"""
15import functools
16import logging
17from collections.abc import Callable
18from datetime import datetime
19from enum import Enum
20from typing import Any, ParamSpec, TypeVar, cast
22import pybreaker
23from opentelemetry import trace
25from mcp_server_langgraph.resilience.config import get_resilience_config
27logger = logging.getLogger(__name__)
28tracer = trace.get_tracer(__name__)
30P = ParamSpec("P")
31T = TypeVar("T")
34class CircuitBreakerState(str, Enum):
35 """Circuit breaker states"""
37 CLOSED = "closed" # Normal operation
38 OPEN = "open" # Failing, reject requests
39 HALF_OPEN = "half_open" # Testing recovery
42class CircuitBreakerMetricsListener(pybreaker.CircuitBreakerListener):
43 """
44 Listener for circuit breaker events.
46 Emits metrics and logs for observability.
47 """
49 def __init__(self, name: str) -> None:
50 self.name = name
51 self._state = CircuitBreakerState.CLOSED
52 self._last_state_change = datetime.now()
54 def state_change(
55 self,
56 breaker: pybreaker.CircuitBreaker,
57 old: pybreaker.CircuitBreakerState | None,
58 new: pybreaker.CircuitBreakerState | None,
59 ) -> None:
60 """Called when circuit breaker state changes"""
61 if old is None or new is None: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true
62 return
63 old_state = self._map_state(old)
64 new_state = self._map_state(new)
66 # Log state changes at appropriate level:
67 # - WARNING when transitioning to OPEN (service failure detected)
68 # - INFO for normal transitions (HALF_OPEN, CLOSED - recovery/normal operation)
69 log_level = logger.warning if new_state == CircuitBreakerState.OPEN else logger.info
70 log_level(
71 f"Circuit breaker state changed: {self.name}",
72 extra={
73 "service": self.name,
74 "old_state": old_state.value,
75 "new_state": new_state.value,
76 "failure_count": breaker.fail_counter,
77 },
78 )
80 # Record state change time
81 self._state = new_state
82 self._last_state_change = datetime.now()
84 # Emit metric
85 from mcp_server_langgraph.observability.telemetry import circuit_breaker_state_gauge
87 circuit_breaker_state_gauge.set(
88 1 if new_state == CircuitBreakerState.OPEN else 0,
89 attributes={"service": self.name, "state": new_state.value},
90 )
92 def before_call(self, cb: pybreaker.CircuitBreaker, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
93 """Called before calling the protected function"""
95 def success(self, breaker: pybreaker.CircuitBreaker) -> None:
96 """Called on successful call"""
97 from mcp_server_langgraph.observability.telemetry import circuit_breaker_success_counter
99 circuit_breaker_success_counter.add(1, attributes={"service": self.name})
101 def failure(self, breaker: pybreaker.CircuitBreaker, exception: BaseException) -> None:
102 """Called on failed call"""
103 from mcp_server_langgraph.observability.telemetry import circuit_breaker_failure_counter
105 logger.warning(
106 f"Circuit breaker failure: {self.name}",
107 extra={
108 "service": self.name,
109 "failure_count": breaker.fail_counter,
110 "exception_type": type(exception).__name__,
111 },
112 )
114 circuit_breaker_failure_counter.add(
115 1,
116 attributes={
117 "service": self.name,
118 "exception_type": type(exception).__name__,
119 },
120 )
122 @staticmethod
123 def _map_state(state: pybreaker.CircuitBreakerState) -> CircuitBreakerState:
124 """Map pybreaker state to our enum"""
125 if state.name == pybreaker.STATE_CLOSED:
126 return CircuitBreakerState.CLOSED
127 elif state.name == pybreaker.STATE_OPEN:
128 return CircuitBreakerState.OPEN
129 else: # STATE_HALF_OPEN
130 return CircuitBreakerState.HALF_OPEN
133# Global circuit breaker instances
134_circuit_breakers: dict[str, pybreaker.CircuitBreaker] = {}
137def get_circuit_breaker(name: str) -> pybreaker.CircuitBreaker:
138 """
139 Get or create a circuit breaker for a service.
141 Args:
142 name: Service name (e.g., "llm", "openfga", "redis")
144 Returns:
145 CircuitBreaker instance
146 """
147 if name in _circuit_breakers:
148 return _circuit_breakers[name]
150 # Load configuration
151 config = get_resilience_config()
152 cb_config = config.circuit_breakers.get(name)
154 if not cb_config:
155 # Create default config
156 from mcp_server_langgraph.resilience.config import CircuitBreakerConfig
158 cb_config = CircuitBreakerConfig(name=name)
160 # Create circuit breaker
161 breaker = pybreaker.CircuitBreaker(
162 fail_max=cb_config.fail_max,
163 reset_timeout=cb_config.timeout_duration, # pybreaker uses reset_timeout, not timeout_duration
164 exclude=[], # Don't exclude any exceptions
165 listeners=[CircuitBreakerMetricsListener(name)],
166 name=name,
167 )
169 _circuit_breakers[name] = breaker
170 logger.info(
171 f"Created circuit breaker: {name}",
172 extra={
173 "fail_max": cb_config.fail_max,
174 "timeout_duration": cb_config.timeout_duration,
175 },
176 )
178 return breaker
181def circuit_breaker( # noqa: C901
182 name: str,
183 fail_max: int | None = None,
184 timeout: int | None = None,
185 fallback: Callable[..., Any] | None = None,
186) -> Callable[[Callable[P, T]], Callable[P, T]]:
187 """
188 Decorator to protect a function with a circuit breaker.
190 Args:
191 name: Service name for the circuit breaker
192 fail_max: Max failures before opening (optional override)
193 timeout: Timeout duration in seconds (optional override)
194 fallback: Fallback function to call when circuit is open
196 Usage:
197 @circuit_breaker(name="llm", fail_max=5, timeout=60)
198 async def call_llm(prompt: str) -> str:
199 return await llm_client.generate(prompt)
201 # With fallback
202 @circuit_breaker(name="openfga", fallback=lambda *args: True)
203 async def check_permission(user: str, resource: str) -> bool:
204 return await openfga_client.check(user, resource)
205 """
207 def decorator(func: Callable[P, T]) -> Callable[P, T]:
208 # Get or create circuit breaker
209 breaker = get_circuit_breaker(name)
211 # Override config if provided
212 if fail_max is not None:
213 breaker._fail_max = fail_max
214 if timeout is not None:
215 breaker._reset_timeout = timeout # pybreaker uses _reset_timeout in seconds
217 @functools.wraps(func)
218 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
219 """Async wrapper with circuit breaker"""
220 import asyncio
222 with tracer.start_as_current_span(
223 f"circuit_breaker.{name}",
224 attributes={
225 "circuit_breaker.name": name,
226 "circuit_breaker.state": breaker.current_state,
227 },
228 ) as span:
229 try:
230 # Wrap the async function in a sync callable for pybreaker
231 # pybreaker will handle the state transitions
232 def _sync_wrapper() -> None:
233 # This won't work directly - we need a different approach
234 msg = "Should not be called directly"
235 raise RuntimeError(msg)
237 # Call before_call to handle state transitions (OPEN -> HALF_OPEN after timeout)
238 try:
239 with breaker._lock:
240 # This will transition to HALF_OPEN if timeout elapsed, or raise if still OPEN
241 state_any = breaker.state
242 state_any.before_call(func, *args, **kwargs)
243 # Also notify listeners
244 for listener in cast(list[Any], breaker.listeners):
245 # cast(Any, listener).before_call(breaker, func, *args, **kwargs) # Disabled due to MyPy issues
246 pass
247 except pybreaker.CircuitBreakerError:
248 # Circuit is still OPEN (timeout not elapsed)
249 span.set_attribute("circuit_breaker.success", False)
250 span.set_attribute("circuit_breaker.fallback_used", fallback is not None)
252 logger.warning(
253 f"Circuit breaker open for {name}, failing fast",
254 extra={"service": name, "fallback": fallback is not None},
255 )
257 if fallback:
258 logger.info(f"Using fallback for {name}")
259 if asyncio.iscoroutinefunction(fallback):
260 res = await fallback(*args, **kwargs)
261 return cast(T, res)
262 else:
263 res = fallback(*args, **kwargs)
264 return cast(T, res)
265 else:
266 # Raise our custom exception
267 from mcp_server_langgraph.core.exceptions import CircuitBreakerOpenError
269 raise CircuitBreakerOpenError(
270 message=f"Circuit breaker open for {name}",
271 metadata={"service": name, "state": breaker.current_state},
272 )
274 # Call the function
275 try:
276 result: T = await func(*args, **kwargs) # type: ignore[misc]
278 # Success - handle via state machine
279 with breaker._lock:
280 breaker._state_storage.increment_counter()
281 for listener in cast(list[Any], breaker.listeners):
282 cast(Any, listener).success(breaker)
283 breaker.state.on_success()
285 span.set_attribute("circuit_breaker.success", True)
286 return result
288 except Exception as e:
289 # Failure - handle via state machine
290 try:
291 with breaker._lock:
292 if breaker.is_system_error(e): 292 ↛ 308line 292 didn't jump to line 308 because the condition on line 292 was always true
293 logger.debug(
294 f"Circuit breaker {breaker.name}: system error {type(e).__name__}, "
295 f"counter before={breaker.fail_counter}, fail_max={breaker.fail_max}, "
296 f"state={breaker.state.name}"
297 )
298 breaker._inc_counter()
299 for listener in cast(list[Any], breaker.listeners):
300 cast(Any, listener).failure(breaker, e)
301 breaker.state.on_failure(e)
302 logger.debug(
303 f"Circuit breaker {breaker.name}: after on_failure, "
304 f"counter={breaker.fail_counter}, state={breaker.state.name}"
305 )
306 else:
307 # Not a system error, treat as success
308 logger.debug(
309 f"Circuit breaker {breaker.name}: non-system error {type(e).__name__}, "
310 f"treating as success"
311 )
312 breaker._state_storage.increment_counter()
313 for listener in breaker.listeners:
314 listener.success(breaker)
315 breaker.state.on_success()
316 except pybreaker.CircuitBreakerError:
317 # Circuit just opened on this failure
318 # Don't use fallback here - the circuit opened because of THIS failure
319 # Fallback is only used when circuit is already open BEFORE the call
320 pass
322 # Re-raise the original exception
323 span.set_attribute("circuit_breaker.success", False)
324 raise
326 except pybreaker.CircuitBreakerError as e: # noqa: F841
327 # This shouldn't happen with our manual implementation, but just in case
328 span.set_attribute("circuit_breaker.success", False)
329 raise
331 @functools.wraps(func)
332 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
333 """Sync wrapper with circuit breaker"""
334 # Similar to async_wrapper but for sync functions
335 try:
336 result: T = breaker.call(func, *args, **kwargs)
337 return result
338 except pybreaker.CircuitBreakerError as e:
339 if fallback:
340 res = fallback(*args, **kwargs)
341 return cast(T, res)
342 else:
343 from mcp_server_langgraph.core.exceptions import CircuitBreakerOpenError
345 raise CircuitBreakerOpenError(
346 message=f"Circuit breaker open for {name}",
347 metadata={"service": name},
348 ) from e
350 # Return appropriate wrapper
351 import asyncio
353 if asyncio.iscoroutinefunction(func):
354 return async_wrapper # type: ignore[return-value]
355 else:
356 return sync_wrapper
358 return decorator
361def reset_circuit_breaker(name: str) -> None:
362 """
363 Manually reset a circuit breaker to closed state.
365 Args:
366 name: Service name
368 Usage:
369 reset_circuit_breaker("llm") # Force close circuit
370 """
371 if name in _circuit_breakers:
372 breaker = _circuit_breakers[name]
373 # Use the proper pybreaker API to close the circuit
374 breaker.close()
375 logger.info(f"Circuit breaker manually reset: {name}")
378def get_circuit_breaker_state(name: str) -> CircuitBreakerState:
379 """
380 Get current state of a circuit breaker.
382 Args:
383 name: Service name
385 Returns:
386 Current circuit breaker state
387 """
388 if name not in _circuit_breakers:
389 return CircuitBreakerState.CLOSED
391 breaker = _circuit_breakers[name]
392 return CircuitBreakerMetricsListener._map_state(breaker.state)
395def get_all_circuit_breaker_states() -> dict[str, CircuitBreakerState]:
396 """
397 Get states of all circuit breakers.
399 Returns:
400 Dictionary mapping service name to state
401 """
402 return {name: get_circuit_breaker_state(name) for name in _circuit_breakers}
405def reset_all_circuit_breakers() -> None:
406 """
407 Reset all circuit breakers (for testing).
409 Resets the state of all existing circuit breakers to CLOSED without
410 removing them from the registry. This preserves decorator closure integrity
411 by ensuring decorators continue to reference the same circuit breaker instances.
413 Important: This function does NOT clear the registry. If you need to clear
414 the registry completely (e.g., in teardown), call _circuit_breakers.clear()
415 directly, but be aware that decorator closures will hold stale references.
417 See: tests/resilience/test_circuit_breaker_decorator_isolation.py
418 See: adr/ADR-0054-circuit-breaker-decorator-closure-isolation.md
419 """
420 # Reset state of all existing circuit breakers instead of clearing registry
421 # This preserves decorator closure integrity
422 for name, breaker in list(_circuit_breakers.items()):
423 breaker.close() # Reset to CLOSED state
424 logger.debug(f"Circuit breaker state reset: {name}")
426 logger.info(f"All circuit breakers reset ({len(_circuit_breakers)} total)")