Coverage for src / mcp_server_langgraph / llm / factory.py: 81%
238 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2LiteLLM Factory for Multi-Provider LLM Support
4Supports: Anthropic, OpenAI, Google (Gemini), Azure OpenAI, AWS Bedrock,
5Ollama (Llama, Qwen, Mistral, etc.)
7Enhanced with resilience patterns (ADR-0026):
8- Circuit breaker for provider failures
9- Retry logic with exponential backoff
10- Timeout enforcement
11- Bulkhead isolation (10 concurrent LLM calls max)
12"""
14import os
15from typing import Any
17from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
18from litellm import acompletion, completion
19from litellm.utils import ModelResponse # type: ignore[attr-defined]
21from mcp_server_langgraph.core.exceptions import (
22 LLMModelNotFoundError,
23 LLMOverloadError,
24 LLMProviderError,
25 LLMRateLimitError,
26 LLMTimeoutError,
27)
28from mcp_server_langgraph.resilience.retry import extract_retry_after_from_exception, is_overload_error
29from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
30from mcp_server_langgraph.resilience import circuit_breaker, retry_with_backoff, with_bulkhead, with_timeout
33class LLMFactory:
34 """
35 Factory for creating and managing LLM connections via LiteLLM
37 Supports multiple providers with automatic fallback and retry logic.
38 """
40 def __init__( # type: ignore[no-untyped-def]
41 self,
42 provider: str = "google",
43 model_name: str = "gemini-2.5-flash",
44 api_key: str | None = None,
45 temperature: float = 0.7,
46 max_tokens: int = 4096,
47 timeout: int = 60,
48 enable_fallback: bool = True,
49 fallback_models: list[str] | None = None,
50 **kwargs,
51 ):
52 """
53 Initialize LLM Factory
55 Args:
56 provider: LLM provider (anthropic, openai, google, azure, bedrock, ollama)
57 model_name: Model identifier
58 api_key: API key for the provider
59 temperature: Sampling temperature (0-1)
60 max_tokens: Maximum tokens to generate
61 timeout: Request timeout in seconds
62 enable_fallback: Enable fallback to alternative models
63 fallback_models: List of fallback model names
64 **kwargs: Additional provider-specific parameters
65 """
66 self.provider = provider
67 self.model_name = model_name
68 self.api_key = api_key
69 self.temperature = temperature
70 self.max_tokens = max_tokens
71 self.timeout = timeout
72 self.enable_fallback = enable_fallback
73 self.fallback_models = fallback_models or []
74 self.kwargs = kwargs
76 # Note: _setup_environment is now called by factory functions with config
77 # This allows multi-provider credential setup for fallbacks
79 logger.info(
80 "LLM Factory initialized",
81 extra={
82 "provider": provider,
83 "model": model_name,
84 "fallback_enabled": enable_fallback,
85 "fallback_count": len(fallback_models) if fallback_models else 0,
86 },
87 )
89 def _get_provider_from_model(self, model_name: str) -> str:
90 """
91 Extract provider from model name.
93 Args:
94 model_name: Model identifier (e.g., "gpt-5", "claude-sonnet-4-5", "gemini-2.5-flash")
96 Returns:
97 Provider name (e.g., "openai", "anthropic", "google")
98 """
99 model_lower = model_name.lower()
101 # Check provider prefixes FIRST (azure/gpt-4 should be azure, not openai)
102 # Vertex AI (Google Cloud AI Platform)
103 if model_lower.startswith("vertex_ai/"):
104 return "vertex_ai"
106 # Azure (prefixed models)
107 if model_lower.startswith("azure/"):
108 return "azure"
110 # Bedrock (prefixed models)
111 if model_lower.startswith("bedrock/"):
112 return "bedrock"
114 # Ollama (local models)
115 if model_lower.startswith("ollama/"):
116 return "ollama"
118 # Then check model name patterns
119 # Anthropic models
120 if any(x in model_lower for x in ["claude", "anthropic"]):
121 return "anthropic"
123 # OpenAI models
124 if any(x in model_lower for x in ["gpt-", "o1-", "davinci", "curie", "babbage"]):
125 return "openai"
127 # Google models
128 if any(x in model_lower for x in ["gemini", "palm", "text-bison", "chat-bison"]):
129 return "google"
131 # Default to current provider
132 return self.provider
134 def _get_provider_kwargs(self, model_name: str) -> dict[str, Any]:
135 """
136 Get provider-specific kwargs for a given model.
138 BUGFIX: Filters out provider-specific parameters that don't apply to the target provider.
139 For example, Azure-specific params (azure_endpoint, azure_deployment) are removed
140 when calling Anthropic or OpenAI fallback models.
142 Args:
143 model_name: Model identifier to get kwargs for
145 Returns:
146 dict: Filtered kwargs appropriate for the model's provider
147 """
148 provider = self._get_provider_from_model(model_name)
150 # Define provider-specific parameter prefixes
151 provider_specific_prefixes = {
152 "azure": ["azure_", "api_version"],
153 "bedrock": ["aws_", "bedrock_"],
154 "vertex": ["vertex_", "gcp_"],
155 }
157 # If this is the same provider as the primary model, return all kwargs
158 if provider == self.provider:
159 return self.kwargs
161 # Filter out parameters specific to other providers
162 filtered_kwargs = {}
163 for key, value in self.kwargs.items(): 163 ↛ 165line 163 didn't jump to line 165 because the loop on line 163 never started
164 # Check if this parameter belongs to a different provider
165 is_provider_specific = False
166 for other_provider, prefixes in provider_specific_prefixes.items():
167 if other_provider != provider and any(key.startswith(prefix) for prefix in prefixes):
168 is_provider_specific = True
169 break
171 # Include parameter if it's not specific to another provider
172 if not is_provider_specific:
173 filtered_kwargs[key] = value
175 logger.debug(
176 "Filtered kwargs for fallback model",
177 extra={
178 "model": model_name,
179 "provider": provider,
180 "original_kwargs": list(self.kwargs.keys()),
181 "filtered_kwargs": list(filtered_kwargs.keys()),
182 },
183 )
185 return filtered_kwargs
187 def _setup_environment(self, config=None) -> None: # type: ignore[no-untyped-def]
188 """
189 Set up environment variables for LiteLLM.
191 Configures credentials for primary provider AND all fallback providers
192 to enable seamless multi-provider fallback.
194 Args:
195 config: Settings object with API keys for all providers (optional)
196 """
197 # Collect all providers needed (primary + fallbacks)
198 providers_needed = {self.provider}
200 # Add providers for each fallback model
201 for fallback_model in self.fallback_models:
202 provider = self._get_provider_from_model(fallback_model)
203 providers_needed.add(provider)
205 # Map provider to environment variable and config attribute
206 # Provider-specific credential mapping (some providers need multiple env vars)
207 provider_config_map = {
208 "anthropic": [("ANTHROPIC_API_KEY", "anthropic_api_key")],
209 "openai": [("OPENAI_API_KEY", "openai_api_key")],
210 "google": [("GOOGLE_API_KEY", "google_api_key")],
211 "gemini": [("GOOGLE_API_KEY", "google_api_key")],
212 "azure": [
213 ("AZURE_API_KEY", "azure_api_key"),
214 ("AZURE_API_BASE", "azure_api_base"),
215 ("AZURE_API_VERSION", "azure_api_version"),
216 ("AZURE_DEPLOYMENT_NAME", "azure_deployment_name"),
217 ],
218 "bedrock": [
219 ("AWS_ACCESS_KEY_ID", "aws_access_key_id"),
220 ("AWS_SECRET_ACCESS_KEY", "aws_secret_access_key"),
221 ("AWS_REGION", "aws_region"),
222 ],
223 }
225 # Set up credentials for each needed provider
226 for provider in providers_needed:
227 if provider not in provider_config_map:
228 continue # Skip unknown providers (e.g., ollama doesn't need API key)
230 credential_pairs = provider_config_map[provider]
232 for env_var, config_attr in credential_pairs:
233 # Get value from config
234 value = getattr(config, config_attr) if config and hasattr(config, config_attr) else None
236 # For primary provider, use self.api_key as fallback for API key fields
237 if value is None and provider == self.provider and "api_key" in config_attr.lower():
238 value = self.api_key
240 # Set environment variable if we have a value
241 if value and env_var:
242 os.environ[env_var] = str(value)
243 logger.debug(
244 f"Configured credential for provider: {provider}", extra={"provider": provider, "env_var": env_var}
245 )
247 def _format_messages(self, messages: list[BaseMessage | dict[str, Any]]) -> list[dict[str, str]]:
248 """
249 Convert LangChain messages to LiteLLM format
251 Args:
252 messages: List of LangChain BaseMessage objects or dicts
254 Returns:
255 List of dictionaries in LiteLLM format
256 """
257 formatted = []
258 for msg in messages:
259 if isinstance(msg, HumanMessage):
260 content = msg.content if isinstance(msg.content, str) else str(msg.content)
261 formatted.append({"role": "user", "content": content})
262 elif isinstance(msg, AIMessage):
263 content = msg.content if isinstance(msg.content, str) else str(msg.content)
264 formatted.append({"role": "assistant", "content": content})
265 elif isinstance(msg, SystemMessage):
266 content = msg.content if isinstance(msg.content, str) else str(msg.content)
267 formatted.append({"role": "system", "content": content})
268 elif isinstance(msg, dict):
269 # Handle dict messages (already in correct format or need conversion)
270 if "role" in msg and "content" in msg:
271 # Already formatted dict
272 formatted.append(msg)
273 elif "content" in msg:
274 # Dict with content but no role
275 formatted.append({"role": "user", "content": str(msg["content"])})
276 else:
277 # Malformed dict, convert to string
278 formatted.append({"role": "user", "content": str(msg)})
279 else:
280 # Fallback for other types - check if it has content attribute
281 if hasattr(msg, "content"):
282 formatted.append({"role": "user", "content": str(msg.content)})
283 else:
284 # Last resort: convert entire object to string
285 formatted.append({"role": "user", "content": str(msg)})
287 return formatted
289 def invoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
290 """
291 Synchronous LLM invocation
293 Args:
294 messages: List of messages
295 **kwargs: Additional parameters for the model
297 Returns:
298 AIMessage with the response
299 """
300 with tracer.start_as_current_span("llm.invoke") as span:
301 span.set_attribute("llm.provider", self.provider)
302 span.set_attribute("llm.model", self.model_name)
304 formatted_messages = self._format_messages(messages)
306 # Merge kwargs with defaults
307 params = {
308 "model": self.model_name,
309 "messages": formatted_messages,
310 "temperature": kwargs.get("temperature", self.temperature),
311 "max_tokens": kwargs.get("max_tokens", self.max_tokens),
312 "timeout": kwargs.get("timeout", self.timeout),
313 **self.kwargs,
314 }
316 try:
317 response: ModelResponse = completion(**params)
319 content = response.choices[0].message.content # type: ignore[union-attr]
321 # Track metrics
322 metrics.successful_calls.add(1, {"operation": "llm.invoke", "model": self.model_name})
324 logger.info(
325 "LLM invocation successful",
326 extra={
327 "model": self.model_name,
328 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined]
329 },
330 )
332 return AIMessage(content=content)
334 except Exception as e:
335 logger.error(
336 f"LLM invocation failed: {e}", extra={"model": self.model_name, "provider": self.provider}, exc_info=True
337 )
339 metrics.failed_calls.add(1, {"operation": "llm.invoke", "model": self.model_name})
340 span.record_exception(e)
342 # Try fallback if enabled
343 if self.enable_fallback and self.fallback_models: 343 ↛ 346line 343 didn't jump to line 346 because the condition on line 343 was always true
344 return self._try_fallback(messages, **kwargs)
346 raise
348 @circuit_breaker(name="llm", fail_max=5, timeout=60)
349 @retry_with_backoff(max_attempts=3, exponential_base=2)
350 @with_timeout(operation_type="llm")
351 @with_bulkhead(resource_type="llm")
352 async def ainvoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
353 """
354 Asynchronous LLM invocation with full resilience protection.
356 Protected by:
357 - Circuit breaker: Fail fast if LLM provider is down (5 failures → open, 60s timeout)
358 - Retry logic: Up to 3 attempts with exponential backoff (1s, 2s, 4s)
359 - Timeout: 60s timeout for LLM operations
360 - Bulkhead: Limit to 10 concurrent LLM calls
362 Args:
363 messages: List of messages
364 **kwargs: Additional parameters for the model
366 Returns:
367 AIMessage with the response
369 Raises:
370 CircuitBreakerOpenError: If circuit breaker is open
371 RetryExhaustedError: If all retry attempts failed
372 TimeoutError: If operation exceeds 60s timeout
373 BulkheadRejectedError: If too many concurrent LLM calls
374 LLMProviderError: For other LLM provider errors
375 """
376 with tracer.start_as_current_span("llm.ainvoke") as span:
377 span.set_attribute("llm.provider", self.provider)
378 span.set_attribute("llm.model", self.model_name)
380 formatted_messages = self._format_messages(messages)
382 params = {
383 "model": self.model_name,
384 "messages": formatted_messages,
385 "temperature": kwargs.get("temperature", self.temperature),
386 "max_tokens": kwargs.get("max_tokens", self.max_tokens),
387 "timeout": kwargs.get("timeout", self.timeout),
388 **self.kwargs,
389 }
391 try:
392 response: ModelResponse = await acompletion(**params)
394 content = response.choices[0].message.content # type: ignore[union-attr]
396 metrics.successful_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name})
398 logger.info(
399 "Async LLM invocation successful",
400 extra={
401 "model": self.model_name,
402 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined]
403 },
404 )
406 return AIMessage(content=content)
408 except Exception as e:
409 # Convert to custom exceptions for better error handling
410 error_msg = str(e).lower()
412 # Check for overload errors (529 or equivalent) - handle first
413 if is_overload_error(e): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true
414 retry_after = extract_retry_after_from_exception(e)
415 raise LLMOverloadError(
416 message=f"LLM provider overloaded: {e}",
417 retry_after=retry_after,
418 metadata={
419 "model": self.model_name,
420 "provider": self.provider,
421 "retry_after": retry_after,
422 },
423 cause=e,
424 )
425 elif "rate limit" in error_msg or "429" in error_msg: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true
426 raise LLMRateLimitError(
427 message=f"LLM provider rate limit exceeded: {e}",
428 metadata={"model": self.model_name, "provider": self.provider},
429 cause=e,
430 )
431 elif "timeout" in error_msg or "timed out" in error_msg: 431 ↛ 432line 431 didn't jump to line 432 because the condition on line 431 was never true
432 raise LLMTimeoutError(
433 message=f"LLM request timed out: {e}",
434 metadata={"model": self.model_name, "provider": self.provider, "timeout": self.timeout},
435 cause=e,
436 )
437 elif "model not found" in error_msg or "404" in error_msg: 437 ↛ 438line 437 didn't jump to line 438 because the condition on line 437 was never true
438 raise LLMModelNotFoundError(
439 message=f"LLM model not found: {e}",
440 metadata={"model": self.model_name, "provider": self.provider},
441 cause=e,
442 )
443 else:
444 logger.error(
445 f"Async LLM invocation failed: {e}",
446 extra={"model": self.model_name, "provider": self.provider},
447 exc_info=True,
448 )
450 metrics.failed_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name})
451 span.record_exception(e)
453 # Try fallback if enabled
454 if self.enable_fallback and self.fallback_models: 454 ↛ 457line 454 didn't jump to line 457 because the condition on line 454 was always true
455 return await self._try_fallback_async(messages, **kwargs)
457 raise LLMProviderError(
458 message=f"LLM provider error: {e}",
459 metadata={"model": self.model_name, "provider": self.provider},
460 cause=e,
461 )
463 def _try_fallback(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
464 """Try fallback models if primary fails"""
465 for fallback_model in self.fallback_models:
466 if fallback_model == self.model_name: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true
467 continue # Skip if it's the same model
469 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name})
471 try:
472 formatted_messages = self._format_messages(messages)
473 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors
474 provider_kwargs = self._get_provider_kwargs(fallback_model)
475 response = completion(
476 model=fallback_model,
477 messages=formatted_messages,
478 temperature=self.temperature,
479 max_tokens=self.max_tokens,
480 timeout=self.timeout,
481 **provider_kwargs, # Forward provider-specific kwargs only
482 )
484 content = response.choices[0].message.content
486 logger.info("Fallback successful", extra={"fallback_model": fallback_model})
488 metrics.successful_calls.add(1, {"operation": "llm.fallback", "model": fallback_model})
490 return AIMessage(content=content)
492 except Exception as e:
493 logger.error(f"Fallback model {fallback_model} failed: {e}", exc_info=True)
494 continue
496 msg = "All models failed including fallbacks"
497 raise RuntimeError(msg)
499 async def _try_fallback_async(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def]
500 """Try fallback models asynchronously"""
501 for fallback_model in self.fallback_models:
502 if fallback_model == self.model_name: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true
503 continue
505 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name})
507 try:
508 formatted_messages = self._format_messages(messages)
509 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors
510 provider_kwargs = self._get_provider_kwargs(fallback_model)
511 response = await acompletion(
512 model=fallback_model,
513 messages=formatted_messages,
514 temperature=self.temperature,
515 max_tokens=self.max_tokens,
516 timeout=self.timeout,
517 **provider_kwargs, # Forward provider-specific kwargs only
518 )
520 content = response.choices[0].message.content
522 logger.info("Async fallback successful", extra={"fallback_model": fallback_model})
524 metrics.successful_calls.add(1, {"operation": "llm.fallback_async", "model": fallback_model})
526 return AIMessage(content=content)
528 except Exception as e:
529 logger.error(f"Async fallback model {fallback_model} failed: {e}", exc_info=True)
530 continue
532 msg = "All async models failed including fallbacks"
533 raise RuntimeError(msg)
536def create_llm_from_config(config) -> LLMFactory: # type: ignore[no-untyped-def]
537 """
538 Create primary LLM instance from configuration
540 Args:
541 config: Settings object with LLM configuration
543 Returns:
544 Configured LLMFactory instance for primary chat operations
545 """
546 # Determine API key based on provider
547 api_key_map = {
548 "anthropic": config.anthropic_api_key,
549 "openai": config.openai_api_key,
550 "google": config.google_api_key,
551 "gemini": config.google_api_key,
552 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
553 "azure": config.azure_api_key,
554 "bedrock": config.aws_access_key_id,
555 }
557 api_key = api_key_map.get(config.llm_provider)
559 # Provider-specific kwargs
560 provider_kwargs = {}
562 if config.llm_provider == "azure": 562 ↛ 563line 562 didn't jump to line 563 because the condition on line 562 was never true
563 provider_kwargs.update(
564 {
565 "api_base": config.azure_api_base,
566 "api_version": config.azure_api_version,
567 }
568 )
569 elif config.llm_provider == "bedrock": 569 ↛ 570line 569 didn't jump to line 570 because the condition on line 569 was never true
570 provider_kwargs.update(
571 {
572 "aws_secret_access_key": config.aws_secret_access_key,
573 "aws_region_name": config.aws_region,
574 }
575 )
576 elif config.llm_provider == "ollama": 576 ↛ 577line 576 didn't jump to line 577 because the condition on line 576 was never true
577 provider_kwargs.update(
578 {
579 "api_base": config.ollama_base_url,
580 }
581 )
582 elif config.llm_provider in ["vertex_ai", "google"]: 582 ↛ 595line 582 didn't jump to line 595 because the condition on line 582 was always true
583 # Vertex AI configuration
584 # LiteLLM requires vertex_project and vertex_location for Vertex AI models
585 # If using Workload Identity on GKE, authentication is automatic
586 vertex_project = config.vertex_project or config.google_project_id
587 if vertex_project: 587 ↛ 588line 587 didn't jump to line 588 because the condition on line 587 was never true
588 provider_kwargs.update(
589 {
590 "vertex_project": vertex_project,
591 "vertex_location": config.vertex_location,
592 }
593 )
595 factory = LLMFactory(
596 provider=config.llm_provider,
597 model_name=config.model_name,
598 api_key=api_key,
599 temperature=config.model_temperature,
600 max_tokens=config.model_max_tokens,
601 timeout=config.model_timeout,
602 enable_fallback=config.enable_fallback,
603 fallback_models=config.fallback_models,
604 **provider_kwargs,
605 )
607 # Set up credentials for primary + all fallback providers
608 factory._setup_environment(config=config)
610 return factory
613def create_summarization_model(config) -> LLMFactory: # type: ignore[no-untyped-def]
614 """
615 Create dedicated LLM instance for summarization (cost-optimized).
617 Uses lighter/cheaper model for context compaction to reduce costs.
618 Falls back to primary model if dedicated model not configured.
620 Args:
621 config: Settings object with LLM configuration
623 Returns:
624 Configured LLMFactory instance for summarization
625 """
626 # If dedicated summarization model not enabled, use primary model
627 if not getattr(config, "use_dedicated_summarization_model", False): 627 ↛ 628line 627 didn't jump to line 628 because the condition on line 627 was never true
628 return create_llm_from_config(config)
630 # Determine provider and API key
631 provider = config.summarization_model_provider or config.llm_provider
633 api_key_map = {
634 "anthropic": config.anthropic_api_key,
635 "openai": config.openai_api_key,
636 "google": config.google_api_key,
637 "gemini": config.google_api_key,
638 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
639 "azure": config.azure_api_key,
640 "bedrock": config.aws_access_key_id,
641 }
643 api_key = api_key_map.get(provider)
645 # Provider-specific kwargs
646 provider_kwargs = {}
647 if provider == "azure": 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true
648 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version})
649 elif provider == "bedrock": 649 ↛ 650line 649 didn't jump to line 650 because the condition on line 649 was never true
650 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region})
651 elif provider == "ollama": 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true
652 provider_kwargs.update({"api_base": config.ollama_base_url})
653 elif provider in ["vertex_ai", "google"]:
654 vertex_project = config.vertex_project or config.google_project_id
655 if vertex_project: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true
656 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location})
658 factory = LLMFactory(
659 provider=provider,
660 model_name=config.summarization_model_name or config.model_name,
661 api_key=api_key,
662 temperature=config.summarization_model_temperature,
663 max_tokens=config.summarization_model_max_tokens,
664 timeout=config.model_timeout,
665 enable_fallback=config.enable_fallback,
666 fallback_models=config.fallback_models,
667 **provider_kwargs,
668 )
670 # Set up credentials for all providers
671 factory._setup_environment(config=config)
673 return factory
676def create_verification_model(config) -> LLMFactory: # type: ignore[no-untyped-def]
677 """
678 Create dedicated LLM instance for verification (LLM-as-judge).
680 Uses potentially different model for output verification to balance
681 cost and quality. Falls back to primary model if not configured.
683 Args:
684 config: Settings object with LLM configuration
686 Returns:
687 Configured LLMFactory instance for verification
688 """
689 # If dedicated verification model not enabled, use primary model
690 if not getattr(config, "use_dedicated_verification_model", False): 690 ↛ 691line 690 didn't jump to line 691 because the condition on line 690 was never true
691 return create_llm_from_config(config)
693 # Determine provider and API key
694 provider = config.verification_model_provider or config.llm_provider
696 api_key_map = {
697 "anthropic": config.anthropic_api_key,
698 "openai": config.openai_api_key,
699 "google": config.google_api_key,
700 "gemini": config.google_api_key,
701 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS
702 "azure": config.azure_api_key,
703 "bedrock": config.aws_access_key_id,
704 }
706 api_key = api_key_map.get(provider)
708 # Provider-specific kwargs
709 provider_kwargs = {}
710 if provider == "azure": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true
711 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version})
712 elif provider == "bedrock": 712 ↛ 713line 712 didn't jump to line 713 because the condition on line 712 was never true
713 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region})
714 elif provider == "ollama": 714 ↛ 715line 714 didn't jump to line 715 because the condition on line 714 was never true
715 provider_kwargs.update({"api_base": config.ollama_base_url})
716 elif provider in ["vertex_ai", "google"]:
717 vertex_project = config.vertex_project or config.google_project_id
718 if vertex_project: 718 ↛ 719line 718 didn't jump to line 719 because the condition on line 718 was never true
719 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location})
721 factory = LLMFactory(
722 provider=provider,
723 model_name=config.verification_model_name or config.model_name,
724 api_key=api_key,
725 temperature=config.verification_model_temperature,
726 max_tokens=config.verification_model_max_tokens,
727 timeout=config.model_timeout,
728 enable_fallback=config.enable_fallback,
729 fallback_models=config.fallback_models,
730 **provider_kwargs,
731 )
733 # Set up credentials for all providers
734 factory._setup_environment(config=config)
736 return factory