Coverage for src / mcp_server_langgraph / llm / factory.py: 84%

251 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-08 06:31 +0000

1""" 

2LiteLLM Factory for Multi-Provider LLM Support 

3 

4Supports: Anthropic, OpenAI, Google (Gemini), Azure OpenAI, AWS Bedrock, 

5Ollama (Llama, Qwen, Mistral, etc.) 

6 

7Enhanced with resilience patterns (ADR-0026): 

8- Circuit breaker for provider failures 

9- Retry logic with exponential backoff 

10- Timeout enforcement 

11- Bulkhead isolation (10 concurrent LLM calls max, provider-aware) 

12- Exponential backoff between fallback attempts 

13""" 

14 

15import asyncio 

16import os 

17 

18# Fallback resilience constants 

19FALLBACK_BASE_DELAY_SECONDS = 1.0 # Initial delay between fallback attempts 

20FALLBACK_DELAY_MULTIPLIER = 2.0 # Exponential multiplier 

21FALLBACK_MAX_DELAY_SECONDS = 8.0 # Cap for fallback delays 

22from typing import Any 

23 

24from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage 

25from litellm import acompletion, completion 

26from litellm.utils import ModelResponse # type: ignore[attr-defined] 

27 

28from mcp_server_langgraph.core.exceptions import ( 

29 LLMModelNotFoundError, 

30 LLMOverloadError, 

31 LLMProviderError, 

32 LLMRateLimitError, 

33 LLMTimeoutError, 

34) 

35from mcp_server_langgraph.resilience.retry import extract_retry_after_from_exception, is_overload_error 

36from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

37from mcp_server_langgraph.resilience import circuit_breaker, retry_with_backoff, with_bulkhead, with_timeout 

38 

39 

40class LLMFactory: 

41 """ 

42 Factory for creating and managing LLM connections via LiteLLM 

43 

44 Supports multiple providers with automatic fallback and retry logic. 

45 """ 

46 

47 def __init__( # type: ignore[no-untyped-def] 

48 self, 

49 provider: str = "google", 

50 model_name: str = "gemini-2.5-flash", 

51 api_key: str | None = None, 

52 temperature: float = 0.7, 

53 max_tokens: int = 4096, 

54 timeout: int = 60, 

55 enable_fallback: bool = True, 

56 fallback_models: list[str] | None = None, 

57 **kwargs, 

58 ): 

59 """ 

60 Initialize LLM Factory 

61 

62 Args: 

63 provider: LLM provider (anthropic, openai, google, azure, bedrock, ollama) 

64 model_name: Model identifier 

65 api_key: API key for the provider 

66 temperature: Sampling temperature (0-1) 

67 max_tokens: Maximum tokens to generate 

68 timeout: Request timeout in seconds 

69 enable_fallback: Enable fallback to alternative models 

70 fallback_models: List of fallback model names 

71 **kwargs: Additional provider-specific parameters 

72 """ 

73 self.provider = provider 

74 self.model_name = model_name 

75 self.api_key = api_key 

76 self.temperature = temperature 

77 self.max_tokens = max_tokens 

78 self.timeout = timeout 

79 self.enable_fallback = enable_fallback 

80 self.fallback_models = fallback_models or [] 

81 self.kwargs = kwargs 

82 

83 # Note: _setup_environment is now called by factory functions with config 

84 # This allows multi-provider credential setup for fallbacks 

85 

86 logger.info( 

87 "LLM Factory initialized", 

88 extra={ 

89 "provider": provider, 

90 "model": model_name, 

91 "fallback_enabled": enable_fallback, 

92 "fallback_count": len(fallback_models) if fallback_models else 0, 

93 }, 

94 ) 

95 

96 def _get_provider_from_model(self, model_name: str) -> str: 

97 """ 

98 Extract provider from model name. 

99 

100 Args: 

101 model_name: Model identifier (e.g., "gpt-5", "claude-sonnet-4-5", "gemini-2.5-flash") 

102 

103 Returns: 

104 Provider name (e.g., "openai", "anthropic", "google") 

105 """ 

106 model_lower = model_name.lower() 

107 

108 # Check provider prefixes FIRST (azure/gpt-4 should be azure, not openai) 

109 # Vertex AI (Google Cloud AI Platform) 

110 if model_lower.startswith("vertex_ai/"): 

111 return "vertex_ai" 

112 

113 # Azure (prefixed models) 

114 if model_lower.startswith("azure/"): 

115 return "azure" 

116 

117 # Bedrock (prefixed models) 

118 if model_lower.startswith("bedrock/"): 

119 return "bedrock" 

120 

121 # Ollama (local models) 

122 if model_lower.startswith("ollama/"): 

123 return "ollama" 

124 

125 # Then check model name patterns 

126 # Anthropic models 

127 if any(x in model_lower for x in ["claude", "anthropic"]): 

128 return "anthropic" 

129 

130 # OpenAI models 

131 if any(x in model_lower for x in ["gpt-", "o1-", "davinci", "curie", "babbage"]): 

132 return "openai" 

133 

134 # Google models 

135 if any(x in model_lower for x in ["gemini", "palm", "text-bison", "chat-bison"]): 

136 return "google" 

137 

138 # Default to current provider 

139 return self.provider 

140 

141 def _get_provider_kwargs(self, model_name: str) -> dict[str, Any]: 

142 """ 

143 Get provider-specific kwargs for a given model. 

144 

145 BUGFIX: Filters out provider-specific parameters that don't apply to the target provider. 

146 For example, Azure-specific params (azure_endpoint, azure_deployment) are removed 

147 when calling Anthropic or OpenAI fallback models. 

148 

149 Args: 

150 model_name: Model identifier to get kwargs for 

151 

152 Returns: 

153 dict: Filtered kwargs appropriate for the model's provider 

154 """ 

155 provider = self._get_provider_from_model(model_name) 

156 

157 # Define provider-specific parameter prefixes 

158 provider_specific_prefixes = { 

159 "azure": ["azure_", "api_version"], 

160 "bedrock": ["aws_", "bedrock_"], 

161 "vertex": ["vertex_", "gcp_"], 

162 } 

163 

164 # If this is the same provider as the primary model, return all kwargs 

165 if provider == self.provider: 

166 return self.kwargs 

167 

168 # Filter out parameters specific to other providers 

169 filtered_kwargs = {} 

170 for key, value in self.kwargs.items(): 170 ↛ 172line 170 didn't jump to line 172 because the loop on line 170 never started

171 # Check if this parameter belongs to a different provider 

172 is_provider_specific = False 

173 for other_provider, prefixes in provider_specific_prefixes.items(): 

174 if other_provider != provider and any(key.startswith(prefix) for prefix in prefixes): 

175 is_provider_specific = True 

176 break 

177 

178 # Include parameter if it's not specific to another provider 

179 if not is_provider_specific: 

180 filtered_kwargs[key] = value 

181 

182 logger.debug( 

183 "Filtered kwargs for fallback model", 

184 extra={ 

185 "model": model_name, 

186 "provider": provider, 

187 "original_kwargs": list(self.kwargs.keys()), 

188 "filtered_kwargs": list(filtered_kwargs.keys()), 

189 }, 

190 ) 

191 

192 return filtered_kwargs 

193 

194 def _setup_environment(self, config=None) -> None: # type: ignore[no-untyped-def] 

195 """ 

196 Set up environment variables for LiteLLM. 

197 

198 Configures credentials for primary provider AND all fallback providers 

199 to enable seamless multi-provider fallback. 

200 

201 Args: 

202 config: Settings object with API keys for all providers (optional) 

203 """ 

204 # Collect all providers needed (primary + fallbacks) 

205 providers_needed = {self.provider} 

206 

207 # Add providers for each fallback model 

208 for fallback_model in self.fallback_models: 

209 provider = self._get_provider_from_model(fallback_model) 

210 providers_needed.add(provider) 

211 

212 # Map provider to environment variable and config attribute 

213 # Provider-specific credential mapping (some providers need multiple env vars) 

214 provider_config_map = { 

215 "anthropic": [("ANTHROPIC_API_KEY", "anthropic_api_key")], 

216 "openai": [("OPENAI_API_KEY", "openai_api_key")], 

217 "google": [("GOOGLE_API_KEY", "google_api_key")], 

218 "gemini": [("GOOGLE_API_KEY", "google_api_key")], 

219 "azure": [ 

220 ("AZURE_API_KEY", "azure_api_key"), 

221 ("AZURE_API_BASE", "azure_api_base"), 

222 ("AZURE_API_VERSION", "azure_api_version"), 

223 ("AZURE_DEPLOYMENT_NAME", "azure_deployment_name"), 

224 ], 

225 "bedrock": [ 

226 ("AWS_ACCESS_KEY_ID", "aws_access_key_id"), 

227 ("AWS_SECRET_ACCESS_KEY", "aws_secret_access_key"), 

228 ("AWS_REGION", "aws_region"), 

229 ], 

230 } 

231 

232 # Set up credentials for each needed provider 

233 for provider in providers_needed: 

234 if provider not in provider_config_map: 

235 continue # Skip unknown providers (e.g., ollama doesn't need API key) 

236 

237 credential_pairs = provider_config_map[provider] 

238 

239 for env_var, config_attr in credential_pairs: 

240 # Get value from config 

241 value = getattr(config, config_attr) if config and hasattr(config, config_attr) else None 

242 

243 # For primary provider, use self.api_key as fallback for API key fields 

244 if value is None and provider == self.provider and "api_key" in config_attr.lower(): 

245 value = self.api_key 

246 

247 # Set environment variable if we have a value 

248 if value and env_var: 

249 os.environ[env_var] = str(value) 

250 logger.debug( 

251 f"Configured credential for provider: {provider}", extra={"provider": provider, "env_var": env_var} 

252 ) 

253 

254 def _format_messages(self, messages: list[BaseMessage | dict[str, Any]]) -> list[dict[str, str]]: 

255 """ 

256 Convert LangChain messages to LiteLLM format 

257 

258 Args: 

259 messages: List of LangChain BaseMessage objects or dicts 

260 

261 Returns: 

262 List of dictionaries in LiteLLM format 

263 """ 

264 formatted = [] 

265 for msg in messages: 

266 if isinstance(msg, HumanMessage): 

267 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

268 formatted.append({"role": "user", "content": content}) 

269 elif isinstance(msg, AIMessage): 

270 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

271 formatted.append({"role": "assistant", "content": content}) 

272 elif isinstance(msg, SystemMessage): 

273 content = msg.content if isinstance(msg.content, str) else str(msg.content) 

274 formatted.append({"role": "system", "content": content}) 

275 elif isinstance(msg, dict): 

276 # Handle dict messages (already in correct format or need conversion) 

277 if "role" in msg and "content" in msg: 

278 # Already formatted dict 

279 formatted.append(msg) 

280 elif "content" in msg: 

281 # Dict with content but no role 

282 formatted.append({"role": "user", "content": str(msg["content"])}) 

283 else: 

284 # Malformed dict, convert to string 

285 formatted.append({"role": "user", "content": str(msg)}) 

286 else: 

287 # Fallback for other types - check if it has content attribute 

288 if hasattr(msg, "content"): 

289 formatted.append({"role": "user", "content": str(msg.content)}) 

290 else: 

291 # Last resort: convert entire object to string 

292 formatted.append({"role": "user", "content": str(msg)}) 

293 

294 return formatted 

295 

296 def invoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

297 """ 

298 Synchronous LLM invocation 

299 

300 Args: 

301 messages: List of messages 

302 **kwargs: Additional parameters for the model 

303 

304 Returns: 

305 AIMessage with the response 

306 """ 

307 with tracer.start_as_current_span("llm.invoke") as span: 

308 span.set_attribute("llm.provider", self.provider) 

309 span.set_attribute("llm.model", self.model_name) 

310 

311 formatted_messages = self._format_messages(messages) 

312 

313 # Merge kwargs with defaults 

314 params = { 

315 "model": self.model_name, 

316 "messages": formatted_messages, 

317 "temperature": kwargs.get("temperature", self.temperature), 

318 "max_tokens": kwargs.get("max_tokens", self.max_tokens), 

319 "timeout": kwargs.get("timeout", self.timeout), 

320 **self.kwargs, 

321 } 

322 

323 try: 

324 response: ModelResponse = completion(**params) 

325 

326 content = response.choices[0].message.content # type: ignore[union-attr] 

327 

328 # Track metrics 

329 metrics.successful_calls.add(1, {"operation": "llm.invoke", "model": self.model_name}) 

330 

331 logger.info( 

332 "LLM invocation successful", 

333 extra={ 

334 "model": self.model_name, 

335 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined] 

336 }, 

337 ) 

338 

339 return AIMessage(content=content) 

340 

341 except Exception as e: 

342 logger.error( 

343 f"LLM invocation failed: {e}", extra={"model": self.model_name, "provider": self.provider}, exc_info=True 

344 ) 

345 

346 metrics.failed_calls.add(1, {"operation": "llm.invoke", "model": self.model_name}) 

347 span.record_exception(e) 

348 

349 # Try fallback if enabled 

350 if self.enable_fallback and self.fallback_models: 350 ↛ 353line 350 didn't jump to line 353 because the condition on line 350 was always true

351 return self._try_fallback(messages, **kwargs) 

352 

353 raise 

354 

355 @circuit_breaker(name="llm", fail_max=5, timeout=60) 

356 @retry_with_backoff(max_attempts=3, exponential_base=2) 

357 @with_timeout(operation_type="llm") 

358 @with_bulkhead(resource_type="llm") 

359 async def ainvoke(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

360 """ 

361 Asynchronous LLM invocation with full resilience protection. 

362 

363 Protected by: 

364 - Circuit breaker: Fail fast if LLM provider is down (5 failures → open, 60s timeout) 

365 - Retry logic: Up to 3 attempts with exponential backoff (1s, 2s, 4s) 

366 - Timeout: 60s timeout for LLM operations 

367 - Bulkhead: Limit to 10 concurrent LLM calls 

368 

369 Args: 

370 messages: List of messages 

371 **kwargs: Additional parameters for the model 

372 

373 Returns: 

374 AIMessage with the response 

375 

376 Raises: 

377 CircuitBreakerOpenError: If circuit breaker is open 

378 RetryExhaustedError: If all retry attempts failed 

379 TimeoutError: If operation exceeds 60s timeout 

380 BulkheadRejectedError: If too many concurrent LLM calls 

381 LLMProviderError: For other LLM provider errors 

382 """ 

383 with tracer.start_as_current_span("llm.ainvoke") as span: 

384 span.set_attribute("llm.provider", self.provider) 

385 span.set_attribute("llm.model", self.model_name) 

386 

387 formatted_messages = self._format_messages(messages) 

388 

389 params = { 

390 "model": self.model_name, 

391 "messages": formatted_messages, 

392 "temperature": kwargs.get("temperature", self.temperature), 

393 "max_tokens": kwargs.get("max_tokens", self.max_tokens), 

394 "timeout": kwargs.get("timeout", self.timeout), 

395 **self.kwargs, 

396 } 

397 

398 try: 

399 response: ModelResponse = await acompletion(**params) 

400 

401 content = response.choices[0].message.content # type: ignore[union-attr] 

402 

403 metrics.successful_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name}) 

404 

405 logger.info( 

406 "Async LLM invocation successful", 

407 extra={ 

408 "model": self.model_name, 

409 "tokens": response.usage.total_tokens if response.usage else 0, # type: ignore[attr-defined] 

410 }, 

411 ) 

412 

413 return AIMessage(content=content) 

414 

415 except Exception as e: 

416 # Convert to custom exceptions for better error handling 

417 error_msg = str(e).lower() 

418 

419 # Check for overload errors (529 or equivalent) - handle first 

420 if is_overload_error(e): 420 ↛ 421line 420 didn't jump to line 421 because the condition on line 420 was never true

421 retry_after = extract_retry_after_from_exception(e) 

422 raise LLMOverloadError( 

423 message=f"LLM provider overloaded: {e}", 

424 retry_after=retry_after, 

425 metadata={ 

426 "model": self.model_name, 

427 "provider": self.provider, 

428 "retry_after": retry_after, 

429 }, 

430 cause=e, 

431 ) 

432 elif "rate limit" in error_msg or "429" in error_msg: 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was never true

433 # Extract Retry-After header for rate limit errors (same pattern as overload) 

434 retry_after = extract_retry_after_from_exception(e) 

435 raise LLMRateLimitError( 

436 message=f"LLM provider rate limit exceeded: {e}", 

437 retry_after=retry_after, 

438 metadata={ 

439 "model": self.model_name, 

440 "provider": self.provider, 

441 "retry_after": retry_after, 

442 }, 

443 cause=e, 

444 ) 

445 elif "timeout" in error_msg or "timed out" in error_msg: 445 ↛ 446line 445 didn't jump to line 446 because the condition on line 445 was never true

446 raise LLMTimeoutError( 

447 message=f"LLM request timed out: {e}", 

448 metadata={"model": self.model_name, "provider": self.provider, "timeout": self.timeout}, 

449 cause=e, 

450 ) 

451 elif "model not found" in error_msg or "404" in error_msg: 451 ↛ 452line 451 didn't jump to line 452 because the condition on line 451 was never true

452 raise LLMModelNotFoundError( 

453 message=f"LLM model not found: {e}", 

454 metadata={"model": self.model_name, "provider": self.provider}, 

455 cause=e, 

456 ) 

457 else: 

458 logger.error( 

459 f"Async LLM invocation failed: {e}", 

460 extra={"model": self.model_name, "provider": self.provider}, 

461 exc_info=True, 

462 ) 

463 

464 metrics.failed_calls.add(1, {"operation": "llm.ainvoke", "model": self.model_name}) 

465 span.record_exception(e) 

466 

467 # Try fallback if enabled 

468 if self.enable_fallback and self.fallback_models: 468 ↛ 471line 468 didn't jump to line 471 because the condition on line 468 was always true

469 return await self._try_fallback_async(messages, **kwargs) 

470 

471 raise LLMProviderError( 

472 message=f"LLM provider error: {e}", 

473 metadata={"model": self.model_name, "provider": self.provider}, 

474 cause=e, 

475 ) 

476 

477 def _try_fallback(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

478 """Try fallback models if primary fails""" 

479 for fallback_model in self.fallback_models: 

480 if fallback_model == self.model_name: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true

481 continue # Skip if it's the same model 

482 

483 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name}) 

484 

485 try: 

486 formatted_messages = self._format_messages(messages) 

487 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors 

488 provider_kwargs = self._get_provider_kwargs(fallback_model) 

489 response = completion( 

490 model=fallback_model, 

491 messages=formatted_messages, 

492 temperature=self.temperature, 

493 max_tokens=self.max_tokens, 

494 timeout=self.timeout, 

495 **provider_kwargs, # Forward provider-specific kwargs only 

496 ) 

497 

498 content = response.choices[0].message.content 

499 

500 logger.info("Fallback successful", extra={"fallback_model": fallback_model}) 

501 

502 metrics.successful_calls.add(1, {"operation": "llm.fallback", "model": fallback_model}) 

503 

504 return AIMessage(content=content) 

505 

506 except Exception as e: 

507 logger.error(f"Fallback model {fallback_model} failed: {e}", exc_info=True) 

508 continue 

509 

510 msg = "All models failed including fallbacks" 

511 raise RuntimeError(msg) 

512 

513 async def _try_fallback_async(self, messages: list[BaseMessage | dict[str, Any]], **kwargs) -> AIMessage: # type: ignore[no-untyped-def] 

514 """Try fallback models asynchronously with exponential backoff. 

515 

516 Implements resilience patterns: 

517 - Exponential backoff between attempts (1s, 2s, 4s, 8s cap) 

518 - Skips primary model if accidentally in fallback list 

519 - Logs each attempt for observability 

520 """ 

521 attempt = 0 

522 current_delay = FALLBACK_BASE_DELAY_SECONDS 

523 

524 for fallback_model in self.fallback_models: 

525 if fallback_model == self.model_name: 

526 continue 

527 

528 # Apply exponential backoff delay before attempt (except first) 

529 if attempt > 0: 

530 delay = min(current_delay, FALLBACK_MAX_DELAY_SECONDS) 

531 logger.info( 

532 f"Waiting {delay:.1f}s before fallback attempt {attempt + 1}", 

533 extra={"delay_seconds": delay, "fallback_model": fallback_model}, 

534 ) 

535 await asyncio.sleep(delay) 

536 current_delay *= FALLBACK_DELAY_MULTIPLIER 

537 

538 attempt += 1 

539 logger.warning(f"Trying fallback model: {fallback_model}", extra={"primary_model": self.model_name}) 

540 

541 try: 

542 formatted_messages = self._format_messages(messages) 

543 # BUGFIX: Use provider-specific kwargs to avoid cross-provider parameter errors 

544 provider_kwargs = self._get_provider_kwargs(fallback_model) 

545 response = await acompletion( 

546 model=fallback_model, 

547 messages=formatted_messages, 

548 temperature=self.temperature, 

549 max_tokens=self.max_tokens, 

550 timeout=self.timeout, 

551 **provider_kwargs, # Forward provider-specific kwargs only 

552 ) 

553 

554 content = response.choices[0].message.content 

555 

556 logger.info("Async fallback successful", extra={"fallback_model": fallback_model}) 

557 

558 metrics.successful_calls.add(1, {"operation": "llm.fallback_async", "model": fallback_model}) 

559 

560 return AIMessage(content=content) 

561 

562 except Exception as e: 

563 logger.error(f"Async fallback model {fallback_model} failed: {e}", exc_info=True) 

564 continue 

565 

566 msg = "All async models failed including fallbacks" 

567 raise RuntimeError(msg) 

568 

569 

570def create_llm_from_config(config) -> LLMFactory: # type: ignore[no-untyped-def] 

571 """ 

572 Create primary LLM instance from configuration 

573 

574 Args: 

575 config: Settings object with LLM configuration 

576 

577 Returns: 

578 Configured LLMFactory instance for primary chat operations 

579 """ 

580 # Determine API key based on provider 

581 api_key_map = { 

582 "anthropic": config.anthropic_api_key, 

583 "openai": config.openai_api_key, 

584 "google": config.google_api_key, 

585 "gemini": config.google_api_key, 

586 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

587 "azure": config.azure_api_key, 

588 "bedrock": config.aws_access_key_id, 

589 } 

590 

591 api_key = api_key_map.get(config.llm_provider) 

592 

593 # Provider-specific kwargs 

594 provider_kwargs = {} 

595 

596 if config.llm_provider == "azure": 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true

597 provider_kwargs.update( 

598 { 

599 "api_base": config.azure_api_base, 

600 "api_version": config.azure_api_version, 

601 } 

602 ) 

603 elif config.llm_provider == "bedrock": 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true

604 provider_kwargs.update( 

605 { 

606 "aws_secret_access_key": config.aws_secret_access_key, 

607 "aws_region_name": config.aws_region, 

608 } 

609 ) 

610 elif config.llm_provider == "ollama": 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true

611 provider_kwargs.update( 

612 { 

613 "api_base": config.ollama_base_url, 

614 } 

615 ) 

616 elif config.llm_provider in ["vertex_ai", "google"]: 616 ↛ 629line 616 didn't jump to line 629 because the condition on line 616 was always true

617 # Vertex AI configuration 

618 # LiteLLM requires vertex_project and vertex_location for Vertex AI models 

619 # If using Workload Identity on GKE, authentication is automatic 

620 vertex_project = config.vertex_project or config.google_project_id 

621 if vertex_project: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true

622 provider_kwargs.update( 

623 { 

624 "vertex_project": vertex_project, 

625 "vertex_location": config.vertex_location, 

626 } 

627 ) 

628 

629 factory = LLMFactory( 

630 provider=config.llm_provider, 

631 model_name=config.model_name, 

632 api_key=api_key, 

633 temperature=config.model_temperature, 

634 max_tokens=config.model_max_tokens, 

635 timeout=config.model_timeout, 

636 enable_fallback=config.enable_fallback, 

637 fallback_models=config.fallback_models, 

638 **provider_kwargs, 

639 ) 

640 

641 # Set up credentials for primary + all fallback providers 

642 factory._setup_environment(config=config) 

643 

644 return factory 

645 

646 

647def create_summarization_model(config) -> LLMFactory: # type: ignore[no-untyped-def] 

648 """ 

649 Create dedicated LLM instance for summarization (cost-optimized). 

650 

651 Uses lighter/cheaper model for context compaction to reduce costs. 

652 Falls back to primary model if dedicated model not configured. 

653 

654 Args: 

655 config: Settings object with LLM configuration 

656 

657 Returns: 

658 Configured LLMFactory instance for summarization 

659 """ 

660 # If dedicated summarization model not enabled, use primary model 

661 if not getattr(config, "use_dedicated_summarization_model", False): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true

662 return create_llm_from_config(config) 

663 

664 # Determine provider and API key 

665 provider = config.summarization_model_provider or config.llm_provider 

666 

667 api_key_map = { 

668 "anthropic": config.anthropic_api_key, 

669 "openai": config.openai_api_key, 

670 "google": config.google_api_key, 

671 "gemini": config.google_api_key, 

672 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

673 "azure": config.azure_api_key, 

674 "bedrock": config.aws_access_key_id, 

675 } 

676 

677 api_key = api_key_map.get(provider) 

678 

679 # Provider-specific kwargs 

680 provider_kwargs = {} 

681 if provider == "azure": 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true

682 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version}) 

683 elif provider == "bedrock": 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true

684 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region}) 

685 elif provider == "ollama": 685 ↛ 686line 685 didn't jump to line 686 because the condition on line 685 was never true

686 provider_kwargs.update({"api_base": config.ollama_base_url}) 

687 elif provider in ["vertex_ai", "google"]: 

688 vertex_project = config.vertex_project or config.google_project_id 

689 if vertex_project: 689 ↛ 690line 689 didn't jump to line 690 because the condition on line 689 was never true

690 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location}) 

691 

692 factory = LLMFactory( 

693 provider=provider, 

694 model_name=config.summarization_model_name or config.model_name, 

695 api_key=api_key, 

696 temperature=config.summarization_model_temperature, 

697 max_tokens=config.summarization_model_max_tokens, 

698 timeout=config.model_timeout, 

699 enable_fallback=config.enable_fallback, 

700 fallback_models=config.fallback_models, 

701 **provider_kwargs, 

702 ) 

703 

704 # Set up credentials for all providers 

705 factory._setup_environment(config=config) 

706 

707 return factory 

708 

709 

710def create_verification_model(config) -> LLMFactory: # type: ignore[no-untyped-def] 

711 """ 

712 Create dedicated LLM instance for verification (LLM-as-judge). 

713 

714 Uses potentially different model for output verification to balance 

715 cost and quality. Falls back to primary model if not configured. 

716 

717 Args: 

718 config: Settings object with LLM configuration 

719 

720 Returns: 

721 Configured LLMFactory instance for verification 

722 """ 

723 # If dedicated verification model not enabled, use primary model 

724 if not getattr(config, "use_dedicated_verification_model", False): 724 ↛ 725line 724 didn't jump to line 725 because the condition on line 724 was never true

725 return create_llm_from_config(config) 

726 

727 # Determine provider and API key 

728 provider = config.verification_model_provider or config.llm_provider 

729 

730 api_key_map = { 

731 "anthropic": config.anthropic_api_key, 

732 "openai": config.openai_api_key, 

733 "google": config.google_api_key, 

734 "gemini": config.google_api_key, 

735 "vertex_ai": None, # Vertex AI uses Workload Identity or GOOGLE_APPLICATION_CREDENTIALS 

736 "azure": config.azure_api_key, 

737 "bedrock": config.aws_access_key_id, 

738 } 

739 

740 api_key = api_key_map.get(provider) 

741 

742 # Provider-specific kwargs 

743 provider_kwargs = {} 

744 if provider == "azure": 744 ↛ 745line 744 didn't jump to line 745 because the condition on line 744 was never true

745 provider_kwargs.update({"api_base": config.azure_api_base, "api_version": config.azure_api_version}) 

746 elif provider == "bedrock": 746 ↛ 747line 746 didn't jump to line 747 because the condition on line 746 was never true

747 provider_kwargs.update({"aws_secret_access_key": config.aws_secret_access_key, "aws_region_name": config.aws_region}) 

748 elif provider == "ollama": 748 ↛ 749line 748 didn't jump to line 749 because the condition on line 748 was never true

749 provider_kwargs.update({"api_base": config.ollama_base_url}) 

750 elif provider in ["vertex_ai", "google"]: 

751 vertex_project = config.vertex_project or config.google_project_id 

752 if vertex_project: 752 ↛ 753line 752 didn't jump to line 753 because the condition on line 752 was never true

753 provider_kwargs.update({"vertex_project": vertex_project, "vertex_location": config.vertex_location}) 

754 

755 factory = LLMFactory( 

756 provider=provider, 

757 model_name=config.verification_model_name or config.model_name, 

758 api_key=api_key, 

759 temperature=config.verification_model_temperature, 

760 max_tokens=config.verification_model_max_tokens, 

761 timeout=config.model_timeout, 

762 enable_fallback=config.enable_fallback, 

763 fallback_models=config.fallback_models, 

764 **provider_kwargs, 

765 ) 

766 

767 # Set up credentials for all providers 

768 factory._setup_environment(config=config) 

769 

770 return factory