Coverage for src / mcp_server_langgraph / monitoring / cost_tracker.py: 71%

165 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Cost Metrics Collector 

3 

4Tracks token usage and costs for LLM API calls with async recording, 

5Prometheus metrics integration, and PostgreSQL persistence. 

6 

7Example: 

8 >>> from mcp_server_langgraph.monitoring.cost_tracker import CostMetricsCollector 

9 >>> collector = CostMetricsCollector() 

10 >>> await collector.record_usage( 

11 ... user_id="user123", 

12 ... session_id="session456", 

13 ... model="claude-sonnet-4-5-20250929", 

14 ... provider="anthropic", 

15 ... prompt_tokens=1000, 

16 ... completion_tokens=500 

17 ... ) 

18""" 

19 

20import asyncio 

21from collections import defaultdict 

22from dataclasses import dataclass, field 

23from datetime import datetime, timedelta, UTC 

24from decimal import Decimal 

25from typing import Any, cast 

26 

27from pydantic import BaseModel, ConfigDict, Field, field_serializer 

28 

29from .pricing import calculate_cost 

30 

31# ============================================================================== 

32# Data Models 

33# ============================================================================== 

34 

35 

36class TokenUsage(BaseModel): 

37 """Token usage record for a single LLM call.""" 

38 

39 timestamp: datetime = Field(description="When the call was made") 

40 user_id: str = Field(description="User who made the call") 

41 session_id: str = Field(description="Session identifier") 

42 model: str = Field(description="Model name") 

43 provider: str = Field(description="Provider (anthropic, openai, google)") 

44 prompt_tokens: int = Field(description="Number of input tokens", ge=0) 

45 completion_tokens: int = Field(description="Number of output tokens", ge=0) 

46 total_tokens: int = Field(description="Total tokens (prompt + completion)", ge=0) 

47 estimated_cost_usd: Decimal = Field(description="Estimated cost in USD") 

48 feature: str | None = Field(default=None, description="Feature that triggered the call") 

49 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") 

50 

51 model_config = ConfigDict() 

52 

53 @field_serializer("estimated_cost_usd") 

54 def serialize_decimal(self, value: Decimal) -> str: 

55 """Serialize Decimal as string for JSON compatibility.""" 

56 return str(value) 

57 

58 @field_serializer("timestamp") 

59 def serialize_timestamp(self, value: datetime) -> str: 

60 """Serialize datetime as ISO 8601 string.""" 

61 return value.isoformat() 

62 

63 def __init__(self, **data: Any): 

64 # Calculate total_tokens if not provided 

65 if "total_tokens" not in data: 

66 data["total_tokens"] = data.get("prompt_tokens", 0) + data.get("completion_tokens", 0) 

67 super().__init__(**data) 

68 

69 

70# ============================================================================== 

71# Prometheus Metrics 

72# ============================================================================== 

73 

74try: 

75 from prometheus_client import Counter 

76 

77 # Real Prometheus metrics for production 

78 llm_token_usage = Counter( 

79 name="llm_token_usage_total", 

80 documentation="Total tokens used by LLM calls", 

81 labelnames=["provider", "model", "token_type"], 

82 ) 

83 

84 llm_cost = Counter( 

85 name="llm_cost_usd_total", 

86 documentation="Total estimated cost in USD (cumulative)", 

87 labelnames=["provider", "model"], 

88 ) 

89 

90 PROMETHEUS_AVAILABLE = True 

91 

92except ImportError: 

93 # Fallback to mock counters if prometheus_client not available 

94 import warnings 

95 

96 warnings.warn( 

97 "prometheus_client not available, using mock counters. " 

98 "Install prometheus-client for production metrics: pip install prometheus-client" 

99 ) 

100 

101 @dataclass 

102 class MockPrometheusCounter: 

103 """Mock Prometheus counter for testing.""" 

104 

105 name: str 

106 description: str 

107 labelnames: list[str] 

108 _values: dict[tuple[str, ...], float] = field(default_factory=lambda: defaultdict(float)) 

109 

110 def labels(self, **labels: str) -> "MockCounterChild": 

111 """Return label-specific counter.""" 

112 label_tuple = tuple(labels.get(name, "") for name in self.labelnames) 

113 return MockCounterChild(self, label_tuple) 

114 

115 @dataclass 

116 class MockCounterChild: 

117 """Child counter with specific labels.""" 

118 

119 parent: MockPrometheusCounter 

120 label_values: tuple[str, ...] 

121 

122 def inc(self, amount: float = 1.0) -> None: 

123 """Increment counter.""" 

124 self.parent._values[self.label_values] += amount 

125 

126 # Mock Prometheus metrics (fallback when prometheus_client not installed) 

127 llm_token_usage = cast( 

128 Counter, 

129 MockPrometheusCounter( 

130 name="llm_token_usage_total", 

131 description="Total tokens used by LLM calls", 

132 labelnames=["provider", "model", "token_type"], 

133 ), 

134 ) 

135 

136 llm_cost = cast( 

137 Counter, 

138 MockPrometheusCounter( 

139 name="llm_cost_usd_total", 

140 description="Total estimated cost in USD", 

141 labelnames=["provider", "model"], 

142 ), 

143 ) 

144 

145 PROMETHEUS_AVAILABLE = False 

146 

147 

148# ============================================================================== 

149# Cost Metrics Collector 

150# ============================================================================== 

151 

152 

153class CostMetricsCollector: 

154 """ 

155 Collects and persists LLM cost metrics. 

156 

157 Features: 

158 - Async recording to avoid blocking API calls 

159 - Automatic cost calculation 

160 - Prometheus metrics integration 

161 - PostgreSQL persistence with retention policy 

162 - In-memory fallback when database unavailable 

163 """ 

164 

165 def __init__( 

166 self, 

167 database_url: str | None = None, 

168 retention_days: int = 90, 

169 enable_persistence: bool = True, 

170 ) -> None: 

171 """ 

172 Initialize the cost metrics collector. 

173 

174 Args: 

175 database_url: PostgreSQL connection URL (postgresql+asyncpg://...) 

176 If None, uses in-memory storage only 

177 retention_days: Number of days to retain records (default: 90) 

178 enable_persistence: Whether to enable PostgreSQL persistence 

179 """ 

180 self._records: list[TokenUsage] = [] 

181 self._lock = asyncio.Lock() 

182 self._database_url = database_url 

183 self._retention_days = retention_days 

184 self._enable_persistence = enable_persistence and database_url is not None 

185 

186 @property 

187 def total_records(self) -> int: 

188 """Get total number of records.""" 

189 return len(self._records) 

190 

191 async def record_usage( 

192 self, 

193 timestamp: datetime, 

194 user_id: str, 

195 session_id: str, 

196 model: str, 

197 provider: str, 

198 prompt_tokens: int, 

199 completion_tokens: int, 

200 estimated_cost_usd: Decimal | None = None, 

201 feature: str | None = None, 

202 metadata: dict[str, Any] | None = None, 

203 ) -> TokenUsage: 

204 """ 

205 Record token usage for an LLM call. 

206 

207 Args: 

208 timestamp: When the call was made 

209 user_id: User identifier 

210 session_id: Session identifier 

211 model: Model name 

212 provider: Provider name 

213 prompt_tokens: Input token count 

214 completion_tokens: Output token count 

215 estimated_cost_usd: Pre-calculated cost (optional) 

216 feature: Feature name (optional) 

217 metadata: Additional metadata (optional) 

218 

219 Returns: 

220 TokenUsage record 

221 

222 Example: 

223 >>> collector = CostMetricsCollector() 

224 >>> usage = await collector.record_usage( 

225 ... timestamp=datetime.now(timezone.utc), 

226 ... user_id="user123", 

227 ... session_id="session456", 

228 ... model="claude-sonnet-4-5-20250929", 

229 ... provider="anthropic", 

230 ... prompt_tokens=1000, 

231 ... completion_tokens=500 

232 ... ) 

233 """ 

234 # Calculate cost if not provided 

235 if estimated_cost_usd is None: 

236 estimated_cost_usd = calculate_cost( 

237 model=model, 

238 provider=provider, 

239 prompt_tokens=prompt_tokens, 

240 completion_tokens=completion_tokens, 

241 ) 

242 

243 # Create usage record 

244 usage = TokenUsage( 

245 timestamp=timestamp, 

246 user_id=user_id, 

247 session_id=session_id, 

248 model=model, 

249 provider=provider, 

250 prompt_tokens=prompt_tokens, 

251 completion_tokens=completion_tokens, 

252 estimated_cost_usd=estimated_cost_usd, 

253 feature=feature, 

254 metadata=metadata or {}, 

255 ) 

256 

257 # Store record in-memory (thread-safe) 

258 async with self._lock: 

259 self._records.append(usage) 

260 

261 # Persist to PostgreSQL if enabled 

262 if self._enable_persistence: 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true

263 try: 

264 await self._persist_to_database(usage) 

265 except Exception as e: 

266 # Log error but don't fail the recording 

267 import logging 

268 

269 logger = logging.getLogger(__name__) 

270 logger.exception(f"Failed to persist usage record to database: {e}") 

271 

272 # Update Prometheus metrics 

273 llm_token_usage.labels( 

274 provider=provider, 

275 model=model, 

276 token_type="input", 

277 ).inc(prompt_tokens) 

278 

279 llm_token_usage.labels( 

280 provider=provider, 

281 model=model, 

282 token_type="output", 

283 ).inc(completion_tokens) 

284 

285 llm_cost.labels( 

286 provider=provider, 

287 model=model, 

288 ).inc(float(estimated_cost_usd)) 

289 

290 return usage 

291 

292 async def _persist_to_database(self, usage: TokenUsage) -> None: 

293 """ 

294 Persist usage record to PostgreSQL. 

295 

296 Args: 

297 usage: TokenUsage record to persist 

298 """ 

299 if not self._database_url: 

300 return 

301 

302 from mcp_server_langgraph.database import get_async_session 

303 from mcp_server_langgraph.database.models import TokenUsageRecord 

304 

305 async with get_async_session(self._database_url) as session: 

306 # Create database record 

307 db_record = TokenUsageRecord( 

308 timestamp=usage.timestamp, 

309 user_id=usage.user_id, 

310 session_id=usage.session_id, 

311 model=usage.model, 

312 provider=usage.provider, 

313 prompt_tokens=usage.prompt_tokens, 

314 completion_tokens=usage.completion_tokens, 

315 total_tokens=usage.total_tokens, 

316 estimated_cost_usd=usage.estimated_cost_usd, 

317 feature=usage.feature, 

318 metadata_=usage.metadata, 

319 ) 

320 

321 session.add(db_record) 

322 # Session commits automatically via context manager 

323 

324 async def cleanup_old_records(self) -> int: 

325 """ 

326 Remove records older than retention period. 

327 

328 This method removes both in-memory and PostgreSQL records that exceed 

329 the configured retention period (default: 90 days). 

330 

331 Returns: 

332 Number of records deleted 

333 

334 Example: 

335 >>> collector = CostMetricsCollector(database_url="...", retention_days=90) 

336 >>> deleted = await collector.cleanup_old_records() 

337 >>> print(f"Deleted {deleted} old records") 

338 """ 

339 

340 cutoff_time = datetime.now(UTC) - timedelta(days=self._retention_days) 

341 deleted_count = 0 

342 

343 # Clean up in-memory records 

344 async with self._lock: 

345 initial_count = len(self._records) 

346 self._records = [r for r in self._records if r.timestamp >= cutoff_time] 

347 deleted_count = initial_count - len(self._records) 

348 

349 # Clean up PostgreSQL records 

350 if self._enable_persistence: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 try: 

352 from sqlalchemy import delete 

353 

354 from mcp_server_langgraph.database import get_async_session 

355 from mcp_server_langgraph.database.models import TokenUsageRecord 

356 

357 # Type guard: _database_url is guaranteed non-None when _enable_persistence is True 

358 assert self._database_url is not None, "database_url must be set when persistence is enabled" 

359 

360 async with get_async_session(self._database_url) as session: 

361 stmt = delete(TokenUsageRecord).where(TokenUsageRecord.timestamp < cutoff_time) 

362 result = await session.execute(stmt) 

363 db_deleted = result.rowcount or 0 # type: ignore[attr-defined] 

364 deleted_count += db_deleted 

365 

366 import logging 

367 

368 logger = logging.getLogger(__name__) 

369 logger.info( 

370 f"Cleaned up {deleted_count} records older than {self._retention_days} days " 

371 f"(cutoff: {cutoff_time.isoformat()})" 

372 ) 

373 except Exception as e: 

374 import logging 

375 

376 logger = logging.getLogger(__name__) 

377 logger.exception(f"Failed to cleanup database records: {e}") 

378 

379 return deleted_count 

380 

381 async def get_latest_record(self) -> TokenUsage | None: 

382 """Get the most recent usage record.""" 

383 async with self._lock: 

384 return self._records[-1] if self._records else None 

385 

386 async def get_records( 

387 self, 

388 period: str = "day", 

389 user_id: str | None = None, 

390 model: str | None = None, 

391 ) -> list[TokenUsage]: 

392 """ 

393 Get usage records with optional filtering. 

394 

395 Args: 

396 period: Time period ("day", "week", "month") 

397 user_id: Filter by user (optional) 

398 model: Filter by model (optional) 

399 

400 Returns: 

401 List of TokenUsage records 

402 """ 

403 async with self._lock: 

404 records = self._records.copy() 

405 

406 # Apply filters 

407 if user_id: 

408 records = [r for r in records if r.user_id == user_id] 

409 

410 if model: 

411 records = [r for r in records if r.model == model] 

412 

413 # TODO: Apply time period filter 

414 

415 return records 

416 

417 async def get_total_cost( 

418 self, 

419 user_id: str | None = None, 

420 period: str | None = None, 

421 ) -> Decimal: 

422 """ 

423 Calculate total cost with optional filters. 

424 

425 Args: 

426 user_id: Filter by user (optional) 

427 period: Time period (optional) 

428 

429 Returns: 

430 Total cost in USD 

431 """ 

432 records = await self.get_records(period=period or "day", user_id=user_id) 

433 return sum((r.estimated_cost_usd for r in records), Decimal("0")) 

434 

435 

436# ============================================================================== 

437# Cost Aggregator 

438# ============================================================================== 

439 

440 

441class CostAggregator: 

442 """ 

443 Aggregates cost data by multiple dimensions. 

444 

445 Provides: 

446 - Cost by model 

447 - Cost by user 

448 - Cost by feature 

449 - Total cost calculation 

450 """ 

451 

452 async def aggregate_by_model(self, records: list[dict[str, Any]]) -> dict[str, Decimal]: 

453 """ 

454 Aggregate costs by model. 

455 

456 Args: 

457 records: List of cost records 

458 

459 Returns: 

460 Dict mapping model names to total costs 

461 """ 

462 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0")) 

463 

464 for record in records: 

465 model = record["model"] 

466 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"])) 

467 aggregated[model] += cost 

468 

469 return dict(aggregated) 

470 

471 async def aggregate_by_user(self, records: list[dict[str, Any]]) -> dict[str, Decimal]: 

472 """ 

473 Aggregate costs by user. 

474 

475 Args: 

476 records: List of cost records 

477 

478 Returns: 

479 Dict mapping user IDs to total costs 

480 """ 

481 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0")) 

482 

483 for record in records: 

484 user_id = record["user_id"] 

485 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"])) 

486 aggregated[user_id] += cost 

487 

488 return dict(aggregated) 

489 

490 async def aggregate_by_feature(self, records: list[dict[str, Any]]) -> dict[str, Decimal]: 

491 """ 

492 Aggregate costs by feature. 

493 

494 Args: 

495 records: List of cost records 

496 

497 Returns: 

498 Dict mapping feature names to total costs 

499 """ 

500 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0")) 

501 

502 for record in records: 

503 feature = record.get("feature", "unknown") 

504 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"])) 

505 aggregated[feature] += cost 

506 

507 return dict(aggregated) 

508 

509 async def calculate_total(self, records: list[dict[str, Any]]) -> Decimal: 

510 """ 

511 Calculate total cost across all records. 

512 

513 Args: 

514 records: List of cost records 

515 

516 Returns: 

517 Total cost in USD 

518 """ 

519 total = Decimal("0") 

520 

521 for record in records: 

522 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"])) 

523 total += cost 

524 

525 return total 

526 

527 

528# ============================================================================== 

529# Singleton Instance 

530# ============================================================================== 

531 

532_collector_instance: CostMetricsCollector | None = None 

533 

534 

535def get_cost_collector() -> CostMetricsCollector: 

536 """Get or create singleton cost collector instance.""" 

537 global _collector_instance 

538 if _collector_instance is None: 

539 _collector_instance = CostMetricsCollector() 

540 return _collector_instance