Coverage for src / mcp_server_langgraph / monitoring / cost_api.py: 91%

146 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Cost Monitoring API 

3 

4FastAPI endpoints for retrieving cost data, managing budgets, and exporting reports. 

5 

6Endpoints: 

7- GET /api/cost/summary - Aggregated cost summary 

8- GET /api/cost/usage - Detailed usage records 

9- GET /api/cost/budget/{budget_id} - Budget status 

10- POST /api/cost/budget - Create budget 

11- GET /api/cost/trends - Time-series cost trends 

12- GET /api/cost/export - Export cost data (CSV/JSON) 

13 

14Example: 

15 uvicorn mcp_server_langgraph.monitoring.cost_api:app --reload --port 8003 

16""" 

17 

18import csv 

19import io 

20from datetime import datetime, timedelta, UTC 

21from decimal import Decimal 

22from typing import Any 

23 

24from fastapi import FastAPI, HTTPException, Query, Response, status 

25from pydantic import BaseModel, ConfigDict, Field, field_serializer 

26 

27from .budget_monitor import Budget, BudgetPeriod, BudgetStatus, get_budget_monitor 

28from .cost_tracker import CostAggregator, get_cost_collector 

29 

30# ============================================================================== 

31# FastAPI Application 

32# ============================================================================== 

33 

34app = FastAPI( 

35 title="Cost Monitoring API", 

36 description="LLM cost tracking, budget monitoring, and financial analytics", 

37 version="1.0.0", 

38) 

39 

40 

41# ============================================================================== 

42# Request/Response Models 

43# ============================================================================== 

44 

45 

46class CostSummaryResponse(BaseModel): 

47 """Cost summary response.""" 

48 

49 period_start: datetime 

50 period_end: datetime 

51 total_cost_usd: str # Decimal as string 

52 total_tokens: int 

53 request_count: int 

54 by_model: dict[str, str] # Decimal as string 

55 by_user: dict[str, str] # Decimal as string 

56 by_feature: dict[str, str] = Field(default_factory=dict) 

57 

58 model_config = ConfigDict() 

59 

60 @field_serializer("period_start", "period_end") 

61 def serialize_dates(self, value: datetime) -> str: 

62 """Serialize datetime as ISO 8601 string.""" 

63 return value.isoformat() 

64 

65 

66class UsageRecordResponse(BaseModel): 

67 """Usage record response.""" 

68 

69 timestamp: datetime 

70 user_id: str 

71 session_id: str 

72 model: str 

73 provider: str 

74 prompt_tokens: int 

75 completion_tokens: int 

76 total_tokens: int 

77 estimated_cost_usd: str 

78 feature: str | None = None 

79 

80 model_config = ConfigDict() 

81 

82 @field_serializer("timestamp") 

83 def serialize_timestamp(self, value: datetime) -> str: 

84 """Serialize datetime as ISO 8601 string.""" 

85 return value.isoformat() 

86 

87 

88class CreateBudgetRequest(BaseModel): 

89 """Request to create a budget.""" 

90 

91 id: str 

92 name: str 

93 limit_usd: str # Decimal as string 

94 period: BudgetPeriod 

95 alert_thresholds: list[str] | None = None # Decimals as strings 

96 

97 

98class TrendDataPoint(BaseModel): 

99 """Single data point in trend series.""" 

100 

101 timestamp: datetime 

102 value: str # Decimal as string 

103 

104 model_config = ConfigDict() 

105 

106 @field_serializer("timestamp") 

107 def serialize_timestamp(self, value: datetime) -> str: 

108 """Serialize datetime as ISO 8601 string.""" 

109 return value.isoformat() 

110 

111 

112class TrendsResponse(BaseModel): 

113 """Cost trends response.""" 

114 

115 metric: str 

116 period: str 

117 data_points: list[TrendDataPoint] 

118 

119 

120# ============================================================================== 

121# Endpoints 

122# ============================================================================== 

123 

124 

125@app.get("/") 

126def root() -> dict[str, Any]: 

127 """API information.""" 

128 return { 

129 "name": "Cost Monitoring API", 

130 "version": "1.0.0", 

131 "features": [ 

132 "Real-time cost tracking", 

133 "Budget monitoring", 

134 "Multi-dimensional aggregation", 

135 "Export capabilities", 

136 "Trend analysis", 

137 ], 

138 } 

139 

140 

141@app.get("/api/cost/summary") 

142async def get_cost_summary( 

143 period: str = Query("month", description="Time period (day, week, month)"), 

144 group_by: str | None = Query(None, description="Group by dimension (model, user, feature)"), 

145) -> CostSummaryResponse: 

146 """ 

147 Get aggregated cost summary. 

148 

149 Args: 

150 period: Time period (day, week, month) 

151 group_by: Optional grouping dimension 

152 

153 Returns: 

154 Aggregated cost summary 

155 

156 Example: 

157 GET /api/cost/summary?period=month&group_by=model 

158 """ 

159 collector = get_cost_collector() 

160 aggregator = CostAggregator() 

161 

162 # Get records for period 

163 records = await collector.get_records(period=period) 

164 

165 # Calculate totals 

166 total_cost = sum((r.estimated_cost_usd for r in records), Decimal("0")) 

167 total_tokens = sum(r.total_tokens for r in records) 

168 request_count = len(records) 

169 

170 # Aggregate by dimensions 

171 record_dicts = [ 

172 { 

173 "model": r.model, 

174 "user_id": r.user_id, 

175 "feature": r.feature, 

176 "cost": r.estimated_cost_usd, 

177 } 

178 for r in records 

179 ] 

180 

181 by_model = await aggregator.aggregate_by_model(record_dicts) 

182 by_user = await aggregator.aggregate_by_user(record_dicts) 

183 by_feature = await aggregator.aggregate_by_feature(record_dicts) 

184 

185 # Calculate period dates 

186 now = datetime.now(UTC) 

187 if period == "day": 

188 period_start = now.replace(hour=0, minute=0, second=0, microsecond=0) 

189 elif period == "week": 

190 period_start = now - timedelta(days=7) 

191 else: # month 

192 period_start = now - timedelta(days=30) 

193 

194 return CostSummaryResponse( 

195 period_start=period_start, 

196 period_end=now, 

197 total_cost_usd=str(total_cost), 

198 total_tokens=total_tokens, 

199 request_count=request_count, 

200 by_model={k: str(v) for k, v in by_model.items()}, 

201 by_user={k: str(v) for k, v in by_user.items()}, 

202 by_feature={k: str(v) for k, v in by_feature.items()}, 

203 ) 

204 

205 

206@app.get("/api/cost/usage") 

207async def get_usage_records( 

208 user_id: str | None = Query(None, description="Filter by user ID"), 

209 model: str | None = Query(None, description="Filter by model"), 

210 start: datetime | None = Query(None, description="Start datetime"), 

211 end: datetime | None = Query(None, description="End datetime"), 

212 limit: int = Query(100, description="Max records to return", ge=1, le=1000), 

213) -> list[UsageRecordResponse]: 

214 """ 

215 Get detailed usage records. 

216 

217 Args: 

218 user_id: Filter by user (optional) 

219 model: Filter by model (optional) 

220 start: Start datetime (optional) 

221 end: End datetime (optional) 

222 limit: Maximum records 

223 

224 Returns: 

225 List of usage records 

226 

227 Example: 

228 GET /api/cost/usage?user_id=user123&limit=50 

229 """ 

230 collector = get_cost_collector() 

231 

232 # Get records 

233 records = await collector.get_records(user_id=user_id, model=model) 

234 

235 # Apply date filters 

236 if start: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 records = [r for r in records if r.timestamp >= start] 

238 if end: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true

239 records = [r for r in records if r.timestamp <= end] 

240 

241 # Limit results 

242 records = records[:limit] 

243 

244 return [ 

245 UsageRecordResponse( 

246 timestamp=r.timestamp, 

247 user_id=r.user_id, 

248 session_id=r.session_id, 

249 model=r.model, 

250 provider=r.provider, 

251 prompt_tokens=r.prompt_tokens, 

252 completion_tokens=r.completion_tokens, 

253 total_tokens=r.total_tokens, 

254 estimated_cost_usd=str(r.estimated_cost_usd), 

255 feature=r.feature, 

256 ) 

257 for r in records 

258 ] 

259 

260 

261@app.get("/api/cost/budget/{budget_id}") 

262async def get_budget_status(budget_id: str) -> BudgetStatus: 

263 """ 

264 Get budget status. 

265 

266 Args: 

267 budget_id: Budget identifier 

268 

269 Returns: 

270 Current budget status with utilization 

271 

272 Example: 

273 GET /api/cost/budget/dev_team_monthly 

274 """ 

275 monitor = get_budget_monitor() 

276 budget_status = await monitor.get_budget_status(budget_id) 

277 

278 if not budget_status: 

279 raise HTTPException( 

280 status_code=status.HTTP_404_NOT_FOUND, 

281 detail=f"Budget '{budget_id}' not found", 

282 ) 

283 

284 return budget_status 

285 

286 

287@app.post("/api/cost/budget", status_code=status.HTTP_201_CREATED) 

288async def create_budget(request: CreateBudgetRequest) -> Budget: 

289 """ 

290 Create a new budget. 

291 

292 Args: 

293 request: Budget creation request 

294 

295 Returns: 

296 Created budget 

297 

298 Example: 

299 POST /api/cost/budget 

300 { 

301 "id": "dev_monthly", 

302 "name": "Development Team - Monthly", 

303 "limit_usd": "1000.00", 

304 "period": "monthly", 

305 "alert_thresholds": ["0.75", "0.90"] 

306 } 

307 """ 

308 monitor = get_budget_monitor() 

309 

310 # Convert string decimals to Decimal 

311 limit_usd = Decimal(request.limit_usd) 

312 alert_thresholds = [Decimal(t) for t in request.alert_thresholds] if request.alert_thresholds else None 

313 

314 budget = await monitor.create_budget( 

315 id=request.id, 

316 name=request.name, 

317 limit_usd=limit_usd, 

318 period=request.period, 

319 alert_thresholds=alert_thresholds, 

320 ) 

321 

322 return budget 

323 

324 

325@app.get("/api/cost/trends") 

326async def get_cost_trends( 

327 metric: str = Query("total_cost", description="Metric to track (total_cost, token_usage)"), 

328 period: str = Query("7d", description="Time period (7d, 30d, 90d)"), 

329) -> TrendsResponse: 

330 """ 

331 Get cost trends over time from actual usage records. 

332 

333 Aggregates cost and token usage data by day from the CostMetricsCollector. 

334 Supports both in-memory and PostgreSQL data sources. 

335 

336 Args: 

337 metric: Metric to track ("total_cost" or "token_usage") 

338 period: Time period ("7d", "30d", "90d") 

339 

340 Returns: 

341 Time-series trend data with daily aggregations 

342 

343 Example: 

344 GET /api/cost/trends?metric=total_cost&period=30d 

345 """ 

346 collector = get_cost_collector() 

347 now = datetime.now(UTC) 

348 days = int(period.replace("d", "")) 

349 

350 # Get all records from the period 

351 period_start = now - timedelta(days=days) 

352 all_records = await collector.get_records(period="day") # Will get all in-memory records 

353 

354 # Filter records within the period 

355 records_in_period = [r for r in all_records if r.timestamp >= period_start] 

356 

357 # Aggregate by day 

358 daily_data: dict[str, dict[str, Any]] = {} 

359 

360 for record in records_in_period: 360 ↛ 362line 360 didn't jump to line 362 because the loop on line 360 never started

361 # Get day key (YYYY-MM-DD) 

362 day_key = record.timestamp.date().isoformat() 

363 

364 if day_key not in daily_data: 

365 daily_data[day_key] = { 

366 "timestamp": record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0), 

367 "total_cost": Decimal("0"), 

368 "total_tokens": 0, 

369 } 

370 

371 # Aggregate metrics 

372 daily_data[day_key]["total_cost"] += record.estimated_cost_usd 

373 daily_data[day_key]["total_tokens"] += record.total_tokens 

374 

375 # Create data points for each day in the period (fill missing days with zeros) 

376 data_points = [] 

377 for i in range(days): 

378 day_date = (period_start + timedelta(days=i)).date() 

379 day_key = day_date.isoformat() 

380 

381 if day_key in daily_data: 381 ↛ 382line 381 didn't jump to line 382 because the condition on line 381 was never true

382 day_data = daily_data[day_key] 

383 value = day_data["total_cost"] if metric == "total_cost" else Decimal(str(day_data["total_tokens"])) 

384 else: 

385 # No data for this day 

386 value = Decimal("0") 

387 

388 data_points.append( 

389 TrendDataPoint( 

390 timestamp=datetime.combine(day_date, datetime.min.time(), tzinfo=UTC), 

391 value=str(value), 

392 ) 

393 ) 

394 

395 return TrendsResponse( 

396 metric=metric, 

397 period=period, 

398 data_points=data_points, 

399 ) 

400 

401 

402@app.get("/api/cost/export") 

403async def export_cost_data( 

404 format: str = Query("csv", description="Export format (csv, json)"), 

405 period: str = Query("month", description="Time period"), 

406) -> Response: 

407 """ 

408 Export cost data. 

409 

410 Args: 

411 format: Export format (csv or json) 

412 period: Time period 

413 

414 Returns: 

415 Cost data in requested format 

416 

417 Example: 

418 GET /api/cost/export?format=csv&period=month 

419 """ 

420 collector = get_cost_collector() 

421 records = await collector.get_records(period=period) 

422 

423 if format == "csv": 

424 # Generate CSV 

425 output = io.StringIO() 

426 writer = csv.writer(output) 

427 

428 # Header 

429 writer.writerow( 

430 [ 

431 "timestamp", 

432 "user_id", 

433 "session_id", 

434 "model", 

435 "provider", 

436 "prompt_tokens", 

437 "completion_tokens", 

438 "total_tokens", 

439 "cost_usd", 

440 "feature", 

441 ] 

442 ) 

443 

444 # Data 

445 for r in records: 

446 writer.writerow( 

447 [ 

448 r.timestamp.isoformat(), 

449 r.user_id, 

450 r.session_id, 

451 r.model, 

452 r.provider, 

453 r.prompt_tokens, 

454 r.completion_tokens, 

455 r.total_tokens, 

456 str(r.estimated_cost_usd), 

457 r.feature or "", 

458 ] 

459 ) 

460 

461 csv_content = output.getvalue() 

462 return Response( 

463 content=csv_content, 

464 media_type="text/csv", 

465 headers={"Content-Disposition": f'attachment; filename="cost_export_{period}.csv"'}, 

466 ) 

467 

468 elif format == "json": 

469 # Generate JSON 

470 import json 

471 

472 records_dict = [ 

473 { 

474 "timestamp": r.timestamp.isoformat(), 

475 "user_id": r.user_id, 

476 "session_id": r.session_id, 

477 "model": r.model, 

478 "provider": r.provider, 

479 "prompt_tokens": r.prompt_tokens, 

480 "completion_tokens": r.completion_tokens, 

481 "total_tokens": r.total_tokens, 

482 "cost_usd": str(r.estimated_cost_usd), 

483 "feature": r.feature, 

484 } 

485 for r in records 

486 ] 

487 

488 return Response( 

489 content=json.dumps(records_dict, indent=2), 

490 media_type="application/json", 

491 headers={"Content-Disposition": f'attachment; filename="cost_export_{period}.json"'}, 

492 ) 

493 

494 else: 

495 raise HTTPException( 

496 status_code=status.HTTP_400_BAD_REQUEST, 

497 detail=f"Unsupported format: {format}", 

498 ) 

499 

500 

501# ============================================================================== 

502# Health Check 

503# ============================================================================== 

504 

505 

506@app.get("/health") 

507def health_check() -> dict[str, str]: 

508 """Health check endpoint.""" 

509 return {"status": "healthy", "service": "cost-monitoring-api"} 

510 

511 

512# ============================================================================== 

513# Run Server 

514# ============================================================================== 

515 

516if __name__ == "__main__": 

517 import uvicorn 

518 

519 print("=" * 80) 

520 print("💰 Cost Monitoring API") 

521 print("=" * 80) 

522 print("\nStarting server...") 

523 print("\n📍 Endpoints:") 

524 print(" • Summary: GET http://localhost:8003/api/cost/summary") 

525 print(" • Usage: GET http://localhost:8003/api/cost/usage") 

526 print(" • Budget: GET http://localhost:8003/api/cost/budget/{id}") 

527 print(" • Create: POST http://localhost:8003/api/cost/budget") 

528 print(" • Trends: GET http://localhost:8003/api/cost/trends") 

529 print(" • Export: GET http://localhost:8003/api/cost/export") 

530 print(" • Docs: http://localhost:8003/docs") 

531 print("=" * 80) 

532 print() 

533 

534 uvicorn.run(app, host="0.0.0.0", port=8003, reload=True) # nosec B104