Coverage for src / mcp_server_langgraph / llm / factory.py: 84%
251 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-08 06:31 +0000
1"""
2LiteLLM Factory for Multi-Provider LLM Support
4Supports: Anthropic, OpenAI, Google (Gemini), Azure OpenAI, AWS Bedrock,
5Ollama (Llama, Qwen, Mistral, etc.)
7Enhanced with resilience patterns (ADR-0026):
8- Circuit breaker for provider failures
9- Retry logic with exponential backoff
10- Timeout enforcement
11- Bulkhead isolation (10 concurrent LLM calls max, provider-aware)
12- Exponential backoff between fallback attempts
13"""
15import asyncio
16import os
18# Fallback resilience constants
19FALLBACK_BASE_DELAY_SECONDS = 1.0 # Initial delay between fallback attempts
20FALLBACK_DELAY_MULTIPLIER = 2.0 # Exponential multiplier
21FALLBACK_MAX_DELAY_SECONDS = 8.0 # Cap for fallback delays
22from typing import Any
24from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
25from litellm import acompletion, completion
26from litellm.utils import ModelResponse # type: ignore[attr-defined]
28from mcp_server_langgraph.core.exceptions import (
29 LLMModelNotFoundError,
30 LLMOverloadError,
31 LLMProviderError,
32 LLMRateLimitError,
33 LLMTimeoutError,
34)
35from mcp_server_langgraph.resilience.retry import extract_retry_after_from_exception, is_overload_error
36from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
37from mcp_server_langgraph.resilience import circuit_breaker, retry_with_backoff, with_bulkhead, with_timeout
40class LLMFactory:
41 """
42 Factory for creating and managing LLM connections via LiteLLM
44 Supports multiple providers with automatic fallback and retry logic.
45 """
47 def __init__( # type: ignore[no-untyped-def]
48 self,
49 provider: str = "google",
50 model_name: str = "gemini-2.5-flash",
51 api_key: str | None = None,
52 temperature: float = 0.7,
53 max_tokens: int = 4096,
54 timeout: int = 60,
55 enable_fallback: bool = True,
56 fallback_models: list[str] | None = None,
57 **kwargs,
58 ):
59 """
60 Initialize LLM Factory
62 Args:
63 provider: LLM provider (anthropic, openai, google, azure, bedrock, ollama)
64 model_name: Model identifier
65 api_key: API key for the provider
66 temperature: Sampling temperature (0-1)
67 max_tokens: Maximum tokens to generate
68 timeout: Request timeout in seconds
69 enable_fallback: Enable fallback to alternative models
70 fallback_models: List of fallback model names
71 **kwargs: Additional provider-specific parameters
72 """
73 self.provider = provider
74 self.model_name = model_name
75 self.api_key = api_key
76 self.temperature = temperature
77 self.max_tokens = max_tokens
78 self.timeout = timeout
79 self.enable_fallback = enable_fallback
80 self.fallback_models = fallback_models or []
81 self.kwargs = kwargs
83 # Note: _setup_environment is now called by factory functions with config
84 # This allows multi-provider credential setup for fallbacks
86 logger.info(
87 "LLM Factory initialized",
88 extra={
89 "provider": provider,
90 "model": model_name,
91 "fallback_enabled": enable_fallback,
92 "fallback_count": len(fallback_models) if fallback_models else 0,
93 },
94 )
96 def _get_provider_from_model(self, model_name: str) -> str:
97 """
98 Extract provider from model name.
100 Args:
101 model_name: Model identifier (e.g., "gpt-5", "claude-sonnet-4-5", "gemini-2.5-flash")
103 Returns:
104 Provider name (e.g., "openai", "anthropic", "google")
105 """
106 model_lower = model_name.lower()
108 # Check provider prefixes FIRST (azure/gpt-4 should be azure, not openai)
109 # Vertex AI (Google Cloud AI Platform)
110 if model_lower.startswith("vertex_ai/"):
111 return "vertex_ai"
113 # Azure (prefixed models)
114 if model_lower.startswith("azure/"):
115 return "azure"
117 # Bedrock (prefixed models)
118 if model_lower.startswith("bedrock/"):
119 return "bedrock"
121 # Ollama (local models)
122 if model_lower.startswith("ollama/"):
123 return "ollama"
125 # Then check model name patterns
126 # Anthropic models
127 if any(x in model_lower for x in ["claude", "anthropic"]):
128 return "anthropic"
130 # OpenAI models
131 if any(x in model_lower for x in ["gpt-", "o1-", "davinci", "curie", "babbage"]):
132 return "openai"
134 # Google models
135 if any(x in model_lower for x in ["gemini", "palm", "text-bison", "chat-bison"]):
136 return "google"
138 # Default to current provider
139 return self.provider
141 def _get_provider_kwargs(self, model_name: str) -> dict[str, Any]:
142 """
143 Get provider-specific kwargs for a given model.
145 BUGFIX: Filters out provider-specific parameters that don't apply to the target provider.
146 For example, Azure-specific params (azure_endpoint, azure_deployment) are removed
147 when calling Anthropic or OpenAI fallback models.
149 Args:
150 model_name: Model identifier to get kwargs for
152 Returns:
153 dict: Filtered kwargs appropriate for the model's provider
154 """
155 provider = self._get_provider_from_model(model_name)
157 # Define provider-specific parameter prefixes
158 provider_specific_prefixes = {
159 "azure": ["azure_", "api_version"],
160 "bedrock": ["aws_", "bedrock_"],
161 "vertex": ["vertex_", "gcp_"],
162 }
164 # If this is the same provider as the primary model, return all kwargs
165 if provider == self.provider:
166 return self.kwargs
168 # Filter out parameters specific to other providers
169 filtered_kwargs = {}
170 for key, value in self.kwargs.items(): 170 ↛ 172line 170 didn't jump to line 172 because the loop on line 170 never started
171 # Check if this parameter belongs to a different provider
172 is_provider_specific = False
173 for other_provider, prefixes in provider_specific_prefixes.items():
174 if other_provider != provider and any(key.startswith(prefix) for prefix in prefixes):
175 is_provider_specific = True
176 break
178 # Include parameter if it's not specific to another provider
179 if not is_provider_specific:
180 filtered_kwargs[key] = value
182 logger.debug(
183 "Filtered kwargs for fallback model",
184 extra={
185 "model": model_name,
186 "provider": provider,
187 "original_kwargs": list(self.kwargs.keys()),
188 "filtered_kwargs": list(filtered_kwargs.keys()),
189 },
190 )
192 return filtered_kwargs
194 def _setup_environment(self, config=None) -> None: # type: ignore[no-untyped-def]
195 """
196 Set up environment variables for LiteLLM.
198 Configures credentials for primary provider AND all fallback providers
199 to enable seamless multi-provider fallback.
201 Args:
202 config: Settings object with API keys for all providers (optional)
203 """
204 # Collect all providers needed (primary + fallbacks)
205 providers_needed = {self.provider}
207 # Add providers for each fallback model
208 for fallback_model in self.fallback_models:
209 provider = self._get_provider_from_model(fallback_model)
210 providers_needed.add(provider)
212 # Map provider to environment variable and config attribute
213 # Provider-specific credential mapping (some providers need multiple env vars)
214 provider_config_map = {
215 "anthropic": [("ANTHROPIC_API_KEY", "anthropic_api_key")],
216 "openai": [("OPENAI_API_KEY", "openai_api_key")],
217 "google": [("GOOGLE_API_KEY", "google_api_key")],
218 "gemini": [("GOOGLE_API_KEY", "google_api_key")],
219 "azure": [
220 ("AZURE_API_KEY", "azure_api_key"),
221 ("AZURE_API_BASE", "azure_api_base"),
222 ("AZURE_API_VERSION", "azure_api_version"),
223 ("AZURE_DEPLOYMENT_NAME", "azure_deployment_name"),
224 ],
225 "bedrock": [
226 ("AWS_ACCESS_KEY_ID", "aws_access_key_id"),
227 ("AWS_SECRET_ACCESS_KEY", "aws_secret_access_key"),
228 ("AWS_REGION", "aws_region"),
229 ],
230 }
232 # Set up credentials for each needed provider
233 for provider in providers_needed:
234 if provider not in provider_config_map:
235 continue # Skip unknown providers (e.g., ollama doesn't need API key)
237 credential_pairs = provider_config_map[provider]
239 for env_var, config_attr in credential_pairs:
240 # Get value from config
241 value = getattr(config, config_attr) if config and hasattr(config, config_attr) else None
243 # For primary provider, use self.api_key as fallback for API key fields
244 if value is None and provider == self.provider and "api_key" in config_attr.lower():
245 value = self.api_key
247 # Set environment variable if we have a value
248 if value and env_var:
249 os.environ[env_var] = str(value)
250 logger.debug(
251 f"Configured credential for provider: {provider}", extra={"provider": provider, "env_var": env_var}
252 )
254 def _format_messages(self, messages: list[BaseMessage | dict[str, Any]]) -> list[dict[str, str]]:
255 """
256 Convert LangChain messages to LiteLLM format
258 Args:
259 messages: List of LangChain BaseMessage objects or dicts
261 Returns:
262 List of dictionaries in LiteLLM format
263 """
264 formatted = []
265 for msg in messages:
266 if isinstance(msg, HumanMessage):
267 content = msg.content if isinstance(msg.content, str) else str(msg.content)
268 formatted.append({"role": "user", "content": content})
269 elif isinstance(msg, AIMessage):
270 content = msg.content if isinstance(msg.content, str) else str(msg.content)
271 formatted.append({"role": "assistant", "content": content})
272 elif isinstance(msg, SystemMessage):
273 content = msg.content if isinstance(msg.content, str) else str(msg.content)
274 formatted.append({"role": "system", "content": content})
275 elif isinstance(msg, dict):
276 # Handle dict messages (already in correct format or need conversion)
277 if "role" in msg and "content" in msg:
278 # Already formatted dict
279 formatted.append(msg)
280 elif "content" in msg:
281 # Dict with content but no role
282 formatted.append({"role": "user", "content": str(msg["content"])})
283 else:
284 # Malformed dict, convert to string
285 formatted.append({"role": "user", "content": str(msg)})
286 else:
287 # Fallback for other types - check if it has content attribute
288 if hasattr(msg, "content"):
289 formatted.append({"role": "user", "content": str(msg.content)})
290 else:
291 # Last resort: convert entire object to string
292 formatted.append({"role": "user", "content": str(msg)})
294 return formatted
296 def invoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
297 """
298 Synchronous LLM invocation
300 Args:
301 messages: List of messages
302 **kwargs: Additional parameters for the model
304 Returns:
305 AIMessage with the response
306 """
307 with tracer.start_as_current_span("llm.invoke") as span:
308 span.set_attribute("llm.provider", self.provider)
309 span.set_attribute("llm.model", self.model_name)
311 formatted_messages = self._format_messages(messages)
313 # Merge kwargs with defaults
314 params = {
315 "model": self.model_name,
316 "messages": formatted_messages,
317 "temperature": kwargs.get("temperature", self.temperature),
318 "max_tokens": kwargs.get("max_tokens", self.max_tokens),
319 "timeout": kwargs.get("timeout", self.timeout),
320 **self.kwargs,
321 }
323 try:
324 response: ModelResponse = completion(**params)
326 content = response.choices[0].message.content # type: ignore[union-attr]
328 # Track metrics
329 metrics.successful_calls.add(1, {"operation": "llm.invoke", "model": self.model_name})
331 logger.info(
332 "LLM invocation successful",
333 extra={
334 "model": self.model_name,
335 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined]
336 },
337 )
339 return AIMessage(content=content)
341 except Exception as e:
342 logger.error(
343 f"LLM invocation failed: {e}", extra={"model": self.model_name, "provider": self.provider}, exc_info=True
344 )
346 metrics.failed_calls.add(1, {"operation": "llm.invoke", "model": self.model_name})
347 span.record_exception(e)
349 # Try fallback if enabled
350 if self.enable_fallback and self.fallback_models: 350 ↛ 353line 350 didn't jump to line 353 because the condition on line 350 was always true
351 return self._try_fallback(messages, **kwargs)
353 raise
355 @circuit_breaker(name="llm", fail_max=5, timeout=60)
356 @retry_with_backoff(max_attempts=3, exponential_base=2)
357 @with_timeout(operation_type="llm")
358 @with_bulkhead(resource_type="llm")
359 async def ainvoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
360 """
361 Asynchronous LLM invocation with full resilience protection.
363 Protected by:
364 - Circuit breaker: Fail fast if LLM provider is down (5 failures → open, 60s timeout)
365 - Retry logic: Up to 3 attempts with exponential backoff (1s, 2s, 4s)
366 - Timeout: 60s timeout for LLM operations
367 - Bulkhead: Limit to 10 concurrent LLM calls
369 Args:
370 messages: List of messages
371 **kwargs: Additional parameters for the model
373 Returns:
374 AIMessage with the response
376 Raises:
377 CircuitBreakerOpenError: If circuit breaker is open
378 RetryExhaustedError: If all retry attempts failed
379 TimeoutError: If operation exceeds 60s timeout
380 BulkheadRejectedError: If too many concurrent LLM calls
381 LLMProviderError: For other LLM provider errors
382 """
383 with tracer.start_as_current_span("llm.ainvoke") as span:
384 span.set_attribute("llm.provider", self.provider)
385 span.set_attribute("llm.model", self.model_name)
387 formatted_messages = self._format_messages(messages)
389 params = {
390 "model": self.model_name,
391 "messages": formatted_messages,
392 "temperature": kwargs.get("temperature", self.temperature),
393 "max_tokens": kwargs.get("max_tokens", self.max_tokens),
394 "timeout": kwargs.get("timeout", self.timeout),
395 **self.kwargs,
396 }
398 try:
399 response: ModelResponse = await acompletion(**params)
401 content = response.choices[0].message.content # type: ignore[union-attr]
403 metrics.successful_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name})
405 logger.info(
406 "Async LLM invocation successful",
407 extra={
408 "model": self.model_name,
409 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined]
410 },
411 )
413 return AIMessage(content=content)
415 except Exception as e:
416 # Convert to custom exceptions for better error handling
417 error_msg = str(e).lower()
419 # Check for overload errors (529 or equivalent) - handle first
420 if is_overload_error(e): 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true
421 retry_after = extract_retry_after_from_exception(e)
422 raise LLMOverloadError(
423 message=f"LLM provider overloaded: {e}",
424 retry_after=retry_after,
425 metadata={
426 "model": self.model_name,
427 "provider": self.provider,
428 "retry_after": retry_after,
429 },
430 cause=e,
431 )
432 elif "rate limit" in error_msg or "429" in error_msg: 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was never true
433 # Extract Retry-After header for rate limit errors (same pattern as overload)
434 retry_after = extract_retry_after_from_exception(e)
435 raise LLMRateLimitError(
436 message=f"LLM provider rate limit exceeded: {e}",
437 retry_after=retry_after,
438 metadata={
439 "model": self.model_name,
440 "provider": self.provider,
441 "retry_after": retry_after,
442 },
443 cause=e,
444 )
445 elif "timeout" in error_msg or "timed out" in error_msg: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true
446 raise LLMTimeoutError(
447 message=f"LLM request timed out: {e}",
448 metadata={"model": self.model_name, "provider": self.provider, "timeout": self.timeout},
449 cause=e,
450 )
451 elif "model not found" in error_msg or "404" in error_msg: 451 ↛ 452line 451 didn't jump to line 452 because the condition on line 451 was never true
452 raise LLMModelNotFoundError(
453 message=f"LLM model not found: {e}",
454 metadata={"model": self.model_name, "provider": self.provider},
455 cause=e,
456 )
457 else:
458 logger.error(
459 f"Async LLM invocation failed: {e}",
460 extra={"model": self.model_name, "provider": self.provider},
461 exc_info=True,
462 )
464 metrics.failed_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name})
465 span.record_exception(e)
467 # Try fallback if enabled
468 if self.enable_fallback and self.fallback_models: 468 ↛ 471line 468 didn't jump to line 471 because the condition on line 468 was always true
469 return await self._try_fallback_async(messages, **kwargs)
471 raise LLMProviderError(
472 message=f"LLM provider error: {e}",
473 metadata={"model": self.model_name, "provider": self.provider},
474 cause=e,
475 )
477 def _try_fallback(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
478 """Try fallback models if primary fails"""
479 for fallback_model in self.fallback_models:
480 if fallback_model == self.model_name: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true
481 continue # Skip if it's the same model
483 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name})
485 try:
486 formatted_messages = self._format_messages(messages)
487 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors
488 provider_kwargs = self._get_provider_kwargs(fallback_model)
489 response = completion(
490 model=fallback_model,
491 messages=formatted_messages,
492 temperature=self.temperature,
493 max_tokens=self.max_tokens,
494 timeout=self.timeout,
495 **provider_kwargs, # Forward provider-specific kwargs only
496 )
498 content = response.choices[0].message.content
500 logger.info("Fallback successful", extra={"fallback_model": fallback_model})
502 metrics.successful_calls.add(1, {"operation": "llm.fallback", "model": fallback_model})
504 return AIMessage(content=content)
506 except Exception as e:
507 logger.error(f"Fallback model {fallback_model} failed: {e}", exc_info=True)
508 continue
510 msg = "All models failed including fallbacks"
511 raise RuntimeError(msg)
513 async def _try_fallback_async(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
514 """Try fallback models asynchronously with exponential backoff.
516 Implements resilience patterns:
517 - Exponential backoff between attempts (1s, 2s, 4s, 8s cap)
518 - Skips primary model if accidentally in fallback list
519 - Logs each attempt for observability
520 """
521 attempt = 0
522 current_delay = FALLBACK_BASE_DELAY_SECONDS
524 for fallback_model in self.fallback_models:
525 if fallback_model == self.model_name:
526 continue
528 # Apply exponential backoff delay before attempt (except first)
529 if attempt > 0:
530 delay = min(current_delay, FALLBACK_MAX_DELAY_SECONDS)
531 logger.info(
532 f"Waiting {delay:.1f}s before fallback attempt {attempt + 1}",
533 extra={"delay_seconds": delay, "fallback_model": fallback_model},
534 )
535 await asyncio.sleep(delay)
536 current_delay *= FALLBACK_DELAY_MULTIPLIER
538 attempt += 1
539 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name})
541 try:
542 formatted_messages = self._format_messages(messages)
543 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors
544 provider_kwargs = self._get_provider_kwargs(fallback_model)
545 response = await acompletion(
546 model=fallback_model,
547 messages=formatted_messages,
548 temperature=self.temperature,
549 max_tokens=self.max_tokens,
550 timeout=self.timeout,
551 **provider_kwargs, # Forward provider-specific kwargs only
552 )
554 content = response.choices[0].message.content
556 logger.info("Async fallback successful", extra={"fallback_model": fallback_model})
558 metrics.successful_calls.add(1, {"operation": "llm.fallback_async", "model": fallback_model})
560 return AIMessage(content=content)
562 except Exception as e:
563 logger.error(f"Async fallback model {fallback_model} failed: {e}", exc_info=True)
564 continue
566 msg = "All async models failed including fallbacks"
567 raise RuntimeError(msg)
570def create_llm_from_config(config) -> LLMFactory: # type: ignore[no-untyped-def]
571 """
572 Create primary LLM instance from configuration
574 Args:
575 config: Settings object with LLM configuration
577 Returns:
578 Configured LLMFactory instance for primary chat operations
579 """
580 # Determine API key based on provider
581 api_key_map = {
582 "anthropic": config.anthropic_api_key,
583 "openai": config.openai_api_key,
584 "google": config.google_api_key,
585 "gemini": config.google_api_key,
586 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
587 "azure": config.azure_api_key,
588 "bedrock": config.aws_access_key_id,
589 }
591 api_key = api_key_map.get(config.llm_provider)
593 # Provider-specific kwargs
594 provider_kwargs = {}
596 if config.llm_provider == "azure": 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true
597 provider_kwargs.update(
598 {
599 "api_base": config.azure_api_base,
600 "api_version": config.azure_api_version,
601 }
602 )
603 elif config.llm_provider == "bedrock": 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true
604 provider_kwargs.update(
605 {
606 "aws_secret_access_key": config.aws_secret_access_key,
607 "aws_region_name": config.aws_region,
608 }
609 )
610 elif config.llm_provider == "ollama": 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true
611 provider_kwargs.update(
612 {
613 "api_base": config.ollama_base_url,
614 }
615 )
616 elif config.llm_provider in ["vertex_ai", "google"]: 616 ↛ 629line 616 didn't jump to line 629 because the condition on line 616 was always true
617 # Vertex AI configuration
618 # LiteLLM requires vertex_project and vertex_location for Vertex AI models
619 # If using Workload Identity on GKE, authentication is automatic
620 vertex_project = config.vertex_project or config.google_project_id
621 if vertex_project: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true
622 provider_kwargs.update(
623 {
624 "vertex_project": vertex_project,
625 "vertex_location": config.vertex_location,
626 }
627 )
629 factory = LLMFactory(
630 provider=config.llm_provider,
631 model_name=config.model_name,
632 api_key=api_key,
633 temperature=config.model_temperature,
634 max_tokens=config.model_max_tokens,
635 timeout=config.model_timeout,
636 enable_fallback=config.enable_fallback,
637 fallback_models=config.fallback_models,
638 **provider_kwargs,
639 )
641 # Set up credentials for primary + all fallback providers
642 factory._setup_environment(config=config)
644 return factory
647def create_summarization_model(config) -> LLMFactory: # type: ignore[no-untyped-def]
648 """
649 Create dedicated LLM instance for summarization (cost-optimized).
651 Uses lighter/cheaper model for context compaction to reduce costs.
652 Falls back to primary model if dedicated model not configured.
654 Args:
655 config: Settings object with LLM configuration
657 Returns:
658 Configured LLMFactory instance for summarization
659 """
660 # If dedicated summarization model not enabled, use primary model
661 if not getattr(config, "use_dedicated_summarization_model", False): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true
662 return create_llm_from_config(config)
664 # Determine provider and API key
665 provider = config.summarization_model_provider or config.llm_provider
667 api_key_map = {
668 "anthropic": config.anthropic_api_key,
669 "openai": config.openai_api_key,
670 "google": config.google_api_key,
671 "gemini": config.google_api_key,
672 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
673 "azure": config.azure_api_key,
674 "bedrock": config.aws_access_key_id,
675 }
677 api_key = api_key_map.get(provider)
679 # Provider-specific kwargs
680 provider_kwargs = {}
681 if provider == "azure": 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true
682 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version})
683 elif provider == "bedrock": 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true
684 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region})
685 elif provider == "ollama": 685 ↛ 686line 685 didn't jump to line 686 because the condition on line 685 was never true
686 provider_kwargs.update({"api_base": config.ollama_base_url})
687 elif provider in ["vertex_ai", "google"]:
688 vertex_project = config.vertex_project or config.google_project_id
689 if vertex_project: 689 ↛ 690line 689 didn't jump to line 690 because the condition on line 689 was never true
690 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location})
692 factory = LLMFactory(
693 provider=provider,
694 model_name=config.summarization_model_name or config.model_name,
695 api_key=api_key,
696 temperature=config.summarization_model_temperature,
697 max_tokens=config.summarization_model_max_tokens,
698 timeout=config.model_timeout,
699 enable_fallback=config.enable_fallback,
700 fallback_models=config.fallback_models,
701 **provider_kwargs,
702 )
704 # Set up credentials for all providers
705 factory._setup_environment(config=config)
707 return factory
710def create_verification_model(config) -> LLMFactory: # type: ignore[no-untyped-def]
711 """
712 Create dedicated LLM instance for verification (LLM-as-judge).
714 Uses potentially different model for output verification to balance
715 cost and quality. Falls back to primary model if not configured.
717 Args:
718 config: Settings object with LLM configuration
720 Returns:
721 Configured LLMFactory instance for verification
722 """
723 # If dedicated verification model not enabled, use primary model
724 if not getattr(config, "use_dedicated_verification_model", False): 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true
725 return create_llm_from_config(config)
727 # Determine provider and API key
728 provider = config.verification_model_provider or config.llm_provider
730 api_key_map = {
731 "anthropic": config.anthropic_api_key,
732 "openai": config.openai_api_key,
733 "google": config.google_api_key,
734 "gemini": config.google_api_key,
735 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
736 "azure": config.azure_api_key,
737 "bedrock": config.aws_access_key_id,
738 }
740 api_key = api_key_map.get(provider)
742 # Provider-specific kwargs
743 provider_kwargs = {}
744 if provider == "azure": 744 ↛ 745line 744 didn't jump to line 745 because the condition on line 744 was never true
745 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version})
746 elif provider == "bedrock": 746 ↛ 747line 746 didn't jump to line 747 because the condition on line 746 was never true
747 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region})
748 elif provider == "ollama": 748 ↛ 749line 748 didn't jump to line 749 because the condition on line 748 was never true
749 provider_kwargs.update({"api_base": config.ollama_base_url})
750 elif provider in ["vertex_ai", "google"]:
751 vertex_project = config.vertex_project or config.google_project_id
752 if vertex_project: 752 ↛ 753line 752 didn't jump to line 753 because the condition on line 752 was never true
753 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location})
755 factory = LLMFactory(
756 provider=provider,
757 model_name=config.verification_model_name or config.model_name,
758 api_key=api_key,
759 temperature=config.verification_model_temperature,
760 max_tokens=config.verification_model_max_tokens,
761 timeout=config.model_timeout,
762 enable_fallback=config.enable_fallback,
763 fallback_models=config.fallback_models,
764 **provider_kwargs,
765 )
767 # Set up credentials for all providers
768 factory._setup_environment(config=config)
770 return factory