Coverage for src / mcp_server_langgraph / auth / api_keys.py: 89%

189 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2API Key Management 

3 

4Manages API key lifecycle including generation, validation, rotation, and revocation. 

5API keys are stored as bcrypt hashes in Keycloak user attributes and exchanged for 

6JWTs on each request. 

7 

8See ADR-0034 for API key to JWT exchange pattern. 

9""" 

10 

11import hashlib 

12import hmac 

13import os 

14import secrets 

15from dataclasses import dataclass 

16from datetime import datetime, timedelta, UTC 

17from typing import TYPE_CHECKING, Any, Optional 

18 

19import bcrypt 

20 

21from mcp_server_langgraph.auth.keycloak import KeycloakClient 

22from mcp_server_langgraph.observability.telemetry import logger 

23 

24if TYPE_CHECKING: 

25 from redis.asyncio import Redis 

26 

27 

28@dataclass 

29class APIKey: 

30 """API key metadata""" 

31 

32 key_id: str 

33 name: str 

34 created: str 

35 expires_at: str 

36 last_used: str | None = None 

37 

38 

39class APIKeyManager: 

40 """Manage API key lifecycle""" 

41 

42 DEFAULT_PREFIX = "mcpkey_live_" 

43 TEST_PREFIX = "mcpkey_test_" 

44 MAX_KEYS_PER_USER = 5 

45 BCRYPT_ROUNDS = 12 

46 

47 def __init__( 

48 self, 

49 keycloak_client: KeycloakClient, 

50 redis_client: Optional["Redis[bytes]"] = None, # type: ignore[type-arg] 

51 cache_ttl: int = 3600, 

52 cache_enabled: bool = True, 

53 ): 

54 """ 

55 Initialize API key manager 

56 

57 Args: 

58 keycloak_client: Keycloak client for user attribute storage 

59 redis_client: Optional Redis client for API key lookup cache 

60 cache_ttl: Cache TTL in seconds (default: 3600 = 1 hour) 

61 cache_enabled: Enable/disable Redis caching (default: True) 

62 """ 

63 self.keycloak = keycloak_client 

64 self.redis = redis_client 

65 self.cache_ttl = cache_ttl 

66 self.cache_enabled = cache_enabled and redis_client is not None 

67 

68 def generate_api_key(self, prefix: str = DEFAULT_PREFIX) -> str: 

69 """ 

70 Generate cryptographically secure API key 

71 

72 Args: 

73 prefix: Prefix for the key (default: "sk_live_") 

74 

75 Returns: 

76 API key string (e.g., "sk_live_abc123xyz...") 

77 """ 

78 # Generate 32 bytes (256 bits) of randomness 

79 random_bytes = secrets.token_urlsafe(32) 

80 return f"{prefix}{random_bytes}" 

81 

82 def hash_api_key(self, api_key: str) -> str: 

83 """ 

84 Hash API key with bcrypt 

85 

86 Args: 

87 api_key: Plain API key 

88 

89 Returns: 

90 bcrypt hash 

91 """ 

92 return bcrypt.hashpw(api_key.encode(), bcrypt.gensalt(rounds=self.BCRYPT_ROUNDS)).decode() 

93 

94 def verify_api_key_hash(self, api_key: str, hashed: str) -> bool: 

95 """ 

96 Verify API key against bcrypt hash 

97 

98 Args: 

99 api_key: Plain API key 

100 hashed: bcrypt hash 

101 

102 Returns: 

103 True if key matches hash, False otherwise 

104 """ 

105 try: 

106 return bcrypt.checkpw(api_key.encode(), hashed.encode()) 

107 except Exception: 

108 return False 

109 

110 async def create_api_key( 

111 self, 

112 user_id: str, 

113 name: str, 

114 expires_days: int = 365, 

115 ) -> dict[str, Any]: 

116 """ 

117 Create new API key for user 

118 

119 Args: 

120 user_id: User identifier (e.g., "user:alice") 

121 name: Human-readable name for the key 

122 expires_days: Days until expiration (default: 365) 

123 

124 Returns: 

125 Dictionary with key_id, api_key, name, expires_at 

126 

127 Raises: 

128 ValueError: If user has reached maximum number of keys 

129 """ 

130 # Get current user attributes 

131 attributes = await self.keycloak.get_user_attributes(user_id) 

132 existing_keys = attributes.get("apiKeys", []) 

133 

134 # Check quota 

135 if len(existing_keys) >= self.MAX_KEYS_PER_USER: 

136 msg = ( 

137 f"Maximum API keys reached ({self.MAX_KEYS_PER_USER}). " 

138 "Please revoke an existing key before creating a new one." 

139 ) 

140 raise ValueError(msg) 

141 

142 # Generate new API key 

143 api_key = self.generate_api_key() 

144 

145 # Hash for storage (bcrypt for security verification) 

146 key_hash = self.hash_api_key(api_key) 

147 

148 # Hash for cache invalidation (SHA256 for fast lookup) 

149 cache_hash = self._hash_api_key_for_cache(api_key) 

150 

151 # Generate key ID 

152 key_id = secrets.token_hex(8) 

153 

154 # Calculate expiration 

155 created_at = datetime.now(UTC) 

156 expires_at = created_at + timedelta(days=expires_days) 

157 

158 # Store in Keycloak attributes 

159 existing_keys.append(f"key:{key_id}:{key_hash}") 

160 attributes["apiKeys"] = existing_keys 

161 attributes[f"apiKey_{key_id}_name"] = name 

162 attributes[f"apiKey_{key_id}_created"] = created_at.isoformat() 

163 attributes[f"apiKey_{key_id}_expiresAt"] = expires_at.isoformat() 

164 attributes[f"apiKey_{key_id}_cacheHash"] = cache_hash # For cache invalidation on revoke 

165 

166 await self.keycloak.update_user_attributes(user_id, attributes) 

167 

168 return { 

169 "key_id": key_id, 

170 "api_key": api_key, # Return once, never stored plaintext 

171 "name": name, 

172 "created": created_at.isoformat(), 

173 "expires_at": expires_at.isoformat(), 

174 } 

175 

176 async def _get_from_cache(self, api_key_hash: str) -> dict[str, Any] | None: 

177 """ 

178 Get user info from Redis cache using API key hash 

179 

180 Args: 

181 api_key_hash: SHA256 hash of the API key (for cache key) 

182 

183 Returns: 

184 Cached user info dict or None if not found 

185 """ 

186 if not self.cache_enabled or not self.redis: 

187 return None 

188 

189 try: 

190 cache_key = f"apikey:{api_key_hash}" 

191 cached_data = await self.redis.get(cache_key) 

192 if cached_data: 

193 import json 

194 

195 logger.debug(f"API key cache hit for hash: {api_key_hash[:16]}...") 

196 return json.loads(cached_data) # type: ignore[no-any-return] 

197 except Exception as e: 

198 logger.warning(f"Redis cache read failed: {e}") 

199 

200 return None 

201 

202 async def _set_in_cache(self, api_key_hash: str, user_info: dict[str, Any]) -> None: 

203 """ 

204 Store user info in Redis cache 

205 

206 Args: 

207 api_key_hash: SHA256 hash of the API key (for cache key) 

208 user_info: User information to cache 

209 """ 

210 if not self.cache_enabled or not self.redis: 

211 return 

212 

213 try: 

214 import json 

215 

216 cache_key = f"apikey:{api_key_hash}" 

217 await self.redis.setex(cache_key, self.cache_ttl, json.dumps(user_info)) 

218 logger.debug(f"API key cached for hash: {api_key_hash[:16]}...") 

219 except Exception as e: 

220 logger.warning(f"Redis cache write failed: {e}") 

221 

222 async def _invalidate_cache(self, api_key_hash: str) -> None: 

223 """ 

224 Invalidate cached user info for an API key 

225 

226 Args: 

227 api_key_hash: SHA256 hash of the API key 

228 """ 

229 if not self.cache_enabled or not self.redis: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 return 

231 

232 try: 

233 cache_key = f"apikey:{api_key_hash}" 

234 await self.redis.delete(cache_key) 

235 logger.debug(f"API key cache invalidated for hash: {api_key_hash[:16]}...") 

236 except Exception as e: 

237 logger.warning(f"Redis cache invalidation failed: {e}") 

238 

239 def _hash_api_key_for_cache(self, api_key: str) -> str: 

240 """ 

241 Create a deterministic keyed hash of API key for cache lookup. 

242 

243 Security: Uses HMAC-SHA256 with a secret key to prevent offline brute-force 

244 attacks if the cache is leaked. Without the secret key, attackers cannot 

245 reverse the hash to recover API keys. 

246 

247 The secret key is loaded from API_KEY_CACHE_SECRET environment variable. 

248 Falls back to a derived key from JWT_SECRET_KEY if not set. 

249 

250 bcrypt is still used for secure storage verification. 

251 

252 Args: 

253 api_key: Plain API key 

254 

255 Returns: 

256 HMAC-SHA256 hex digest 

257 """ 

258 # Get or derive the HMAC secret key 

259 cache_secret = os.getenv("API_KEY_CACHE_SECRET") 

260 if not cache_secret: 260 ↛ 275line 260 didn't jump to line 275 because the condition on line 260 was always true

261 # Fall back to deriving from JWT secret (better than no secret) 

262 jwt_secret = os.getenv("JWT_SECRET_KEY", "") 

263 # Use HKDF-like derivation: HMAC(jwt_secret, "api-key-cache") 

264 cache_secret = hmac.new( 

265 jwt_secret.encode() if jwt_secret else b"default-insecure-key", 

266 b"api-key-cache-derivation", 

267 hashlib.sha256, 

268 ).hexdigest() 

269 

270 # Create HMAC with the secret key 

271 # nosemgrep: python.cryptography.security.insecure-hash-function.insecure-hash-function-sha256 

272 # Security: SHA256 is used for CACHE KEY derivation (not password storage). 

273 # This is a keyed HMAC for lookup optimization, not cryptographic password hashing. 

274 # Actual API key verification uses bcrypt via verify_api_key(). 

275 return hmac.new( 

276 cache_secret.encode(), 

277 api_key.encode(), 

278 hashlib.sha256, 

279 ).hexdigest() 

280 

281 async def validate_and_get_user(self, api_key: str) -> dict[str, Any] | None: 

282 """ 

283 Validate API key and return user information 

284 

285 Args: 

286 api_key: Plain API key to validate 

287 

288 Returns: 

289 Dictionary with user_id, username, email, key_id if valid, None otherwise 

290 

291 Note: 

292 This implementation uses Redis cache for O(1) lookups when enabled. 

293 Falls back to paginating through all users if cache miss occurs. 

294 See ADR-0034 for Redis-backed API key cache design. 

295 """ 

296 # Try cache first (O(1) lookup) 

297 api_key_hash = self._hash_api_key_for_cache(api_key) 

298 cached_user = await self._get_from_cache(api_key_hash) 

299 if cached_user: 

300 # Verify expiration from cache 

301 expires_at_str = cached_user.get("expires_at") 

302 if expires_at_str: 302 ↛ 314line 302 didn't jump to line 314 because the condition on line 302 was always true

303 expires_at = datetime.fromisoformat(expires_at_str) 

304 # Ensure timezone-aware comparison (handle both naive and aware datetimes) 

305 if expires_at.tzinfo is None: 305 ↛ 307line 305 didn't jump to line 307 because the condition on line 305 was always true

306 expires_at = expires_at.replace(tzinfo=UTC) 

307 if datetime.now(UTC) > expires_at: 307 ↛ 309line 307 didn't jump to line 309 because the condition on line 307 was never true

308 # Expired, invalidate cache and continue to full search 

309 await self._invalidate_cache(api_key_hash) 

310 else: 

311 return cached_user 

312 

313 # No expiration or still valid 

314 return cached_user 

315 # PERFORMANCE WARNING (OpenAI Codex Finding #5): 

316 # This O(n) pagination fallback is inefficient for large user bases. 

317 # The Redis cache mitigates this (ADR-0034), but cold starts are slow. 

318 # 

319 # RECOMMENDED FUTURE OPTIMIZATION: 

320 # 1. Add indexed Keycloak user attribute: api_key_hash 

321 # 2. Use: keycloak.search_users(query=f"api_key_hash:{hash}") 

322 # 3. This provides O(1) lookup instead of O(n) enumeration 

323 # 

324 # Until then, monitor cache hit rate and user count: 

325 logger.warning( 

326 "API key validation: Cache miss triggered user enumeration (O(n) fallback). " 

327 "Redis cache provides primary mitigation (ADR-0034). " 

328 "For production deployments with >1000 users, consider implementing Keycloak " 

329 "indexed attribute search (see OpenAI Codex Finding #5).", 

330 extra={ 

331 "cache_enabled": self.cache_enabled, 

332 "mitigation": "Redis cache provides O(1) for cache hits (ADR-0034)", 

333 "recommendation": "Implement indexed Keycloak attribute search for cold starts", 

334 }, 

335 ) 

336 

337 # Paginate through all users to find matching key hash 

338 first = 0 

339 max_per_page = 100 

340 users_scanned = 0 

341 

342 while True: 

343 # Fetch page of users 

344 users = await self.keycloak.search_users(first=first, max=max_per_page) 

345 

346 # No more users, key not found 

347 if not users: 

348 break 

349 

350 users_scanned += len(users) 

351 

352 # Search this page for matching key 

353 for user in users: 

354 attributes = user.get("attributes", {}) 

355 api_keys = attributes.get("apiKeys", []) 

356 

357 for key_entry in api_keys: 

358 # Format: "key:key_id:hash" 

359 parts = key_entry.split(":") 

360 if len(parts) != 3: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true

361 continue # Invalid format 

362 

363 _, key_id, stored_hash = parts 

364 

365 # Check if hash matches 

366 if self.verify_api_key_hash(api_key, stored_hash): 366 ↛ 357line 366 didn't jump to line 357 because the condition on line 366 was always true

367 # Check expiration 

368 expires_at_str = attributes.get(f"apiKey_{key_id}_expiresAt") 

369 if expires_at_str: 369 ↛ 378line 369 didn't jump to line 378 because the condition on line 369 was always true

370 expires_at = datetime.fromisoformat(expires_at_str) 

371 # Ensure timezone-aware comparison (handle both naive and aware datetimes) 

372 if expires_at.tzinfo is None: 

373 expires_at = expires_at.replace(tzinfo=UTC) 

374 if datetime.now(UTC) > expires_at: 

375 continue # Expired 

376 

377 # Update last used timestamp 

378 attributes[f"apiKey_{key_id}_lastUsed"] = datetime.now(UTC).isoformat() 

379 await self.keycloak.update_user_attributes(user["id"], attributes) 

380 

381 user_info = { 

382 "user_id": f"user:{user['username']}", # OpenFGA format 

383 "keycloak_id": user["id"], # Raw UUID for Keycloak Admin API 

384 "username": user["username"], 

385 "email": user.get("email"), 

386 "key_id": key_id, 

387 "expires_at": expires_at_str, # Store for cache validation 

388 } 

389 

390 # Cache for future lookups (O(1) next time) 

391 await self._set_in_cache(api_key_hash, user_info) 

392 

393 return user_info 

394 

395 # Move to next page 

396 first += max_per_page 

397 

398 # PERFORMANCE MONITORING (OpenAI Codex Finding #5): 

399 # Log how many users were scanned to identify performance issues 

400 logger.info( 

401 "API key validation: User enumeration completed (key not found)", 

402 extra={ 

403 "users_scanned": users_scanned, 

404 "performance_impact": "HIGH" if users_scanned > 1000 else "MEDIUM" if users_scanned > 100 else "LOW", 

405 "recommendation": "Implement Keycloak indexed search if users_scanned > 1000", 

406 }, 

407 ) 

408 

409 return None # Invalid key 

410 

411 async def revoke_api_key(self, user_id: str, key_id: str) -> None: 

412 """ 

413 Revoke specific API key 

414 

415 Args: 

416 user_id: User identifier 

417 key_id: Key identifier to revoke 

418 """ 

419 # Get current attributes 

420 attributes = await self.keycloak.get_user_attributes(user_id) 

421 api_keys = attributes.get("apiKeys", []) 

422 

423 # Invalidate cache if hash is stored 

424 cache_hash = attributes.get(f"apiKey_{key_id}_cacheHash") 

425 if cache_hash: 

426 await self._invalidate_cache(cache_hash) 

427 

428 # Remove key entry 

429 attributes["apiKeys"] = [key for key in api_keys if not key.startswith(f"key:{key_id}:")] 

430 

431 # Remove metadata 

432 attributes.pop(f"apiKey_{key_id}_name", None) 

433 attributes.pop(f"apiKey_{key_id}_created", None) 

434 attributes.pop(f"apiKey_{key_id}_expiresAt", None) 

435 attributes.pop(f"apiKey_{key_id}_lastUsed", None) 

436 attributes.pop(f"apiKey_{key_id}_cacheHash", None) 

437 

438 await self.keycloak.update_user_attributes(user_id, attributes) 

439 

440 async def list_api_keys(self, user_id: str) -> list[dict[str, Any]]: 

441 """ 

442 List all API keys for user (without showing actual keys) 

443 

444 Args: 

445 user_id: User identifier 

446 

447 Returns: 

448 List of dictionaries with key_id, name, created, expires_at, last_used 

449 """ 

450 attributes = await self.keycloak.get_user_attributes(user_id) 

451 api_keys = attributes.get("apiKeys", []) 

452 

453 keys = [] 

454 for key_entry in api_keys: 

455 parts = key_entry.split(":") 

456 if len(parts) != 3: 456 ↛ 457line 456 didn't jump to line 457 because the condition on line 456 was never true

457 continue 

458 

459 _, key_id, _ = parts 

460 

461 key_info = { 

462 "key_id": key_id, 

463 "name": attributes.get(f"apiKey_{key_id}_name", ""), 

464 "created": attributes.get(f"apiKey_{key_id}_created", ""), 

465 "expires_at": attributes.get(f"apiKey_{key_id}_expiresAt", ""), 

466 } 

467 

468 # Include last_used if available 

469 last_used = attributes.get(f"apiKey_{key_id}_lastUsed") 

470 if last_used: 

471 key_info["last_used"] = last_used 

472 

473 keys.append(key_info) 

474 

475 return keys 

476 

477 async def rotate_api_key(self, user_id: str, key_id: str, grace_period_days: int = 0) -> dict[str, Any]: 

478 """ 

479 Rotate API key (generate new key, keeping same key_id) 

480 

481 Args: 

482 user_id: User identifier 

483 key_id: Key identifier to rotate 

484 grace_period_days: Days to keep old key valid (default: 0 = immediate) 

485 

486 Returns: 

487 Dictionary with key_id, new_api_key 

488 

489 Raises: 

490 ValueError: If key_id not found 

491 """ 

492 # Get current attributes 

493 attributes = await self.keycloak.get_user_attributes(user_id) 

494 api_keys = attributes.get("apiKeys", []) 

495 

496 # Find the key to rotate 

497 key_found = False 

498 for i, key_entry in enumerate(api_keys): 

499 if key_entry.startswith(f"key:{key_id}:"): 499 ↛ 498line 499 didn't jump to line 498 because the condition on line 499 was always true

500 key_found = True 

501 

502 # Generate new API key 

503 new_api_key = self.generate_api_key() 

504 new_hash = self.hash_api_key(new_api_key) 

505 

506 # Replace with new hash 

507 api_keys[i] = f"key:{key_id}:{new_hash}" 

508 

509 # Keep existing metadata (name, created), update expiration if needed 

510 if grace_period_days > 0: 510 ↛ 512line 510 didn't jump to line 512 because the condition on line 510 was never true

511 # Extend expiration for grace period 

512 new_expires = datetime.now(UTC) + timedelta(days=grace_period_days) 

513 attributes[f"apiKey_{key_id}_expiresAt"] = new_expires.isoformat() 

514 

515 break 

516 

517 if not key_found: 

518 msg = f"API key with ID '{key_id}' not found for user '{user_id}'" 

519 raise ValueError(msg) 

520 

521 # Update attributes 

522 attributes["apiKeys"] = api_keys 

523 await self.keycloak.update_user_attributes(user_id, attributes) 

524 

525 return { 

526 "key_id": key_id, 

527 "new_api_key": new_api_key, 

528 }