Coverage for src / mcp_server_langgraph / llm / factory.py: 81%

238 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2LiteLLM Factory for Multi-Provider LLM Support 

3 

4Supports: Anthropic, OpenAI, Google (Gemini), Azure OpenAI, AWS Bedrock, 

5Ollama (Llama, Qwen, Mistral, etc.) 

6 

7Enhanced with resilience patterns (ADR-0026): 

8- Circuit breaker for provider failures 

9- Retry logic with exponential backoff 

10- Timeout enforcement 

11- Bulkhead isolation (10 concurrent LLM calls max) 

12""" 

13 

14import os 

15from typing import Any 

16 

17from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage 

18from litellm import acompletion, completion 

19from litellm.utils import ModelResponse # type: ignore[attr-defined] 

20 

21from mcp_server_langgraph.core.exceptions import ( 

22 LLMModelNotFoundError, 

23 LLMOverloadError, 

24 LLMProviderError, 

25 LLMRateLimitError, 

26 LLMTimeoutError, 

27) 

28from mcp_server_langgraph.resilience.retry import extract_retry_after_from_exception, is_overload_error 

29from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

30from mcp_server_langgraph.resilience import circuit_breaker, retry_with_backoff, with_bulkhead, with_timeout 

31 

32 

33class LLMFactory: 

34 """ 

35 Factory for creating and managing LLM connections via LiteLLM 

36 

37 Supports multiple providers with automatic fallback and retry logic. 

38 """ 

39 

40 def __init__( # type: ignore[no-untyped-def] 

41 self, 

42 provider: str = "google", 

43 model_name: str = "gemini-2.5-flash", 

44 api_key: str | None = None, 

45 temperature: float = 0.7, 

46 max_tokens: int = 4096, 

47 timeout: int = 60, 

48 enable_fallback: bool = True, 

49 fallback_models: list[str] | None = None, 

50 **kwargs, 

51 ): 

52 """ 

53 Initialize LLM Factory 

54 

55 Args: 

56 provider: LLM provider (anthropic, openai, google, azure, bedrock, ollama) 

57 model_name: Model identifier 

58 api_key: API key for the provider 

59 temperature: Sampling temperature (0-1) 

60 max_tokens: Maximum tokens to generate 

61 timeout: Request timeout in seconds 

62 enable_fallback: Enable fallback to alternative models 

63 fallback_models: List of fallback model names 

64 **kwargs: Additional provider-specific parameters 

65 """ 

66 self.provider = provider 

67 self.model_name = model_name 

68 self.api_key = api_key 

69 self.temperature = temperature 

70 self.max_tokens = max_tokens 

71 self.timeout = timeout 

72 self.enable_fallback = enable_fallback 

73 self.fallback_models = fallback_models or [] 

74 self.kwargs = kwargs 

75 

76 # Note: _setup_environment is now called by factory functions with config 

77 # This allows multi-provider credential setup for fallbacks 

78 

79 logger.info( 

80 "LLM Factory initialized", 

81 extra={ 

82 "provider": provider, 

83 "model": model_name, 

84 "fallback_enabled": enable_fallback, 

85 "fallback_count": len(fallback_models) if fallback_models else 0, 

86 }, 

87 ) 

88 

89 def _get_provider_from_model(self, model_name: str) -> str: 

90 """ 

91 Extract provider from model name. 

92 

93 Args: 

94 model_name: Model identifier (e.g., "gpt-5", "claude-sonnet-4-5", "gemini-2.5-flash") 

95 

96 Returns: 

97 Provider name (e.g., "openai", "anthropic", "google") 

98 """ 

99 model_lower = model_name.lower() 

100 

101 # Check provider prefixes FIRST (azure/gpt-4 should be azure, not openai) 

102 # Vertex AI (Google Cloud AI Platform) 

103 if model_lower.startswith("vertex_ai/"): 

104 return "vertex_ai" 

105 

106 # Azure (prefixed models) 

107 if model_lower.startswith("azure/"): 

108 return "azure" 

109 

110 # Bedrock (prefixed models) 

111 if model_lower.startswith("bedrock/"): 

112 return "bedrock" 

113 

114 # Ollama (local models) 

115 if model_lower.startswith("ollama/"): 

116 return "ollama" 

117 

118 # Then check model name patterns 

119 # Anthropic models 

120 if any(x in model_lower for x in ["claude", "anthropic"]): 

121 return "anthropic" 

122 

123 # OpenAI models 

124 if any(x in model_lower for x in ["gpt-", "o1-", "davinci", "curie", "babbage"]): 

125 return "openai" 

126 

127 # Google models 

128 if any(x in model_lower for x in ["gemini", "palm", "text-bison", "chat-bison"]): 

129 return "google" 

130 

131 # Default to current provider 

132 return self.provider 

133 

134 def _get_provider_kwargs(self, model_name: str) -> dict[str, Any]: 

135 """ 

136 Get provider-specific kwargs for a given model. 

137 

138 BUGFIX: Filters out provider-specific parameters that don't apply to the target provider. 

139 For example, Azure-specific params (azure_endpoint, azure_deployment) are removed 

140 when calling Anthropic or OpenAI fallback models. 

141 

142 Args: 

143 model_name: Model identifier to get kwargs for 

144 

145 Returns: 

146 dict: Filtered kwargs appropriate for the model's provider 

147 """ 

148 provider = self._get_provider_from_model(model_name) 

149 

150 # Define provider-specific parameter prefixes 

151 provider_specific_prefixes = { 

152 "azure": ["azure_", "api_version"], 

153 "bedrock": ["aws_", "bedrock_"], 

154 "vertex": ["vertex_", "gcp_"], 

155 } 

156 

157 # If this is the same provider as the primary model, return all kwargs 

158 if provider == self.provider: 

159 return self.kwargs 

160 

161 # Filter out parameters specific to other providers 

162 filtered_kwargs = {} 

163 for key, value in self.kwargs.items(): 163 ↛ 165line 163 didn't jump to line 165 because the loop on line 163 never started

164 # Check if this parameter belongs to a different provider 

165 is_provider_specific = False 

166 for other_provider, prefixes in provider_specific_prefixes.items(): 

167 if other_provider != provider and any(key.startswith(prefix) for prefix in prefixes): 

168 is_provider_specific = True 

169 break 

170 

171 # Include parameter if it's not specific to another provider 

172 if not is_provider_specific: 

173 filtered_kwargs[key] = value 

174 

175 logger.debug( 

176 "Filtered kwargs for fallback model", 

177 extra={ 

178 "model": model_name, 

179 "provider": provider, 

180 "original_kwargs": list(self.kwargs.keys()), 

181 "filtered_kwargs": list(filtered_kwargs.keys()), 

182 }, 

183 ) 

184 

185 return filtered_kwargs 

186 

187 def _setup_environment(self, config=None) -> None: # type: ignore[no-untyped-def] 

188 """ 

189 Set up environment variables for LiteLLM. 

190 

191 Configures credentials for primary provider AND all fallback providers 

192 to enable seamless multi-provider fallback. 

193 

194 Args: 

195 config: Settings object with API keys for all providers (optional) 

196 """ 

197 # Collect all providers needed (primary + fallbacks) 

198 providers_needed = {self.provider} 

199 

200 # Add providers for each fallback model 

201 for fallback_model in self.fallback_models: 

202 provider = self._get_provider_from_model(fallback_model) 

203 providers_needed.add(provider) 

204 

205 # Map provider to environment variable and config attribute 

206 # Provider-specific credential mapping (some providers need multiple env vars) 

207 provider_config_map = { 

208 "anthropic": [("ANTHROPIC_API_KEY", "anthropic_api_key")], 

209 "openai": [("OPENAI_API_KEY", "openai_api_key")], 

210 "google": [("GOOGLE_API_KEY", "google_api_key")], 

211 "gemini": [("GOOGLE_API_KEY", "google_api_key")], 

212 "azure": [ 

213 ("AZURE_API_KEY", "azure_api_key"), 

214 ("AZURE_API_BASE", "azure_api_base"), 

215 ("AZURE_API_VERSION", "azure_api_version"), 

216 ("AZURE_DEPLOYMENT_NAME", "azure_deployment_name"), 

217 ], 

218 "bedrock": [ 

219 ("AWS_ACCESS_KEY_ID", "aws_access_key_id"), 

220 ("AWS_SECRET_ACCESS_KEY", "aws_secret_access_key"), 

221 ("AWS_REGION", "aws_region"), 

222 ], 

223 } 

224 

225 # Set up credentials for each needed provider 

226 for provider in providers_needed: 

227 if provider not in provider_config_map: 

228 continue # Skip unknown providers (e.g., ollama doesn't need API key) 

229 

230 credential_pairs = provider_config_map[provider] 

231 

232 for env_var, config_attr in credential_pairs: 

233 # Get value from config 

234 value = getattr(config, config_attr) if config and hasattr(config, config_attr) else None 

235 

236 # For primary provider, use self.api_key as fallback for API key fields 

237 if value is None and provider == self.provider and "api_key" in config_attr.lower(): 

238 value = self.api_key 

239 

240 # Set environment variable if we have a value 

241 if value and env_var: 

242 os.environ[env_var] = str(value) 

243 logger.debug( 

244 f"Configured credential for provider: {provider}", extra={"provider": provider, "env_var": env_var} 

245 ) 

246 

247 def _format_messages(self, messages: list[BaseMessage | dict[str, Any]]) -> list[dict[str, str]]: 

248 """ 

249 Convert LangChain messages to LiteLLM format 

250 

251 Args: 

252 messages: List of LangChain BaseMessage objects or dicts 

253 

254 Returns: 

255 List of dictionaries in LiteLLM format 

256 """ 

257 formatted = [] 

258 for msg in messages: 

259 if isinstance(msg, HumanMessage): 

260 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

261 formatted.append({"role": "user", "content": content}) 

262 elif isinstance(msg, AIMessage): 

263 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

264 formatted.append({"role": "assistant", "content": content}) 

265 elif isinstance(msg, SystemMessage): 

266 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

267 formatted.append({"role": "system", "content": content}) 

268 elif isinstance(msg, dict): 

269 # Handle dict messages (already in correct format or need conversion) 

270 if "role" in msg and "content" in msg: 

271 # Already formatted dict 

272 formatted.append(msg) 

273 elif "content" in msg: 

274 # Dict with content but no role 

275 formatted.append({"role": "user", "content": str(msg["content"])}) 

276 else: 

277 # Malformed dict, convert to string 

278 formatted.append({"role": "user", "content": str(msg)}) 

279 else: 

280 # Fallback for other types - check if it has content attribute 

281 if hasattr(msg, "content"): 

282 formatted.append({"role": "user", "content": str(msg.content)}) 

283 else: 

284 # Last resort: convert entire object to string 

285 formatted.append({"role": "user", "content": str(msg)}) 

286 

287 return formatted 

288 

289 def invoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

290 """ 

291 Synchronous LLM invocation 

292 

293 Args: 

294 messages: List of messages 

295 **kwargs: Additional parameters for the model 

296 

297 Returns: 

298 AIMessage with the response 

299 """ 

300 with tracer.start_as_current_span("llm.invoke") as span: 

301 span.set_attribute("llm.provider", self.provider) 

302 span.set_attribute("llm.model", self.model_name) 

303 

304 formatted_messages = self._format_messages(messages) 

305 

306 # Merge kwargs with defaults 

307 params = { 

308 "model": self.model_name, 

309 "messages": formatted_messages, 

310 "temperature": kwargs.get("temperature", self.temperature), 

311 "max_tokens": kwargs.get("max_tokens", self.max_tokens), 

312 "timeout": kwargs.get("timeout", self.timeout), 

313 **self.kwargs, 

314 } 

315 

316 try: 

317 response: ModelResponse = completion(**params) 

318 

319 content = response.choices[0].message.content # type: ignore[union-attr] 

320 

321 # Track metrics 

322 metrics.successful_calls.add(1, {"operation": "llm.invoke", "model": self.model_name}) 

323 

324 logger.info( 

325 "LLM invocation successful", 

326 extra={ 

327 "model": self.model_name, 

328 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined] 

329 }, 

330 ) 

331 

332 return AIMessage(content=content) 

333 

334 except Exception as e: 

335 logger.error( 

336 f"LLM invocation failed: {e}", extra={"model": self.model_name, "provider": self.provider}, exc_info=True 

337 ) 

338 

339 metrics.failed_calls.add(1, {"operation": "llm.invoke", "model": self.model_name}) 

340 span.record_exception(e) 

341 

342 # Try fallback if enabled 

343 if self.enable_fallback and self.fallback_models: 343 ↛ 346line 343 didn't jump to line 346 because the condition on line 343 was always true

344 return self._try_fallback(messages, **kwargs) 

345 

346 raise 

347 

348 @circuit_breaker(name="llm", fail_max=5, timeout=60) 

349 @retry_with_backoff(max_attempts=3, exponential_base=2) 

350 @with_timeout(operation_type="llm") 

351 @with_bulkhead(resource_type="llm") 

352 async def ainvoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

353 """ 

354 Asynchronous LLM invocation with full resilience protection. 

355 

356 Protected by: 

357 - Circuit breaker: Fail fast if LLM provider is down (5 failures → open, 60s timeout) 

358 - Retry logic: Up to 3 attempts with exponential backoff (1s, 2s, 4s) 

359 - Timeout: 60s timeout for LLM operations 

360 - Bulkhead: Limit to 10 concurrent LLM calls 

361 

362 Args: 

363 messages: List of messages 

364 **kwargs: Additional parameters for the model 

365 

366 Returns: 

367 AIMessage with the response 

368 

369 Raises: 

370 CircuitBreakerOpenError: If circuit breaker is open 

371 RetryExhaustedError: If all retry attempts failed 

372 TimeoutError: If operation exceeds 60s timeout 

373 BulkheadRejectedError: If too many concurrent LLM calls 

374 LLMProviderError: For other LLM provider errors 

375 """ 

376 with tracer.start_as_current_span("llm.ainvoke") as span: 

377 span.set_attribute("llm.provider", self.provider) 

378 span.set_attribute("llm.model", self.model_name) 

379 

380 formatted_messages = self._format_messages(messages) 

381 

382 params = { 

383 "model": self.model_name, 

384 "messages": formatted_messages, 

385 "temperature": kwargs.get("temperature", self.temperature), 

386 "max_tokens": kwargs.get("max_tokens", self.max_tokens), 

387 "timeout": kwargs.get("timeout", self.timeout), 

388 **self.kwargs, 

389 } 

390 

391 try: 

392 response: ModelResponse = await acompletion(**params) 

393 

394 content = response.choices[0].message.content # type: ignore[union-attr] 

395 

396 metrics.successful_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name}) 

397 

398 logger.info( 

399 "Async LLM invocation successful", 

400 extra={ 

401 "model": self.model_name, 

402 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined] 

403 }, 

404 ) 

405 

406 return AIMessage(content=content) 

407 

408 except Exception as e: 

409 # Convert to custom exceptions for better error handling 

410 error_msg = str(e).lower() 

411 

412 # Check for overload errors (529 or equivalent) - handle first 

413 if is_overload_error(e): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true

414 retry_after = extract_retry_after_from_exception(e) 

415 raise LLMOverloadError( 

416 message=f"LLM provider overloaded: {e}", 

417 retry_after=retry_after, 

418 metadata={ 

419 "model": self.model_name, 

420 "provider": self.provider, 

421 "retry_after": retry_after, 

422 }, 

423 cause=e, 

424 ) 

425 elif "rate limit" in error_msg or "429" in error_msg: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true

426 raise LLMRateLimitError( 

427 message=f"LLM provider rate limit exceeded: {e}", 

428 metadata={"model": self.model_name, "provider": self.provider}, 

429 cause=e, 

430 ) 

431 elif "timeout" in error_msg or "timed out" in error_msg: 431 ↛ 432line 431 didn't jump to line 432 because the condition on line 431 was never true

432 raise LLMTimeoutError( 

433 message=f"LLM request timed out: {e}", 

434 metadata={"model": self.model_name, "provider": self.provider, "timeout": self.timeout}, 

435 cause=e, 

436 ) 

437 elif "model not found" in error_msg or "404" in error_msg: 437 ↛ 438line 437 didn't jump to line 438 because the condition on line 437 was never true

438 raise LLMModelNotFoundError( 

439 message=f"LLM model not found: {e}", 

440 metadata={"model": self.model_name, "provider": self.provider}, 

441 cause=e, 

442 ) 

443 else: 

444 logger.error( 

445 f"Async LLM invocation failed: {e}", 

446 extra={"model": self.model_name, "provider": self.provider}, 

447 exc_info=True, 

448 ) 

449 

450 metrics.failed_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name}) 

451 span.record_exception(e) 

452 

453 # Try fallback if enabled 

454 if self.enable_fallback and self.fallback_models: 454 ↛ 457line 454 didn't jump to line 457 because the condition on line 454 was always true

455 return await self._try_fallback_async(messages, **kwargs) 

456 

457 raise LLMProviderError( 

458 message=f"LLM provider error: {e}", 

459 metadata={"model": self.model_name, "provider": self.provider}, 

460 cause=e, 

461 ) 

462 

463 def _try_fallback(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

464 """Try fallback models if primary fails""" 

465 for fallback_model in self.fallback_models: 

466 if fallback_model == self.model_name: 466 ↛ 467line 466 didn't jump to line 467 because the condition on line 466 was never true

467 continue # Skip if it's the same model 

468 

469 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name}) 

470 

471 try: 

472 formatted_messages = self._format_messages(messages) 

473 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors 

474 provider_kwargs = self._get_provider_kwargs(fallback_model) 

475 response = completion( 

476 model=fallback_model, 

477 messages=formatted_messages, 

478 temperature=self.temperature, 

479 max_tokens=self.max_tokens, 

480 timeout=self.timeout, 

481 **provider_kwargs, # Forward provider-specific kwargs only 

482 ) 

483 

484 content = response.choices[0].message.content 

485 

486 logger.info("Fallback successful", extra={"fallback_model": fallback_model}) 

487 

488 metrics.successful_calls.add(1, {"operation": "llm.fallback", "model": fallback_model}) 

489 

490 return AIMessage(content=content) 

491 

492 except Exception as e: 

493 logger.error(f"Fallback model {fallback_model} failed: {e}", exc_info=True) 

494 continue 

495 

496 msg = "All models failed including fallbacks" 

497 raise RuntimeError(msg) 

498 

499 async def _try_fallback_async(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

500 """Try fallback models asynchronously""" 

501 for fallback_model in self.fallback_models: 

502 if fallback_model == self.model_name: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true

503 continue 

504 

505 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name}) 

506 

507 try: 

508 formatted_messages = self._format_messages(messages) 

509 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors 

510 provider_kwargs = self._get_provider_kwargs(fallback_model) 

511 response = await acompletion( 

512 model=fallback_model, 

513 messages=formatted_messages, 

514 temperature=self.temperature, 

515 max_tokens=self.max_tokens, 

516 timeout=self.timeout, 

517 **provider_kwargs, # Forward provider-specific kwargs only 

518 ) 

519 

520 content = response.choices[0].message.content 

521 

522 logger.info("Async fallback successful", extra={"fallback_model": fallback_model}) 

523 

524 metrics.successful_calls.add(1, {"operation": "llm.fallback_async", "model": fallback_model}) 

525 

526 return AIMessage(content=content) 

527 

528 except Exception as e: 

529 logger.error(f"Async fallback model {fallback_model} failed: {e}", exc_info=True) 

530 continue 

531 

532 msg = "All async models failed including fallbacks" 

533 raise RuntimeError(msg) 

534 

535 

536def create_llm_from_config(config) -> LLMFactory: # type: ignore[no-untyped-def] 

537 """ 

538 Create primary LLM instance from configuration 

539 

540 Args: 

541 config: Settings object with LLM configuration 

542 

543 Returns: 

544 Configured LLMFactory instance for primary chat operations 

545 """ 

546 # Determine API key based on provider 

547 api_key_map = { 

548 "anthropic": config.anthropic_api_key, 

549 "openai": config.openai_api_key, 

550 "google": config.google_api_key, 

551 "gemini": config.google_api_key, 

552 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

553 "azure": config.azure_api_key, 

554 "bedrock": config.aws_access_key_id, 

555 } 

556 

557 api_key = api_key_map.get(config.llm_provider) 

558 

559 # Provider-specific kwargs 

560 provider_kwargs = {} 

561 

562 if config.llm_provider == "azure": 562 ↛ 563line 562 didn't jump to line 563 because the condition on line 562 was never true

563 provider_kwargs.update( 

564 { 

565 "api_base": config.azure_api_base, 

566 "api_version": config.azure_api_version, 

567 } 

568 ) 

569 elif config.llm_provider == "bedrock": 569 ↛ 570line 569 didn't jump to line 570 because the condition on line 569 was never true

570 provider_kwargs.update( 

571 { 

572 "aws_secret_access_key": config.aws_secret_access_key, 

573 "aws_region_name": config.aws_region, 

574 } 

575 ) 

576 elif config.llm_provider == "ollama": 576 ↛ 577line 576 didn't jump to line 577 because the condition on line 576 was never true

577 provider_kwargs.update( 

578 { 

579 "api_base": config.ollama_base_url, 

580 } 

581 ) 

582 elif config.llm_provider in ["vertex_ai", "google"]: 582 ↛ 595line 582 didn't jump to line 595 because the condition on line 582 was always true

583 # Vertex AI configuration 

584 # LiteLLM requires vertex_project and vertex_location for Vertex AI models 

585 # If using Workload Identity on GKE, authentication is automatic 

586 vertex_project = config.vertex_project or config.google_project_id 

587 if vertex_project: 587 ↛ 588line 587 didn't jump to line 588 because the condition on line 587 was never true

588 provider_kwargs.update( 

589 { 

590 "vertex_project": vertex_project, 

591 "vertex_location": config.vertex_location, 

592 } 

593 ) 

594 

595 factory = LLMFactory( 

596 provider=config.llm_provider, 

597 model_name=config.model_name, 

598 api_key=api_key, 

599 temperature=config.model_temperature, 

600 max_tokens=config.model_max_tokens, 

601 timeout=config.model_timeout, 

602 enable_fallback=config.enable_fallback, 

603 fallback_models=config.fallback_models, 

604 **provider_kwargs, 

605 ) 

606 

607 # Set up credentials for primary + all fallback providers 

608 factory._setup_environment(config=config) 

609 

610 return factory 

611 

612 

613def create_summarization_model(config) -> LLMFactory: # type: ignore[no-untyped-def] 

614 """ 

615 Create dedicated LLM instance for summarization (cost-optimized). 

616 

617 Uses lighter/cheaper model for context compaction to reduce costs. 

618 Falls back to primary model if dedicated model not configured. 

619 

620 Args: 

621 config: Settings object with LLM configuration 

622 

623 Returns: 

624 Configured LLMFactory instance for summarization 

625 """ 

626 # If dedicated summarization model not enabled, use primary model 

627 if not getattr(config, "use_dedicated_summarization_model", False): 627 ↛ 628line 627 didn't jump to line 628 because the condition on line 627 was never true

628 return create_llm_from_config(config) 

629 

630 # Determine provider and API key 

631 provider = config.summarization_model_provider or config.llm_provider 

632 

633 api_key_map = { 

634 "anthropic": config.anthropic_api_key, 

635 "openai": config.openai_api_key, 

636 "google": config.google_api_key, 

637 "gemini": config.google_api_key, 

638 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

639 "azure": config.azure_api_key, 

640 "bedrock": config.aws_access_key_id, 

641 } 

642 

643 api_key = api_key_map.get(provider) 

644 

645 # Provider-specific kwargs 

646 provider_kwargs = {} 

647 if provider == "azure": 647 ↛ 648line 647 didn't jump to line 648 because the condition on line 647 was never true

648 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version}) 

649 elif provider == "bedrock": 649 ↛ 650line 649 didn't jump to line 650 because the condition on line 649 was never true

650 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region}) 

651 elif provider == "ollama": 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true

652 provider_kwargs.update({"api_base": config.ollama_base_url}) 

653 elif provider in ["vertex_ai", "google"]: 

654 vertex_project = config.vertex_project or config.google_project_id 

655 if vertex_project: 655 ↛ 656line 655 didn't jump to line 656 because the condition on line 655 was never true

656 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location}) 

657 

658 factory = LLMFactory( 

659 provider=provider, 

660 model_name=config.summarization_model_name or config.model_name, 

661 api_key=api_key, 

662 temperature=config.summarization_model_temperature, 

663 max_tokens=config.summarization_model_max_tokens, 

664 timeout=config.model_timeout, 

665 enable_fallback=config.enable_fallback, 

666 fallback_models=config.fallback_models, 

667 **provider_kwargs, 

668 ) 

669 

670 # Set up credentials for all providers 

671 factory._setup_environment(config=config) 

672 

673 return factory 

674 

675 

676def create_verification_model(config) -> LLMFactory: # type: ignore[no-untyped-def] 

677 """ 

678 Create dedicated LLM instance for verification (LLM-as-judge). 

679 

680 Uses potentially different model for output verification to balance 

681 cost and quality. Falls back to primary model if not configured. 

682 

683 Args: 

684 config: Settings object with LLM configuration 

685 

686 Returns: 

687 Configured LLMFactory instance for verification 

688 """ 

689 # If dedicated verification model not enabled, use primary model 

690 if not getattr(config, "use_dedicated_verification_model", False): 690 ↛ 691line 690 didn't jump to line 691 because the condition on line 690 was never true

691 return create_llm_from_config(config) 

692 

693 # Determine provider and API key 

694 provider = config.verification_model_provider or config.llm_provider 

695 

696 api_key_map = { 

697 "anthropic": config.anthropic_api_key, 

698 "openai": config.openai_api_key, 

699 "google": config.google_api_key, 

700 "gemini": config.google_api_key, 

701 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

702 "azure": config.azure_api_key, 

703 "bedrock": config.aws_access_key_id, 

704 } 

705 

706 api_key = api_key_map.get(provider) 

707 

708 # Provider-specific kwargs 

709 provider_kwargs = {} 

710 if provider == "azure": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true

711 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version}) 

712 elif provider == "bedrock": 712 ↛ 713line 712 didn't jump to line 713 because the condition on line 712 was never true

713 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region}) 

714 elif provider == "ollama": 714 ↛ 715line 714 didn't jump to line 715 because the condition on line 714 was never true

715 provider_kwargs.update({"api_base": config.ollama_base_url}) 

716 elif provider in ["vertex_ai", "google"]: 

717 vertex_project = config.vertex_project or config.google_project_id 

718 if vertex_project: 718 ↛ 719line 718 didn't jump to line 719 because the condition on line 718 was never true

719 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location}) 

720 

721 factory = LLMFactory( 

722 provider=provider, 

723 model_name=config.verification_model_name or config.model_name, 

724 api_key=api_key, 

725 temperature=config.verification_model_temperature, 

726 max_tokens=config.verification_model_max_tokens, 

727 timeout=config.model_timeout, 

728 enable_fallback=config.enable_fallback, 

729 fallback_models=config.fallback_models, 

730 **provider_kwargs, 

731 ) 

732 

733 # Set up credentials for all providers 

734 factory._setup_environment(config=config) 

735 

736 return factory