Coverage for src / mcp_server_langgraph / core / parallel_executor.py: 79%

115 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Parallel Tool Execution 

3 

4Implements Anthropic's parallelization pattern for independent operations. 

5""" 

6 

7import asyncio 

8import time 

9from collections import deque 

10from collections.abc import Callable 

11from dataclasses import dataclass, field 

12from typing import Any 

13 

14from mcp_server_langgraph.observability.telemetry import logger, metrics, tracer 

15 

16 

17@dataclass 

18class ToolInvocation: 

19 """Represents a tool invocation.""" 

20 

21 tool_name: str 

22 arguments: dict[str, Any] 

23 invocation_id: str 

24 dependencies: list[str] = field(default_factory=list) # IDs of tools this depends on 

25 

26 

27@dataclass 

28class ToolResult: 

29 """Result from tool execution.""" 

30 

31 invocation_id: str 

32 tool_name: str 

33 result: Any 

34 error: Exception | None = None 

35 duration_ms: float = 0.0 

36 

37 

38class ParallelToolExecutor: 

39 """ 

40 Executes tools in parallel when they have no dependencies. 

41 

42 Implements Anthropic's parallelization pattern: 

43 - Detects independent operations 

44 - Schedules parallel execution 

45 - Respects dependencies 

46 - Aggregates results 

47 """ 

48 

49 def __init__(self, max_parallelism: int = 5, task_timeout_seconds: float | None = None) -> None: 

50 """ 

51 Initialize parallel executor. 

52 

53 Args: 

54 max_parallelism: Maximum concurrent tool executions 

55 task_timeout_seconds: Optional timeout for each task (None = no timeout) 

56 """ 

57 self.max_parallelism = max_parallelism 

58 self.task_timeout_seconds = task_timeout_seconds 

59 self.semaphore = asyncio.Semaphore(max_parallelism) 

60 

61 async def execute_parallel( 

62 self, invocations: list[ToolInvocation], tool_executor: Callable[[str, dict[str, Any]], Any] 

63 ) -> list[ToolResult]: 

64 """ 

65 Execute tool invocations in parallel where possible. 

66 

67 Args: 

68 invocations: List of tool invocations 

69 tool_executor: Async function to execute a single tool 

70 

71 Returns: 

72 List of tool results 

73 """ 

74 with tracer.start_as_current_span("tools.parallel_execute") as span: 

75 span.set_attribute("total_invocations", len(invocations)) 

76 

77 # Build dependency graph 

78 dependency_graph = self._build_dependency_graph(invocations) 

79 

80 # Topological sort for execution order 

81 execution_order = self._topological_sort(dependency_graph) 

82 

83 # Group by dependency level (all items in same level can run in parallel) 

84 levels = self._group_by_level(execution_order, dependency_graph, invocations) 

85 

86 span.set_attribute("parallelization_levels", len(levels)) 

87 

88 logger.info( 

89 "Executing tools in parallel", 

90 extra={ 

91 "total_tools": len(invocations), 

92 "levels": len(levels), 

93 "max_parallelism": self.max_parallelism, 

94 }, 

95 ) 

96 

97 # Execute level by level 

98 all_results = {} 

99 

100 for level_idx, level_invocations in enumerate(levels): 

101 logger.info(f"Executing level {level_idx + 1}/{len(levels)} with {len(level_invocations)} tools") 

102 

103 # Execute all invocations in this level in parallel 

104 tasks = [self._execute_single(inv, tool_executor) for inv in level_invocations] 

105 

106 level_results = await asyncio.gather(*tasks, return_exceptions=True) 

107 

108 # Store results 

109 for inv, result in zip(level_invocations, level_results, strict=False): 

110 all_results[inv.invocation_id] = result 

111 

112 # Convert to list maintaining original order 

113 results = [all_results[inv.invocation_id] for inv in invocations] 

114 

115 # Calculate metrics 

116 successful = sum(1 for r in results if isinstance(r, ToolResult) and r.error is None) 

117 failed = len(results) - successful 

118 

119 span.set_attribute("successful", successful) 

120 span.set_attribute("failed", failed) 

121 

122 metrics.successful_calls.add(successful, {"operation": "parallel_tool_execution"}) 

123 if failed > 0: 

124 metrics.failed_calls.add(failed, {"operation": "parallel_tool_execution"}) 

125 

126 return results # type: ignore[return-value] 

127 

128 async def _execute_single(self, invocation: ToolInvocation, tool_executor: Callable[..., Any]) -> ToolResult: 

129 """Execute a single tool invocation with optional timeout.""" 

130 async with self.semaphore: # Limit concurrency 

131 start_time = time.time() 

132 

133 try: 

134 # Apply timeout if configured 

135 if self.task_timeout_seconds is not None: 

136 result = await asyncio.wait_for( 

137 tool_executor(invocation.tool_name, invocation.arguments), timeout=self.task_timeout_seconds 

138 ) 

139 else: 

140 result = await tool_executor(invocation.tool_name, invocation.arguments) 

141 

142 duration_ms = (time.time() - start_time) * 1000 

143 

144 return ToolResult( 

145 invocation_id=invocation.invocation_id, 

146 tool_name=invocation.tool_name, 

147 result=result, 

148 duration_ms=duration_ms, 

149 ) 

150 

151 except TimeoutError: 

152 duration_ms = (time.time() - start_time) * 1000 

153 timeout_error = TimeoutError(f"Tool '{invocation.tool_name}' exceeded timeout of {self.task_timeout_seconds}s") 

154 logger.warning( 

155 f"Tool execution timeout: {invocation.tool_name}", 

156 extra={"timeout_seconds": self.task_timeout_seconds, "duration_ms": duration_ms}, 

157 ) 

158 

159 return ToolResult( 

160 invocation_id=invocation.invocation_id, 

161 tool_name=invocation.tool_name, 

162 result=None, 

163 error=timeout_error, 

164 duration_ms=duration_ms, 

165 ) 

166 

167 except Exception as e: 

168 duration_ms = (time.time() - start_time) * 1000 

169 logger.error( 

170 f"Tool execution failed: {invocation.tool_name}", 

171 extra={"error": str(e)}, 

172 exc_info=True, 

173 ) 

174 

175 return ToolResult( 

176 invocation_id=invocation.invocation_id, 

177 tool_name=invocation.tool_name, 

178 result=None, 

179 error=e, 

180 duration_ms=duration_ms, 

181 ) 

182 

183 def _build_dependency_graph(self, invocations: list[ToolInvocation]) -> dict[str, list[str]]: 

184 """Build dependency graph from invocations.""" 

185 graph = {inv.invocation_id: inv.dependencies for inv in invocations} 

186 return graph 

187 

188 def _topological_sort(self, graph: dict[str, list[str]]) -> list[str]: 

189 """ 

190 Topological sort of dependency graph. 

191 

192 Returns nodes in an order where dependencies come before dependents. 

193 """ 

194 # Calculate in-degrees (number of dependencies each node has) 

195 in_degree = {node: len(graph[node]) for node in graph} 

196 

197 # Queue nodes with no dependencies (in-degree = 0) 

198 queue = deque([node for node in graph if in_degree[node] == 0]) 

199 sorted_order = [] 

200 

201 while queue: 

202 node = queue.popleft() 

203 sorted_order.append(node) 

204 

205 # Find nodes that depend on this node and reduce their in-degree 

206 for other_node in graph: 

207 if node in graph[other_node]: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 in_degree[other_node] -= 1 

209 if in_degree[other_node] == 0: 

210 queue.append(other_node) 

211 

212 # Check for circular dependencies 

213 if len(sorted_order) != len(graph): 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 msg = "Circular dependency detected in tool execution graph" 

215 raise ValueError(msg) 

216 

217 return sorted_order 

218 

219 def _group_by_level( 

220 self, sorted_nodes: list[str], graph: dict[str, list[str]], invocations: list[ToolInvocation] 

221 ) -> list[list[ToolInvocation]]: 

222 """Group nodes by dependency level for parallel execution.""" 

223 # Create invocation lookup 

224 inv_lookup = {inv.invocation_id: inv for inv in invocations} 

225 

226 levels = [] 

227 processed: set[str] = set() 

228 

229 while processed != set(sorted_nodes): 

230 current_level = [] 

231 

232 for node in sorted_nodes: 

233 if node in processed: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 continue 

235 

236 # Check if all dependencies are processed 

237 deps_satisfied = all(dep in processed for dep in graph[node]) 

238 

239 if deps_satisfied: 239 ↛ 232line 239 didn't jump to line 232 because the condition on line 239 was always true

240 current_level.append(node) 

241 

242 if current_level: 242 ↛ 249line 242 didn't jump to line 249 because the condition on line 242 was always true

243 # Convert node IDs to ToolInvocations 

244 level_invocations = [inv_lookup[node_id] for node_id in current_level] 

245 levels.append(level_invocations) 

246 processed.update(current_level) 

247 else: 

248 # Circular dependency or error 

249 break 

250 

251 return levels 

252 

253 

254# Example usage function 

255async def execute_multi_tool_request(user_request: str, tool_calls: list[dict[str, Any]]) -> list[ToolResult]: 

256 """ 

257 Execute multiple tool calls efficiently with parallelization. 

258 

259 Example: 

260 user_request = "Search for Python and JavaScript, then compare them" 

261 tool_calls = [ 

262 {"tool": "search", "args": {"query": "Python"}}, 

263 {"tool": "search", "args": {"query": "JavaScript"}}, 

264 {"tool": "compare", "args": {"items": ["result_1", "result_2"]}, "deps": ["1", "2"]} 

265 ] 

266 

267 Args: 

268 user_request: The original user request 

269 tool_calls: List of tool call specifications 

270 

271 Returns: 

272 List of tool results 

273 """ 

274 executor = ParallelToolExecutor(max_parallelism=5) 

275 

276 # Convert to invocations 

277 invocations = [] 

278 for i, call in enumerate(tool_calls): 

279 inv = ToolInvocation( 

280 tool_name=call["tool"], 

281 arguments=call["args"], 

282 invocation_id=str(i + 1), 

283 dependencies=call.get("deps", []), 

284 ) 

285 invocations.append(inv) 

286 

287 # Mock tool executor for demonstration 

288 async def mock_tool_executor(name: str, args: dict) -> Any: # type: ignore[type-arg] 

289 await asyncio.sleep(0.1) # Simulate work 

290 return f"Result from {name}" 

291 

292 results = await executor.execute_parallel(invocations, mock_tool_executor) 

293 

294 return results