Coverage for src / mcp_server_langgraph / monitoring / cost_tracker.py: 71%
165 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Cost Metrics Collector
4Tracks token usage and costs for LLM API calls with async recording,
5Prometheus metrics integration, and PostgreSQL persistence.
7Example:
8 >>> from mcp_server_langgraph.monitoring.cost_tracker import CostMetricsCollector
9 >>> collector = CostMetricsCollector()
10 >>> await collector.record_usage(
11 ... user_id="user123",
12 ... session_id="session456",
13 ... model="claude-sonnet-4-5-20250929",
14 ... provider="anthropic",
15 ... prompt_tokens=1000,
16 ... completion_tokens=500
17 ... )
18"""
20import asyncio
21from collections import defaultdict
22from dataclasses import dataclass, field
23from datetime import datetime, timedelta, UTC
24from decimal import Decimal
25from typing import Any, cast
27from pydantic import BaseModel, ConfigDict, Field, field_serializer
29from .pricing import calculate_cost
31# ==============================================================================
32# Data Models
33# ==============================================================================
36class TokenUsage(BaseModel):
37 """Token usage record for a single LLM call."""
39 timestamp: datetime = Field(description="When the call was made")
40 user_id: str = Field(description="User who made the call")
41 session_id: str = Field(description="Session identifier")
42 model: str = Field(description="Model name")
43 provider: str = Field(description="Provider (anthropic, openai, google)")
44 prompt_tokens: int = Field(description="Number of input tokens", ge=0)
45 completion_tokens: int = Field(description="Number of output tokens", ge=0)
46 total_tokens: int = Field(description="Total tokens (prompt + completion)", ge=0)
47 estimated_cost_usd: Decimal = Field(description="Estimated cost in USD")
48 feature: str | None = Field(default=None, description="Feature that triggered the call")
49 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
51 model_config = ConfigDict()
53 @field_serializer("estimated_cost_usd")
54 def serialize_decimal(self, value: Decimal) -> str:
55 """Serialize Decimal as string for JSON compatibility."""
56 return str(value)
58 @field_serializer("timestamp")
59 def serialize_timestamp(self, value: datetime) -> str:
60 """Serialize datetime as ISO 8601 string."""
61 return value.isoformat()
63 def __init__(self, **data: Any):
64 # Calculate total_tokens if not provided
65 if "total_tokens" not in data:
66 data["total_tokens"] = data.get("prompt_tokens", 0) + data.get("completion_tokens", 0)
67 super().__init__(**data)
70# ==============================================================================
71# Prometheus Metrics
72# ==============================================================================
74try:
75 from prometheus_client import Counter
77 # Real Prometheus metrics for production
78 llm_token_usage = Counter(
79 name="llm_token_usage_total",
80 documentation="Total tokens used by LLM calls",
81 labelnames=["provider", "model", "token_type"],
82 )
84 llm_cost = Counter(
85 name="llm_cost_usd_total",
86 documentation="Total estimated cost in USD (cumulative)",
87 labelnames=["provider", "model"],
88 )
90 PROMETHEUS_AVAILABLE = True
92except ImportError:
93 # Fallback to mock counters if prometheus_client not available
94 import warnings
96 warnings.warn(
97 "prometheus_client not available, using mock counters. "
98 "Install prometheus-client for production metrics: pip install prometheus-client"
99 )
101 @dataclass
102 class MockPrometheusCounter:
103 """Mock Prometheus counter for testing."""
105 name: str
106 description: str
107 labelnames: list[str]
108 _values: dict[tuple[str, ...], float] = field(default_factory=lambda: defaultdict(float))
110 def labels(self, **labels: str) -> "MockCounterChild":
111 """Return label-specific counter."""
112 label_tuple = tuple(labels.get(name, "") for name in self.labelnames)
113 return MockCounterChild(self, label_tuple)
115 @dataclass
116 class MockCounterChild:
117 """Child counter with specific labels."""
119 parent: MockPrometheusCounter
120 label_values: tuple[str, ...]
122 def inc(self, amount: float = 1.0) -> None:
123 """Increment counter."""
124 self.parent._values[self.label_values] += amount
126 # Mock Prometheus metrics (fallback when prometheus_client not installed)
127 llm_token_usage = cast(
128 Counter,
129 MockPrometheusCounter(
130 name="llm_token_usage_total",
131 description="Total tokens used by LLM calls",
132 labelnames=["provider", "model", "token_type"],
133 ),
134 )
136 llm_cost = cast(
137 Counter,
138 MockPrometheusCounter(
139 name="llm_cost_usd_total",
140 description="Total estimated cost in USD",
141 labelnames=["provider", "model"],
142 ),
143 )
145 PROMETHEUS_AVAILABLE = False
148# ==============================================================================
149# Cost Metrics Collector
150# ==============================================================================
153class CostMetricsCollector:
154 """
155 Collects and persists LLM cost metrics.
157 Features:
158 - Async recording to avoid blocking API calls
159 - Automatic cost calculation
160 - Prometheus metrics integration
161 - PostgreSQL persistence with retention policy
162 - In-memory fallback when database unavailable
163 """
165 def __init__(
166 self,
167 database_url: str | None = None,
168 retention_days: int = 90,
169 enable_persistence: bool = True,
170 ) -> None:
171 """
172 Initialize the cost metrics collector.
174 Args:
175 database_url: PostgreSQL connection URL (postgresql+asyncpg://...)
176 If None, uses in-memory storage only
177 retention_days: Number of days to retain records (default: 90)
178 enable_persistence: Whether to enable PostgreSQL persistence
179 """
180 self._records: list[TokenUsage] = []
181 self._lock = asyncio.Lock()
182 self._database_url = database_url
183 self._retention_days = retention_days
184 self._enable_persistence = enable_persistence and database_url is not None
186 @property
187 def total_records(self) -> int:
188 """Get total number of records."""
189 return len(self._records)
191 async def record_usage(
192 self,
193 timestamp: datetime,
194 user_id: str,
195 session_id: str,
196 model: str,
197 provider: str,
198 prompt_tokens: int,
199 completion_tokens: int,
200 estimated_cost_usd: Decimal | None = None,
201 feature: str | None = None,
202 metadata: dict[str, Any] | None = None,
203 ) -> TokenUsage:
204 """
205 Record token usage for an LLM call.
207 Args:
208 timestamp: When the call was made
209 user_id: User identifier
210 session_id: Session identifier
211 model: Model name
212 provider: Provider name
213 prompt_tokens: Input token count
214 completion_tokens: Output token count
215 estimated_cost_usd: Pre-calculated cost (optional)
216 feature: Feature name (optional)
217 metadata: Additional metadata (optional)
219 Returns:
220 TokenUsage record
222 Example:
223 >>> collector = CostMetricsCollector()
224 >>> usage = await collector.record_usage(
225 ... timestamp=datetime.now(timezone.utc),
226 ... user_id="user123",
227 ... session_id="session456",
228 ... model="claude-sonnet-4-5-20250929",
229 ... provider="anthropic",
230 ... prompt_tokens=1000,
231 ... completion_tokens=500
232 ... )
233 """
234 # Calculate cost if not provided
235 if estimated_cost_usd is None:
236 estimated_cost_usd = calculate_cost(
237 model=model,
238 provider=provider,
239 prompt_tokens=prompt_tokens,
240 completion_tokens=completion_tokens,
241 )
243 # Create usage record
244 usage = TokenUsage(
245 timestamp=timestamp,
246 user_id=user_id,
247 session_id=session_id,
248 model=model,
249 provider=provider,
250 prompt_tokens=prompt_tokens,
251 completion_tokens=completion_tokens,
252 estimated_cost_usd=estimated_cost_usd,
253 feature=feature,
254 metadata=metadata or {},
255 )
257 # Store record in-memory (thread-safe)
258 async with self._lock:
259 self._records.append(usage)
261 # Persist to PostgreSQL if enabled
262 if self._enable_persistence: 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true
263 try:
264 await self._persist_to_database(usage)
265 except Exception as e:
266 # Log error but don't fail the recording
267 import logging
269 logger = logging.getLogger(__name__)
270 logger.exception(f"Failed to persist usage record to database: {e}")
272 # Update Prometheus metrics
273 llm_token_usage.labels(
274 provider=provider,
275 model=model,
276 token_type="input",
277 ).inc(prompt_tokens)
279 llm_token_usage.labels(
280 provider=provider,
281 model=model,
282 token_type="output",
283 ).inc(completion_tokens)
285 llm_cost.labels(
286 provider=provider,
287 model=model,
288 ).inc(float(estimated_cost_usd))
290 return usage
292 async def _persist_to_database(self, usage: TokenUsage) -> None:
293 """
294 Persist usage record to PostgreSQL.
296 Args:
297 usage: TokenUsage record to persist
298 """
299 if not self._database_url:
300 return
302 from mcp_server_langgraph.database import get_async_session
303 from mcp_server_langgraph.database.models import TokenUsageRecord
305 async with get_async_session(self._database_url) as session:
306 # Create database record
307 db_record = TokenUsageRecord(
308 timestamp=usage.timestamp,
309 user_id=usage.user_id,
310 session_id=usage.session_id,
311 model=usage.model,
312 provider=usage.provider,
313 prompt_tokens=usage.prompt_tokens,
314 completion_tokens=usage.completion_tokens,
315 total_tokens=usage.total_tokens,
316 estimated_cost_usd=usage.estimated_cost_usd,
317 feature=usage.feature,
318 metadata_=usage.metadata,
319 )
321 session.add(db_record)
322 # Session commits automatically via context manager
324 async def cleanup_old_records(self) -> int:
325 """
326 Remove records older than retention period.
328 This method removes both in-memory and PostgreSQL records that exceed
329 the configured retention period (default: 90 days).
331 Returns:
332 Number of records deleted
334 Example:
335 >>> collector = CostMetricsCollector(database_url="...", retention_days=90)
336 >>> deleted = await collector.cleanup_old_records()
337 >>> print(f"Deleted {deleted} old records")
338 """
340 cutoff_time = datetime.now(UTC) - timedelta(days=self._retention_days)
341 deleted_count = 0
343 # Clean up in-memory records
344 async with self._lock:
345 initial_count = len(self._records)
346 self._records = [r for r in self._records if r.timestamp >= cutoff_time]
347 deleted_count = initial_count - len(self._records)
349 # Clean up PostgreSQL records
350 if self._enable_persistence: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 try:
352 from sqlalchemy import delete
354 from mcp_server_langgraph.database import get_async_session
355 from mcp_server_langgraph.database.models import TokenUsageRecord
357 # Type guard: _database_url is guaranteed non-None when _enable_persistence is True
358 assert self._database_url is not None, "database_url must be set when persistence is enabled"
360 async with get_async_session(self._database_url) as session:
361 stmt = delete(TokenUsageRecord).where(TokenUsageRecord.timestamp < cutoff_time)
362 result = await session.execute(stmt)
363 db_deleted = result.rowcount or 0 # type: ignore[attr-defined]
364 deleted_count += db_deleted
366 import logging
368 logger = logging.getLogger(__name__)
369 logger.info(
370 f"Cleaned up {deleted_count} records older than {self._retention_days} days "
371 f"(cutoff: {cutoff_time.isoformat()})"
372 )
373 except Exception as e:
374 import logging
376 logger = logging.getLogger(__name__)
377 logger.exception(f"Failed to cleanup database records: {e}")
379 return deleted_count
381 async def get_latest_record(self) -> TokenUsage | None:
382 """Get the most recent usage record."""
383 async with self._lock:
384 return self._records[-1] if self._records else None
386 async def get_records(
387 self,
388 period: str = "day",
389 user_id: str | None = None,
390 model: str | None = None,
391 ) -> list[TokenUsage]:
392 """
393 Get usage records with optional filtering.
395 Args:
396 period: Time period ("day", "week", "month")
397 user_id: Filter by user (optional)
398 model: Filter by model (optional)
400 Returns:
401 List of TokenUsage records
402 """
403 async with self._lock:
404 records = self._records.copy()
406 # Apply filters
407 if user_id:
408 records = [r for r in records if r.user_id == user_id]
410 if model:
411 records = [r for r in records if r.model == model]
413 # TODO: Apply time period filter
415 return records
417 async def get_total_cost(
418 self,
419 user_id: str | None = None,
420 period: str | None = None,
421 ) -> Decimal:
422 """
423 Calculate total cost with optional filters.
425 Args:
426 user_id: Filter by user (optional)
427 period: Time period (optional)
429 Returns:
430 Total cost in USD
431 """
432 records = await self.get_records(period=period or "day", user_id=user_id)
433 return sum((r.estimated_cost_usd for r in records), Decimal("0"))
436# ==============================================================================
437# Cost Aggregator
438# ==============================================================================
441class CostAggregator:
442 """
443 Aggregates cost data by multiple dimensions.
445 Provides:
446 - Cost by model
447 - Cost by user
448 - Cost by feature
449 - Total cost calculation
450 """
452 async def aggregate_by_model(self, records: list[dict[str, Any]]) -> dict[str, Decimal]:
453 """
454 Aggregate costs by model.
456 Args:
457 records: List of cost records
459 Returns:
460 Dict mapping model names to total costs
461 """
462 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0"))
464 for record in records:
465 model = record["model"]
466 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"]))
467 aggregated[model] += cost
469 return dict(aggregated)
471 async def aggregate_by_user(self, records: list[dict[str, Any]]) -> dict[str, Decimal]:
472 """
473 Aggregate costs by user.
475 Args:
476 records: List of cost records
478 Returns:
479 Dict mapping user IDs to total costs
480 """
481 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0"))
483 for record in records:
484 user_id = record["user_id"]
485 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"]))
486 aggregated[user_id] += cost
488 return dict(aggregated)
490 async def aggregate_by_feature(self, records: list[dict[str, Any]]) -> dict[str, Decimal]:
491 """
492 Aggregate costs by feature.
494 Args:
495 records: List of cost records
497 Returns:
498 Dict mapping feature names to total costs
499 """
500 aggregated: dict[str, Decimal] = defaultdict(lambda: Decimal("0"))
502 for record in records:
503 feature = record.get("feature", "unknown")
504 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"]))
505 aggregated[feature] += cost
507 return dict(aggregated)
509 async def calculate_total(self, records: list[dict[str, Any]]) -> Decimal:
510 """
511 Calculate total cost across all records.
513 Args:
514 records: List of cost records
516 Returns:
517 Total cost in USD
518 """
519 total = Decimal("0")
521 for record in records:
522 cost = record["cost"] if isinstance(record["cost"], Decimal) else Decimal(str(record["cost"]))
523 total += cost
525 return total
528# ==============================================================================
529# Singleton Instance
530# ==============================================================================
532_collector_instance: CostMetricsCollector | None = None
535def get_cost_collector() -> CostMetricsCollector:
536 """Get or create singleton cost collector instance."""
537 global _collector_instance
538 if _collector_instance is None:
539 _collector_instance = CostMetricsCollector()
540 return _collector_instance