Coverage for src / mcp_server_langgraph / resilience / fallback.py: 82%

113 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Fallback strategies for graceful degradation. 

3 

4Provides fallback responses when primary services are unavailable. 

5Enables fail-open vs fail-closed behavior per service. 

6 

7See ADR-0026 for design rationale. 

8""" 

9 

10import functools 

11import logging 

12from collections.abc import Callable 

13from typing import Any, ParamSpec, TypeVar 

14 

15from opentelemetry import trace 

16 

17from mcp_server_langgraph.observability.telemetry import fallback_used_counter 

18 

19logger = logging.getLogger(__name__) 

20tracer = trace.get_tracer(__name__) 

21 

22P = ParamSpec("P") 

23T = TypeVar("T") 

24 

25 

26class FallbackStrategy: 

27 """ 

28 Base class for fallback strategies. 

29 

30 Subclass this to create custom fallback behaviors. 

31 """ 

32 

33 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any: 

34 """ 

35 Get fallback value when primary operation fails. 

36 

37 Args: 

38 *args: Original function arguments 

39 **kwargs: Original function keyword arguments 

40 

41 Returns: 

42 Fallback value 

43 """ 

44 msg = "Subclasses must implement get_fallback_value()" 

45 raise NotImplementedError(msg) 

46 

47 

48class DefaultValueFallback(FallbackStrategy): 

49 """Return a default value on failure""" 

50 

51 def __init__(self, default_value: Any) -> None: 

52 self.default_value = default_value 

53 

54 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any: 

55 return self.default_value 

56 

57 

58class CachedValueFallback(FallbackStrategy): 

59 """Return cached value on failure""" 

60 

61 def __init__(self, cache_key_fn: Callable[..., str] | None = None) -> None: 

62 self.cache_key_fn = cache_key_fn or self._default_cache_key 

63 self._cache: dict[str, Any] = {} 

64 

65 def _default_cache_key(self, *args: Any, **kwargs: Any) -> str: 

66 """Generate cache key from arguments""" 

67 return f"{args}:{kwargs}" 

68 

69 def cache_value(self, value: Any, *args: Any, **kwargs: Any) -> None: 

70 """Cache a value for future fallback""" 

71 key = self.cache_key_fn(*args, **kwargs) 

72 self._cache[key] = value 

73 

74 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any: 

75 """Get cached value or None""" 

76 key = self.cache_key_fn(*args, **kwargs) 

77 return self._cache.get(key) 

78 

79 

80class FunctionFallback(FallbackStrategy): 

81 """Call a fallback function on failure""" 

82 

83 def __init__(self, fallback_fn: Callable[..., Any]) -> None: 

84 self.fallback_fn = fallback_fn 

85 

86 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any: 

87 return self.fallback_fn(*args, **kwargs) 

88 

89 

90class StaleDataFallback(FallbackStrategy): 

91 """Return stale data with warning on failure""" 

92 

93 def __init__(self, max_staleness_seconds: int = 3600) -> None: 

94 self.max_staleness_seconds = max_staleness_seconds 

95 self._cache: dict[str, tuple[Any, float]] = {} # value, timestamp 

96 

97 def cache_value(self, value: Any, key: str) -> None: 

98 """Cache value with timestamp""" 

99 import time 

100 

101 self._cache[key] = (value, time.time()) 

102 

103 def get_fallback_value(self, *args: Any, **kwargs: Any) -> Any | None: 

104 """Get stale data if within staleness limit""" 

105 import time 

106 

107 # Support both direct key (single arg) and generated key (multiple args/kwargs) 

108 key = args[0] if len(args) == 1 and not kwargs and isinstance(args[0], str) else str(args) + str(kwargs) 

109 

110 if key in self._cache: 110 ↛ 120line 110 didn't jump to line 120 because the condition on line 110 was always true

111 value, timestamp = self._cache[key] 

112 age = time.time() - timestamp 

113 if age < self.max_staleness_seconds: 

114 logger.warning( 

115 f"Using stale data (age: {age:.1f}s)", 

116 extra={"staleness_seconds": age, "max_staleness": self.max_staleness_seconds}, 

117 ) 

118 return value 

119 

120 return None 

121 

122 

123# Predefined fallback strategies for common scenarios 

124FALLBACK_STRATEGIES = { 

125 # Authorization: Fail-open (allow by default) 

126 "openfga_allow": DefaultValueFallback(default_value=True), 

127 # Authorization: Fail-closed (deny by default) 

128 "openfga_deny": DefaultValueFallback(default_value=False), 

129 # Cache: Return None on cache miss 

130 "cache_miss": DefaultValueFallback(default_value=None), 

131 # Empty list fallback 

132 "empty_list": DefaultValueFallback(default_value=[]), 

133 # Empty dict fallback 

134 "empty_dict": DefaultValueFallback(default_value={}), 

135} 

136 

137 

138def with_fallback( # noqa: C901 

139 fallback: Any | None = None, 

140 fallback_fn: Callable[..., Any] | None = None, 

141 fallback_strategy: FallbackStrategy | None = None, 

142 fallback_on: tuple[type, ...] | None = None, 

143) -> Callable[[Callable[P, T]], Callable[P, T]]: 

144 """ 

145 Decorator to provide fallback value when function raises exception. 

146 

147 Args: 

148 fallback: Static fallback value 

149 fallback_fn: Function to call for fallback value 

150 fallback_strategy: FallbackStrategy instance for advanced fallback 

151 fallback_on: Exception types to catch (default: all exceptions) 

152 

153 Usage: 

154 # Static fallback 

155 @with_fallback(fallback=True) 

156 async def check_permission(user: str, resource: str) -> bool: 

157 # Returns True if OpenFGA is down (fail-open) 

158 return await openfga_client.check(user, resource) 

159 

160 # Function fallback 

161 @with_fallback(fallback_fn=lambda user, res: user == "admin") 

162 async def check_permission(user: str, resource: str) -> bool: 

163 # Admins always allowed if OpenFGA is down 

164 return await openfga_client.check(user, resource) 

165 

166 # Cached value fallback 

167 cache_strategy = CachedValueFallback() 

168 @with_fallback(fallback_strategy=cache_strategy) 

169 async def get_user_profile(user_id: str) -> dict[str, Any]: 

170 profile = await db.get_user(user_id) 

171 cache_strategy.cache_value(profile, user_id) 

172 return profile 

173 

174 # Specific exception types 

175 @with_fallback(fallback=[], fallback_on=(httpx.TimeoutError, redis.ConnectionError)) 

176 async def fetch_items() -> list: 

177 # Returns [] only on timeout/connection errors 

178 return await get_items() 

179 """ 

180 # Validate arguments 

181 if sum([fallback is not None, fallback_fn is not None, fallback_strategy is not None]) != 1: 

182 msg = "Exactly one of fallback, fallback_fn, or fallback_strategy must be provided" 

183 raise ValueError(msg) 

184 

185 # Determine exception types to catch 

186 exception_types = fallback_on or (Exception,) 

187 

188 def decorator(func: Callable[P, T]) -> Callable[P, T]: 

189 @functools.wraps(func) 

190 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

191 """Async wrapper with fallback""" 

192 with tracer.start_as_current_span( 

193 f"fallback.{func.__name__}", 

194 attributes={"fallback.enabled": True}, 

195 ) as span: 

196 try: 

197 # Try primary operation 

198 result: T = await func(*args, **kwargs) # type: ignore[misc] 

199 span.set_attribute("fallback.used", False) 

200 return result 

201 

202 except BaseException as e: 

203 # Check if this is an exception type we should handle 

204 if not isinstance(e, exception_types): 

205 raise 

206 

207 # Primary operation failed, use fallback 

208 span.set_attribute("fallback.used", True) 

209 span.set_attribute("fallback.exception_type", type(e).__name__) 

210 

211 logger.warning( 

212 f"Using fallback for {func.__name__}", 

213 exc_info=True, 

214 extra={ 

215 "function": func.__name__, 

216 "exception_type": type(e).__name__, 

217 "fallback_type": "static" if fallback is not None else "function" if fallback_fn else "strategy", 

218 }, 

219 ) 

220 

221 # Emit metric 

222 fallback_used_counter.add( 

223 1, 

224 attributes={ 

225 "function": func.__name__, 

226 "exception_type": type(e).__name__, 

227 }, 

228 ) 

229 

230 # Determine fallback value 

231 if fallback is not None: 

232 return fallback # type: ignore[no-any-return] 

233 elif fallback_fn is not None: 233 ↛ 239line 233 didn't jump to line 239 because the condition on line 233 was always true

234 if asyncio.iscoroutinefunction(fallback_fn): 

235 return await fallback_fn(*args, **kwargs) # type: ignore[no-any-return] 

236 else: 

237 return fallback_fn(*args, **kwargs) # type: ignore[no-any-return] 

238 else: # fallback_strategy is not None 

239 return fallback_strategy.get_fallback_value(*args, **kwargs) # type: ignore[no-any-return,union-attr] 

240 

241 @functools.wraps(func) 

242 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

243 """Sync wrapper with fallback""" 

244 try: 

245 return func(*args, **kwargs) 

246 except BaseException as e: 

247 # Check if this is an exception type we should handle 

248 if not isinstance(e, exception_types): 

249 raise 

250 

251 logger.warning(f"Using fallback for {func.__name__}", exc_info=True) 

252 

253 if fallback is not None: 

254 return fallback # type: ignore[no-any-return] 

255 elif fallback_fn is not None: 

256 return fallback_fn(*args, **kwargs) # type: ignore[no-any-return] 

257 else: # fallback_strategy is not None 

258 return fallback_strategy.get_fallback_value(*args, **kwargs) # type: ignore[no-any-return,union-attr] 

259 

260 # Return appropriate wrapper 

261 import asyncio 

262 

263 if asyncio.iscoroutinefunction(func): 263 ↛ 266line 263 didn't jump to line 266 because the condition on line 263 was always true

264 return async_wrapper # type: ignore[return-value] 

265 else: 

266 return sync_wrapper 

267 

268 return decorator 

269 

270 

271# Convenience decorators for common fallback scenarios 

272def fail_open(func: Callable[P, bool]) -> Callable[P, bool]: 

273 """ 

274 Decorator for fail-open authorization (allow on error). 

275 

276 Usage: 

277 @fail_open 

278 async def check_permission(user: str, resource: str) -> bool: 

279 return await openfga_client.check(user, resource) 

280 """ 

281 return with_fallback(fallback=True)(func) 

282 

283 

284def fail_closed(func: Callable[P, bool]) -> Callable[P, bool]: 

285 """ 

286 Decorator for fail-closed authorization (deny on error). 

287 

288 Usage: 

289 @fail_closed 

290 async def check_admin_permission(user: str) -> bool: 

291 return await openfga_client.check(user, "admin") 

292 """ 

293 return with_fallback(fallback=False)(func) 

294 

295 

296def return_empty_on_error(func: Callable[P, T]) -> Callable[P, T]: 

297 """ 

298 Decorator to return empty value on error (list/dict/None). 

299 

300 Usage: 

301 @return_empty_on_error 

302 async def get_user_list() -> list: 

303 return await db.query_users() # Returns [] on error 

304 """ 

305 

306 def determine_empty_value() -> Any: 

307 """Determine empty value based on return type hint""" 

308 import inspect 

309 

310 sig = inspect.signature(func) 

311 return_type = sig.return_annotation 

312 

313 if return_type == list or "List" in str(return_type): 

314 return [] 

315 elif return_type == dict or "Dict" in str(return_type): 315 ↛ 318line 315 didn't jump to line 318 because the condition on line 315 was always true

316 return {} 

317 else: 

318 return None 

319 

320 return with_fallback(fallback=determine_empty_value())(func)