Coverage for src / mcp_server_langgraph / resilience / fallback.py: 82%
113 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Fallback strategies for graceful degradation.
4Provides fallback responses when primary services are unavailable.
5Enables fail-open vs fail-closed behavior per service.
7See ADR-0026 for design rationale.
8"""
10import functools
11import logging
12from collections.abc import Callable
13from typing import Any, ParamSpec, TypeVar
15from opentelemetry import trace
17from mcp_server_langgraph.observability.telemetry import fallback_used_counter
19logger = logging.getLogger(__name__)
20tracer = trace.get_tracer(__name__)
22P = ParamSpec("P")
23T = TypeVar("T")
26class FallbackStrategy:
27 """
28 Base class for fallback strategies.
30 Subclass this to create custom fallback behaviors.
31 """
33 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any:
34 """
35 Get fallback value when primary operation fails.
37 Args:
38 *args: Original function arguments
39 **kwargs: Original function keyword arguments
41 Returns:
42 Fallback value
43 """
44 msg = "Subclasses must implement get_fallback_value()"
45 raise NotImplementedError(msg)
48class DefaultValueFallback(FallbackStrategy):
49 """Return a default value on failure"""
51 def __init__(self, default_value: Any) -> None:
52 self.default_value = default_value
54 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any:
55 return self.default_value
58class CachedValueFallback(FallbackStrategy):
59 """Return cached value on failure"""
61 def __init__(self, cache_key_fn: Callable[..., str] | None = None) -> None:
62 self.cache_key_fn = cache_key_fn or self._default_cache_key
63 self._cache: dict[str, Any] = {}
65 def _default_cache_key(self, *args: Any, **kwargs: Any) -> str:
66 """Generate cache key from arguments"""
67 return f"{args}:{kwargs}"
69 def cache_value(self, value: Any, *args: Any, **kwargs: Any) -> None:
70 """Cache a value for future fallback"""
71 key = self.cache_key_fn(*args, **kwargs)
72 self._cache[key] = value
74 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any:
75 """Get cached value or None"""
76 key = self.cache_key_fn(*args, **kwargs)
77 return self._cache.get(key)
80class FunctionFallback(FallbackStrategy):
81 """Call a fallback function on failure"""
83 def __init__(self, fallback_fn: Callable[..., Any]) -> None:
84 self.fallback_fn = fallback_fn
86 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any:
87 return self.fallback_fn(*args, **kwargs)
90class StaleDataFallback(FallbackStrategy):
91 """Return stale data with warning on failure"""
93 def __init__(self, max_staleness_seconds: int = 3600) -> None:
94 self.max_staleness_seconds = max_staleness_seconds
95 self._cache: dict[str, tuple[Any, float]] = {} # value, timestamp
97 def cache_value(self, value: Any, key: str) -> None:
98 """Cache value with timestamp"""
99 import time
101 self._cache[key] = (value, time.time())
103 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any | None:
104 """Get stale data if within staleness limit"""
105 import time
107 # Support both direct key (single arg) and generated key (multiple args/kwargs)
108 key = args[0] if len(args) == 1 and not kwargs and isinstance(args[0], str) else str(args) + str(kwargs)
110 if key in self._cache: 110 ↛ 120line 110 didn't jump to line 120 because the condition on line 110 was always true
111 value, timestamp = self._cache[key]
112 age = time.time() - timestamp
113 if age < self.max_staleness_seconds:
114 logger.warning(
115 f"Using stale data (age: {age:.1f}s)",
116 extra={"staleness_seconds": age, "max_staleness": self.max_staleness_seconds},
117 )
118 return value
120 return None
123# Predefined fallback strategies for common scenarios
124FALLBACK_STRATEGIES = {
125 # Authorization: Fail-open (allow by default)
126 "openfga_allow": DefaultValueFallback(default_value=True),
127 # Authorization: Fail-closed (deny by default)
128 "openfga_deny": DefaultValueFallback(default_value=False),
129 # Cache: Return None on cache miss
130 "cache_miss": DefaultValueFallback(default_value=None),
131 # Empty list fallback
132 "empty_list": DefaultValueFallback(default_value=[]),
133 # Empty dict fallback
134 "empty_dict": DefaultValueFallback(default_value={}),
135}
138def with_fallback( # noqa: C901
139 fallback: Any | None = None,
140 fallback_fn: Callable[..., Any] | None = None,
141 fallback_strategy: FallbackStrategy | None = None,
142 fallback_on: tuple[type, ...] | None = None,
143) -> Callable[[Callable[P, T]], Callable[P, T]]:
144 """
145 Decorator to provide fallback value when function raises exception.
147 Args:
148 fallback: Static fallback value
149 fallback_fn: Function to call for fallback value
150 fallback_strategy: FallbackStrategy instance for advanced fallback
151 fallback_on: Exception types to catch (default: all exceptions)
153 Usage:
154 # Static fallback
155 @with_fallback(fallback=True)
156 async def check_permission(user: str, resource: str) -> bool:
157 # Returns True if OpenFGA is down (fail-open)
158 return await openfga_client.check(user, resource)
160 # Function fallback
161 @with_fallback(fallback_fn=lambda user, res: user == "admin")
162 async def check_permission(user: str, resource: str) -> bool:
163 # Admins always allowed if OpenFGA is down
164 return await openfga_client.check(user, resource)
166 # Cached value fallback
167 cache_strategy = CachedValueFallback()
168 @with_fallback(fallback_strategy=cache_strategy)
169 async def get_user_profile(user_id: str) -> dict[str, Any]:
170 profile = await db.get_user(user_id)
171 cache_strategy.cache_value(profile, user_id)
172 return profile
174 # Specific exception types
175 @with_fallback(fallback=[], fallback_on=(httpx.TimeoutError, redis.ConnectionError))
176 async def fetch_items() -> list:
177 # Returns [] only on timeout/connection errors
178 return await get_items()
179 """
180 # Validate arguments
181 if sum([fallback is not None, fallback_fn is not None, fallback_strategy is not None]) != 1:
182 msg = "Exactly one of fallback, fallback_fn, or fallback_strategy must be provided"
183 raise ValueError(msg)
185 # Determine exception types to catch
186 exception_types = fallback_on or (Exception,)
188 def decorator(func: Callable[P, T]) -> Callable[P, T]:
189 @functools.wraps(func)
190 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
191 """Async wrapper with fallback"""
192 with tracer.start_as_current_span(
193 f"fallback.{func.__name__}",
194 attributes={"fallback.enabled": True},
195 ) as span:
196 try:
197 # Try primary operation
198 result: T = await func(*args, **kwargs) # type: ignore[misc]
199 span.set_attribute("fallback.used", False)
200 return result
202 except BaseException as e:
203 # Check if this is an exception type we should handle
204 if not isinstance(e, exception_types):
205 raise
207 # Primary operation failed, use fallback
208 span.set_attribute("fallback.used", True)
209 span.set_attribute("fallback.exception_type", type(e).__name__)
211 logger.warning(
212 f"Using fallback for {func.__name__}",
213 exc_info=True,
214 extra={
215 "function": func.__name__,
216 "exception_type": type(e).__name__,
217 "fallback_type": "static" if fallback is not None else "function" if fallback_fn else "strategy",
218 },
219 )
221 # Emit metric
222 fallback_used_counter.add(
223 1,
224 attributes={
225 "function": func.__name__,
226 "exception_type": type(e).__name__,
227 },
228 )
230 # Determine fallback value
231 if fallback is not None:
232 return fallback # type: ignore[no-any-return]
233 elif fallback_fn is not None: 233 ↛ 239line 233 didn't jump to line 239 because the condition on line 233 was always true
234 if asyncio.iscoroutinefunction(fallback_fn):
235 return await fallback_fn(*args, **kwargs) # type: ignore[no-any-return]
236 else:
237 return fallback_fn(*args, **kwargs) # type: ignore[no-any-return]
238 else: # fallback_strategy is not None
239 return fallback_strategy.get_fallback_value(*args, **kwargs) # type: ignore[no-any-return,union-attr]
241 @functools.wraps(func)
242 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
243 """Sync wrapper with fallback"""
244 try:
245 return func(*args, **kwargs)
246 except BaseException as e:
247 # Check if this is an exception type we should handle
248 if not isinstance(e, exception_types):
249 raise
251 logger.warning(f"Using fallback for {func.__name__}", exc_info=True)
253 if fallback is not None:
254 return fallback # type: ignore[no-any-return]
255 elif fallback_fn is not None:
256 return fallback_fn(*args, **kwargs) # type: ignore[no-any-return]
257 else: # fallback_strategy is not None
258 return fallback_strategy.get_fallback_value(*args, **kwargs) # type: ignore[no-any-return,union-attr]
260 # Return appropriate wrapper
261 import asyncio
263 if asyncio.iscoroutinefunction(func): 263 ↛ 266line 263 didn't jump to line 266 because the condition on line 263 was always true
264 return async_wrapper # type: ignore[return-value]
265 else:
266 return sync_wrapper
268 return decorator
271# Convenience decorators for common fallback scenarios
272def fail_open(func: Callable[P, bool]) -> Callable[P, bool]:
273 """
274 Decorator for fail-open authorization (allow on error).
276 Usage:
277 @fail_open
278 async def check_permission(user: str, resource: str) -> bool:
279 return await openfga_client.check(user, resource)
280 """
281 return with_fallback(fallback=True)(func)
284def fail_closed(func: Callable[P, bool]) -> Callable[P, bool]:
285 """
286 Decorator for fail-closed authorization (deny on error).
288 Usage:
289 @fail_closed
290 async def check_admin_permission(user: str) -> bool:
291 return await openfga_client.check(user, "admin")
292 """
293 return with_fallback(fallback=False)(func)
296def return_empty_on_error(func: Callable[P, T]) -> Callable[P, T]:
297 """
298 Decorator to return empty value on error (list/dict/None).
300 Usage:
301 @return_empty_on_error
302 async def get_user_list() -> list:
303 return await db.query_users() # Returns [] on error
304 """
306 def determine_empty_value() -> Any:
307 """Determine empty value based on return type hint"""
308 import inspect
310 sig = inspect.signature(func)
311 return_type = sig.return_annotation
313 if return_type == list or "List" in str(return_type):
314 return []
315 elif return_type == dict or "Dict" in str(return_type): 315 ↛ 318line 315 didn't jump to line 318 because the condition on line 315 was always true
316 return {}
317 else:
318 return None
320 return with_fallback(fallback=determine_empty_value())(func)