Coverage for src / mcp_server_langgraph / patterns / supervisor.py: 90%

87 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Supervisor Pattern for Multi-Agent Coordination 

3 

4The supervisor pattern uses one coordinating agent that delegates tasks 

5to specialized worker agents based on the task requirements. 

6 

7Architecture: 

8 User Query → Supervisor → Routes to Worker → Result → Supervisor → Response 

9 

10Use Cases: 

11 - Customer support (routing to billing, technical, sales specialists) 

12 - Research (delegating to search, analysis, writing specialists) 

13 - Code review (routing to security, performance, style checkers) 

14 

15Example: 

16 from mcp_server_langgraph.patterns import Supervisor 

17 

18 supervisor = Supervisor( 

19 agents={ 

20 "research": research_agent, 

21 "writer": writer_agent, 

22 "reviewer": reviewer_agent 

23 }, 

24 routing_strategy="sequential" 

25 ) 

26 

27 result = supervisor.invoke({"task": "Write a research report"}) 

28""" 

29 

30from collections.abc import Callable 

31from typing import Any, Literal 

32 

33from langgraph.graph import StateGraph 

34from pydantic import BaseModel, Field 

35 

36 

37class SupervisorState(BaseModel): 

38 """State for supervisor pattern.""" 

39 

40 task: str = Field(description="The task to accomplish") 

41 current_agent: str | None = Field(default=None, description="Currently active agent") 

42 agent_history: list[str] = Field(default_factory=list, description="Sequence of agents executed") 

43 agent_results: dict[str, Any] = Field(default_factory=dict, description="Results from each agent") 

44 next_agent: str | None = Field(default=None, description="Next agent to execute") 

45 final_result: str = Field(default="", description="Final aggregated result") 

46 routing_decision: str = Field(default="", description="Supervisor's routing decision") 

47 completed: bool = Field(default=False, description="Whether task is completed") 

48 

49 

50class Supervisor: 

51 """ 

52 Supervisor pattern for multi-agent coordination. 

53 

54 One coordinating agent delegates work to specialized workers. 

55 """ 

56 

57 def __init__( 

58 self, 

59 agents: dict[str, Callable[[str], Any]], 

60 routing_strategy: Literal["sequential", "conditional", "parallel"] = "conditional", 

61 supervisor_prompt: str | None = None, 

62 ): 

63 """ 

64 Initialize supervisor. 

65 

66 Args: 

67 agents: Dictionary of {agent_name: agent_function} 

68 routing_strategy: How to route between agents 

69 - sequential: Execute all agents in order 

70 - conditional: Supervisor decides next agent 

71 - parallel: Execute agents in parallel (future) 

72 supervisor_prompt: Custom prompt for supervisor's routing logic 

73 """ 

74 self.agents = agents 

75 self.routing_strategy = routing_strategy 

76 self.supervisor_prompt = supervisor_prompt or self._default_supervisor_prompt() 

77 self._graph: StateGraph[SupervisorState] | None = None 

78 

79 def _default_supervisor_prompt(self) -> str: 

80 """Default supervisor prompt.""" 

81 agent_list = ", ".join(self.agents.keys()) 

82 return f"""You are a supervisor coordinating a team of AI agents: {agent_list}. 

83 

84Your job is to: 

851. Analyze the user's task 

862. Decide which agent(s) should handle it 

873. Route the task to the appropriate agent 

884. Aggregate results from multiple agents if needed 

895. Provide a final response to the user 

90 

91Available agents: {agent_list} 

92 

93Choose the best agent for each sub-task.""" 

94 

95 def _create_supervisor_node(self, state: SupervisorState) -> SupervisorState: 

96 """ 

97 Supervisor decision node. 

98 

99 Analyzes task and decides which agent to route to. 

100 In production, this would use an LLM for intelligent routing. 

101 """ 

102 # Simplified routing logic (in production, use LLM) 

103 task_lower = state.task.lower() 

104 

105 # Simple keyword-based routing 

106 if "research" in task_lower or "search" in task_lower or "find" in task_lower: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 state.next_agent = "research" if "research" in self.agents else list(self.agents.keys())[0] 

108 elif "write" in task_lower or "create" in task_lower or "draft" in task_lower: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true

109 state.next_agent = "writer" if "writer" in self.agents else list(self.agents.keys())[0] 

110 elif "review" in task_lower or "check" in task_lower or "analyze" in task_lower: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 state.next_agent = "reviewer" if "reviewer" in self.agents else list(self.agents.keys())[0] 

112 else: 

113 # Default to first available agent 

114 state.next_agent = list(self.agents.keys())[0] 

115 

116 state.routing_decision = f"Routing to {state.next_agent} based on task analysis" 

117 

118 return state 

119 

120 def _create_worker_wrapper( 

121 self, agent_name: str, agent_func: Callable[[str], Any] 

122 ) -> Callable[[SupervisorState], SupervisorState]: 

123 """ 

124 Create a wrapper for worker agent. 

125 

126 Args: 

127 agent_name: Name of the agent 

128 agent_func: Agent function to wrap 

129 

130 Returns: 

131 Wrapped function that updates SupervisorState 

132 """ 

133 

134 def worker_node(state: SupervisorState) -> SupervisorState: 

135 """Execute worker agent.""" 

136 # Call the actual agent function with error handling 

137 try: 

138 result = agent_func(state.task) 

139 state.agent_results[agent_name] = result 

140 except Exception as e: 

141 # Handle agent errors gracefully 

142 state.agent_results[agent_name] = f"Error: {e!s}" 

143 

144 # Track agent execution 

145 state.agent_history.append(agent_name) 

146 state.current_agent = agent_name 

147 

148 # In sequential mode, move to next agent 

149 if self.routing_strategy == "sequential": 

150 agent_keys = list(self.agents.keys()) 

151 current_idx = agent_keys.index(agent_name) 

152 if current_idx < len(agent_keys) - 1: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 state.next_agent = agent_keys[current_idx + 1] 

154 else: 

155 state.next_agent = "aggregate" 

156 else: 

157 # In conditional mode, return to supervisor 

158 state.next_agent = "aggregate" 

159 

160 return state 

161 

162 return worker_node 

163 

164 def _create_aggregator_node(self, state: SupervisorState) -> SupervisorState: 

165 """ 

166 Aggregate results from all workers. 

167 

168 In production, use LLM to synthesize results. 

169 """ 

170 # Simple aggregation: combine all results 

171 results_summary = [] 

172 for agent_name in state.agent_history: 

173 result = state.agent_results.get(agent_name, "") 

174 results_summary.append(f"**{agent_name.title()}:** {result}") 

175 

176 state.final_result = "\n\n".join(results_summary) 

177 state.completed = True 

178 

179 return state 

180 

181 def _routing_function(self, state: SupervisorState) -> str: 

182 """Determine next node based on state.""" 

183 if state.next_agent == "aggregate": 

184 return "aggregate" 

185 elif state.next_agent: 185 ↛ 188line 185 didn't jump to line 188 because the condition on line 185 was always true

186 return state.next_agent 

187 else: 

188 return "supervisor" 

189 

190 def build(self) -> "StateGraph[SupervisorState]": 

191 """Build the supervisor graph.""" 

192 graph: StateGraph[SupervisorState] = StateGraph(SupervisorState) 

193 

194 # Add supervisor node 

195 graph.add_node("supervisor", self._create_supervisor_node) 

196 

197 # Add worker nodes 

198 for agent_name, agent_func in self.agents.items(): 

199 worker_node = self._create_worker_wrapper(agent_name, agent_func) 

200 graph.add_node(agent_name, worker_node) # type: ignore[arg-type] 

201 

202 # Add aggregator node 

203 graph.add_node("aggregate", self._create_aggregator_node) 

204 

205 # Define edges 

206 graph.set_entry_point("supervisor") 

207 

208 # Supervisor routes to workers 

209 graph.add_conditional_edges("supervisor", self._routing_function) 

210 

211 # Workers route based on strategy 

212 for agent_name in self.agents: 

213 graph.add_conditional_edges(agent_name, self._routing_function) 

214 

215 # Aggregator is the end 

216 graph.set_finish_point("aggregate") 

217 

218 return graph 

219 

220 def compile(self, checkpointer: Any = None) -> Any: 

221 """ 

222 Compile the supervisor graph. 

223 

224 Args: 

225 checkpointer: Optional checkpointer for state persistence 

226 

227 Returns: 

228 Compiled LangGraph application 

229 """ 

230 if self._graph is None: 230 ↛ 233line 230 didn't jump to line 233 because the condition on line 230 was always true

231 self._graph = self.build() 

232 

233 return self._graph.compile(checkpointer=checkpointer) 

234 

235 def invoke(self, task: str, config: dict[str, Any] | None = None) -> dict[str, Any]: 

236 """ 

237 Execute the supervisor pattern. 

238 

239 Args: 

240 task: Task description 

241 config: Optional configuration 

242 

243 Returns: 

244 Results including final_result and agent_history 

245 """ 

246 compiled = self.compile() 

247 state = SupervisorState(task=task) 

248 

249 result = compiled.invoke(state, config=config or {}) 

250 

251 return { 

252 "task": result["task"], 

253 "final_result": result["final_result"], 

254 "agent_history": result["agent_history"], 

255 "agent_results": result["agent_results"], 

256 "routing_decision": result["routing_decision"], 

257 } 

258 

259 

260# Example usage and testing 

261if __name__ == "__main__": 

262 # Define simple worker agents 

263 def research_agent(task: str) -> str: 

264 """Research agent.""" 

265 return f"Research findings for: {task}\n- Found relevant papers\n- Analyzed data sources" 

266 

267 def writer_agent(task: str) -> str: 

268 """Writer agent.""" 

269 return f"Written content for: {task}\n- Created outline\n- Drafted sections" 

270 

271 def reviewer_agent(task: str) -> str: 

272 """Reviewer agent.""" 

273 return f"Review feedback for: {task}\n- Checked accuracy\n- Suggested improvements" 

274 

275 # Create supervisor 

276 supervisor = Supervisor( 

277 agents={"research": research_agent, "writer": writer_agent, "reviewer": reviewer_agent}, 

278 routing_strategy="sequential", 

279 ) 

280 

281 # Test cases 

282 test_tasks = [ 

283 "Research the latest AI trends", 

284 "Write a blog post about LangGraph", 

285 "Review this code for security issues", 

286 ] 

287 

288 print("=" * 80) 

289 print("SUPERVISOR PATTERN - TEST RUN") 

290 print("=" * 80) 

291 

292 for task in test_tasks: 

293 print(f"\n{'=' * 80}") 

294 print(f"TASK: {task}") 

295 print(f"{'=' * 80}") 

296 

297 result = supervisor.invoke(task) 

298 

299 print(f"\nRouting Decision: {result['routing_decision']}") 

300 print(f"Agent History: {' → '.join(result['agent_history'])}") 

301 print(f"\nFINAL RESULT:\n{result['final_result']}") 

302 

303 print(f"\n{'=' * 80}\n")