Coverage for src / mcp_server_langgraph / core / parallel_executor.py: 79%
115 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Parallel Tool Execution
4Implements Anthropic's parallelization pattern for independent operations.
5"""
7import asyncio
8import time
9from collections import deque
10from collections.abc import Callable
11from dataclasses import dataclass, field
12from typing import Any
14from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer
17@dataclass
18class ToolInvocation:
19 """Represents a tool invocation."""
21 tool_name: str
22 arguments: dict[str, Any]
23 invocation_id: str
24 dependencies: list[str] = field(default_factory=list) # IDs of tools this depends on
27@dataclass
28class ToolResult:
29 """Result from tool execution."""
31 invocation_id: str
32 tool_name: str
33 result: Any
34 error: Exception | None = None
35 duration_ms: float = 0.0
38class ParallelToolExecutor:
39 """
40 Executes tools in parallel when they have no dependencies.
42 Implements Anthropic's parallelization pattern:
43 - Detects independent operations
44 - Schedules parallel execution
45 - Respects dependencies
46 - Aggregates results
47 """
49 def __init__(self, max_parallelism: int = 5, task_timeout_seconds: float | None = None) -> None:
50 """
51 Initialize parallel executor.
53 Args:
54 max_parallelism: Maximum concurrent tool executions
55 task_timeout_seconds: Optional timeout for each task (None = no timeout)
56 """
57 self.max_parallelism = max_parallelism
58 self.task_timeout_seconds = task_timeout_seconds
59 self.semaphore = asyncio.Semaphore(max_parallelism)
61 async def execute_parallel(
62 self, invocations: list[ToolInvocation], tool_executor: Callable[[str, dict[str, Any]], Any]
63 ) -> list[ToolResult]:
64 """
65 Execute tool invocations in parallel where possible.
67 Args:
68 invocations: List of tool invocations
69 tool_executor: Async function to execute a single tool
71 Returns:
72 List of tool results
73 """
74 with tracer.start_as_current_span("tools.parallel_execute") as span:
75 span.set_attribute("total_invocations", len(invocations))
77 # Build dependency graph
78 dependency_graph = self._build_dependency_graph(invocations)
80 # Topological sort for execution order
81 execution_order = self._topological_sort(dependency_graph)
83 # Group by dependency level (all items in same level can run in parallel)
84 levels = self._group_by_level(execution_order, dependency_graph, invocations)
86 span.set_attribute("parallelization_levels", len(levels))
88 logger.info(
89 "Executing tools in parallel",
90 extra={
91 "total_tools": len(invocations),
92 "levels": len(levels),
93 "max_parallelism": self.max_parallelism,
94 },
95 )
97 # Execute level by level
98 all_results = {}
100 for level_idx, level_invocations in enumerate(levels):
101 logger.info(f"Executing level {level_idx + 1}/{len(levels)} with {len(level_invocations)} tools")
103 # Execute all invocations in this level in parallel
104 tasks = [self._execute_single(inv, tool_executor) for inv in level_invocations]
106 level_results = await asyncio.gather(*tasks, return_exceptions=True)
108 # Store results
109 for inv, result in zip(level_invocations, level_results, strict=False):
110 all_results[inv.invocation_id] = result
112 # Convert to list maintaining original order
113 results = [all_results[inv.invocation_id] for inv in invocations]
115 # Calculate metrics
116 successful = sum(1 for r in results if isinstance(r, ToolResult) and r.error is None)
117 failed = len(results) - successful
119 span.set_attribute("successful", successful)
120 span.set_attribute("failed", failed)
122 metrics.successful_calls.add(successful, {"operation": "parallel_tool_execution"})
123 if failed > 0:
124 metrics.failed_calls.add(failed, {"operation": "parallel_tool_execution"})
126 return results # type: ignore[return-value]
128 async def _execute_single(self, invocation: ToolInvocation, tool_executor: Callable[..., Any]) -> ToolResult:
129 """Execute a single tool invocation with optional timeout."""
130 async with self.semaphore: # Limit concurrency
131 start_time = time.time()
133 try:
134 # Apply timeout if configured
135 if self.task_timeout_seconds is not None:
136 result = await asyncio.wait_for(
137 tool_executor(invocation.tool_name, invocation.arguments), timeout=self.task_timeout_seconds
138 )
139 else:
140 result = await tool_executor(invocation.tool_name, invocation.arguments)
142 duration_ms = (time.time() - start_time) * 1000
144 return ToolResult(
145 invocation_id=invocation.invocation_id,
146 tool_name=invocation.tool_name,
147 result=result,
148 duration_ms=duration_ms,
149 )
151 except TimeoutError:
152 duration_ms = (time.time() - start_time) * 1000
153 timeout_error = TimeoutError(f"Tool '{invocation.tool_name}' exceeded timeout of {self.task_timeout_seconds}s")
154 logger.warning(
155 f"Tool execution timeout: {invocation.tool_name}",
156 extra={"timeout_seconds": self.task_timeout_seconds, "duration_ms": duration_ms},
157 )
159 return ToolResult(
160 invocation_id=invocation.invocation_id,
161 tool_name=invocation.tool_name,
162 result=None,
163 error=timeout_error,
164 duration_ms=duration_ms,
165 )
167 except Exception as e:
168 duration_ms = (time.time() - start_time) * 1000
169 logger.error(
170 f"Tool execution failed: {invocation.tool_name}",
171 extra={"error": str(e)},
172 exc_info=True,
173 )
175 return ToolResult(
176 invocation_id=invocation.invocation_id,
177 tool_name=invocation.tool_name,
178 result=None,
179 error=e,
180 duration_ms=duration_ms,
181 )
183 def _build_dependency_graph(self, invocations: list[ToolInvocation]) -> dict[str, list[str]]:
184 """Build dependency graph from invocations."""
185 graph = {inv.invocation_id: inv.dependencies for inv in invocations}
186 return graph
188 def _topological_sort(self, graph: dict[str, list[str]]) -> list[str]:
189 """
190 Topological sort of dependency graph.
192 Returns nodes in an order where dependencies come before dependents.
193 """
194 # Calculate in-degrees (number of dependencies each node has)
195 in_degree = {node: len(graph[node]) for node in graph}
197 # Queue nodes with no dependencies (in-degree = 0)
198 queue = deque([node for node in graph if in_degree[node] == 0])
199 sorted_order = []
201 while queue:
202 node = queue.popleft()
203 sorted_order.append(node)
205 # Find nodes that depend on this node and reduce their in-degree
206 for other_node in graph:
207 if node in graph[other_node]: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 in_degree[other_node] -= 1
209 if in_degree[other_node] == 0:
210 queue.append(other_node)
212 # Check for circular dependencies
213 if len(sorted_order) != len(graph): 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 msg = "Circular dependency detected in tool execution graph"
215 raise ValueError(msg)
217 return sorted_order
219 def _group_by_level(
220 self, sorted_nodes: list[str], graph: dict[str, list[str]], invocations: list[ToolInvocation]
221 ) -> list[list[ToolInvocation]]:
222 """Group nodes by dependency level for parallel execution."""
223 # Create invocation lookup
224 inv_lookup = {inv.invocation_id: inv for inv in invocations}
226 levels = []
227 processed: set[str] = set()
229 while processed != set(sorted_nodes):
230 current_level = []
232 for node in sorted_nodes:
233 if node in processed: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true
234 continue
236 # Check if all dependencies are processed
237 deps_satisfied = all(dep in processed for dep in graph[node])
239 if deps_satisfied: 239 ↛ 232line 239 didn't jump to line 232 because the condition on line 239 was always true
240 current_level.append(node)
242 if current_level: 242 ↛ 249line 242 didn't jump to line 249 because the condition on line 242 was always true
243 # Convert node IDs to ToolInvocations
244 level_invocations = [inv_lookup[node_id] for node_id in current_level]
245 levels.append(level_invocations)
246 processed.update(current_level)
247 else:
248 # Circular dependency or error
249 break
251 return levels
254# Example usage function
255async def execute_multi_tool_request(user_request: str, tool_calls: list[dict[str, Any]]) -> list[ToolResult]:
256 """
257 Execute multiple tool calls efficiently with parallelization.
259 Example:
260 user_request = "Search for Python and JavaScript, then compare them"
261 tool_calls = [
262 {"tool": "search", "args": {"query": "Python"}},
263 {"tool": "search", "args": {"query": "JavaScript"}},
264 {"tool": "compare", "args": {"items": ["result_1", "result_2"]}, "deps": ["1", "2"]}
265 ]
267 Args:
268 user_request: The original user request
269 tool_calls: List of tool call specifications
271 Returns:
272 List of tool results
273 """
274 executor = ParallelToolExecutor(max_parallelism=5)
276 # Convert to invocations
277 invocations = []
278 for i, call in enumerate(tool_calls):
279 inv = ToolInvocation(
280 tool_name=call["tool"],
281 arguments=call["args"],
282 invocation_id=str(i + 1),
283 dependencies=call.get("deps", []),
284 )
285 invocations.append(inv)
287 # Mock tool executor for demonstration
288 async def mock_tool_executor(name: str, args: dict) -> Any: # type: ignore[type-arg]
289 await asyncio.sleep(0.1) # Simulate work
290 return f"Result from {name}"
292 results = await executor.execute_parallel(invocations, mock_tool_executor)
294 return results