Coverage for src / mcp_server_langgraph / patterns / supervisor.py: 90%
87 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Supervisor Pattern for Multi-Agent Coordination
4The supervisor pattern uses one coordinating agent that delegates tasks
5to specialized worker agents based on the task requirements.
7Architecture:
8 User Query → Supervisor → Routes to Worker → Result → Supervisor → Response
10Use Cases:
11 - Customer support (routing to billing, technical, sales specialists)
12 - Research (delegating to search, analysis, writing specialists)
13 - Code review (routing to security, performance, style checkers)
15Example:
16 from mcp_server_langgraph.patterns import Supervisor
18 supervisor = Supervisor(
19 agents={
20 "research": research_agent,
21 "writer": writer_agent,
22 "reviewer": reviewer_agent
23 },
24 routing_strategy="sequential"
25 )
27 result = supervisor.invoke({"task": "Write a research report"})
28"""
30from collections.abc import Callable
31from typing import Any, Literal
33from langgraph.graph import StateGraph
34from pydantic import BaseModel, Field
37class SupervisorState(BaseModel):
38 """State for supervisor pattern."""
40 task: str = Field(description="The task to accomplish")
41 current_agent: str | None = Field(default=None, description="Currently active agent")
42 agent_history: list[str] = Field(default_factory=list, description="Sequence of agents executed")
43 agent_results: dict[str, Any] = Field(default_factory=dict, description="Results from each agent")
44 next_agent: str | None = Field(default=None, description="Next agent to execute")
45 final_result: str = Field(default="", description="Final aggregated result")
46 routing_decision: str = Field(default="", description="Supervisor's routing decision")
47 completed: bool = Field(default=False, description="Whether task is completed")
50class Supervisor:
51 """
52 Supervisor pattern for multi-agent coordination.
54 One coordinating agent delegates work to specialized workers.
55 """
57 def __init__(
58 self,
59 agents: dict[str, Callable[[str], Any]],
60 routing_strategy: Literal["sequential", "conditional", "parallel"] = "conditional",
61 supervisor_prompt: str | None = None,
62 ):
63 """
64 Initialize supervisor.
66 Args:
67 agents: Dictionary of {agent_name: agent_function}
68 routing_strategy: How to route between agents
69 - sequential: Execute all agents in order
70 - conditional: Supervisor decides next agent
71 - parallel: Execute agents in parallel (future)
72 supervisor_prompt: Custom prompt for supervisor's routing logic
73 """
74 self.agents = agents
75 self.routing_strategy = routing_strategy
76 self.supervisor_prompt = supervisor_prompt or self._default_supervisor_prompt()
77 self._graph: StateGraph[SupervisorState] | None = None
79 def _default_supervisor_prompt(self) -> str:
80 """Default supervisor prompt."""
81 agent_list = ", ".join(self.agents.keys())
82 return f"""You are a supervisor coordinating a team of AI agents: {agent_list}.
84Your job is to:
851. Analyze the user's task
862. Decide which agent(s) should handle it
873. Route the task to the appropriate agent
884. Aggregate results from multiple agents if needed
895. Provide a final response to the user
91Available agents: {agent_list}
93Choose the best agent for each sub-task."""
95 def _create_supervisor_node(self, state: SupervisorState) -> SupervisorState:
96 """
97 Supervisor decision node.
99 Analyzes task and decides which agent to route to.
100 In production, this would use an LLM for intelligent routing.
101 """
102 # Simplified routing logic (in production, use LLM)
103 task_lower = state.task.lower()
105 # Simple keyword-based routing
106 if "research" in task_lower or "search" in task_lower or "find" in task_lower: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 state.next_agent = "research" if "research" in self.agents else list(self.agents.keys())[0]
108 elif "write" in task_lower or "create" in task_lower or "draft" in task_lower: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 state.next_agent = "writer" if "writer" in self.agents else list(self.agents.keys())[0]
110 elif "review" in task_lower or "check" in task_lower or "analyze" in task_lower: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 state.next_agent = "reviewer" if "reviewer" in self.agents else list(self.agents.keys())[0]
112 else:
113 # Default to first available agent
114 state.next_agent = list(self.agents.keys())[0]
116 state.routing_decision = f"Routing to {state.next_agent} based on task analysis"
118 return state
120 def _create_worker_wrapper(
121 self, agent_name: str, agent_func: Callable[[str], Any]
122 ) -> Callable[[SupervisorState], SupervisorState]:
123 """
124 Create a wrapper for worker agent.
126 Args:
127 agent_name: Name of the agent
128 agent_func: Agent function to wrap
130 Returns:
131 Wrapped function that updates SupervisorState
132 """
134 def worker_node(state: SupervisorState) -> SupervisorState:
135 """Execute worker agent."""
136 # Call the actual agent function with error handling
137 try:
138 result = agent_func(state.task)
139 state.agent_results[agent_name] = result
140 except Exception as e:
141 # Handle agent errors gracefully
142 state.agent_results[agent_name] = f"Error: {e!s}"
144 # Track agent execution
145 state.agent_history.append(agent_name)
146 state.current_agent = agent_name
148 # In sequential mode, move to next agent
149 if self.routing_strategy == "sequential":
150 agent_keys = list(self.agents.keys())
151 current_idx = agent_keys.index(agent_name)
152 if current_idx < len(agent_keys) - 1: 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 state.next_agent = agent_keys[current_idx + 1]
154 else:
155 state.next_agent = "aggregate"
156 else:
157 # In conditional mode, return to supervisor
158 state.next_agent = "aggregate"
160 return state
162 return worker_node
164 def _create_aggregator_node(self, state: SupervisorState) -> SupervisorState:
165 """
166 Aggregate results from all workers.
168 In production, use LLM to synthesize results.
169 """
170 # Simple aggregation: combine all results
171 results_summary = []
172 for agent_name in state.agent_history:
173 result = state.agent_results.get(agent_name, "")
174 results_summary.append(f"**{agent_name.title()}:** {result}")
176 state.final_result = "\n\n".join(results_summary)
177 state.completed = True
179 return state
181 def _routing_function(self, state: SupervisorState) -> str:
182 """Determine next node based on state."""
183 if state.next_agent == "aggregate":
184 return "aggregate"
185 elif state.next_agent: 185 ↛ 188line 185 didn't jump to line 188 because the condition on line 185 was always true
186 return state.next_agent
187 else:
188 return "supervisor"
190 def build(self) -> "StateGraph[SupervisorState]":
191 """Build the supervisor graph."""
192 graph: StateGraph[SupervisorState] = StateGraph(SupervisorState)
194 # Add supervisor node
195 graph.add_node("supervisor", self._create_supervisor_node)
197 # Add worker nodes
198 for agent_name, agent_func in self.agents.items():
199 worker_node = self._create_worker_wrapper(agent_name, agent_func)
200 graph.add_node(agent_name, worker_node) # type: ignore[arg-type]
202 # Add aggregator node
203 graph.add_node("aggregate", self._create_aggregator_node)
205 # Define edges
206 graph.set_entry_point("supervisor")
208 # Supervisor routes to workers
209 graph.add_conditional_edges("supervisor", self._routing_function)
211 # Workers route based on strategy
212 for agent_name in self.agents:
213 graph.add_conditional_edges(agent_name, self._routing_function)
215 # Aggregator is the end
216 graph.set_finish_point("aggregate")
218 return graph
220 def compile(self, checkpointer: Any = None) -> Any:
221 """
222 Compile the supervisor graph.
224 Args:
225 checkpointer: Optional checkpointer for state persistence
227 Returns:
228 Compiled LangGraph application
229 """
230 if self._graph is None: 230 ↛ 233line 230 didn't jump to line 233 because the condition on line 230 was always true
231 self._graph = self.build()
233 return self._graph.compile(checkpointer=checkpointer)
235 def invoke(self, task: str, config: dict[str, Any] | None = None) -> dict[str, Any]:
236 """
237 Execute the supervisor pattern.
239 Args:
240 task: Task description
241 config: Optional configuration
243 Returns:
244 Results including final_result and agent_history
245 """
246 compiled = self.compile()
247 state = SupervisorState(task=task)
249 result = compiled.invoke(state, config=config or {})
251 return {
252 "task": result["task"],
253 "final_result": result["final_result"],
254 "agent_history": result["agent_history"],
255 "agent_results": result["agent_results"],
256 "routing_decision": result["routing_decision"],
257 }
260# Example usage and testing
261if __name__ == "__main__":
262 # Define simple worker agents
263 def research_agent(task: str) -> str:
264 """Research agent."""
265 return f"Research findings for: {task}\n- Found relevant papers\n- Analyzed data sources"
267 def writer_agent(task: str) -> str:
268 """Writer agent."""
269 return f"Written content for: {task}\n- Created outline\n- Drafted sections"
271 def reviewer_agent(task: str) -> str:
272 """Reviewer agent."""
273 return f"Review feedback for: {task}\n- Checked accuracy\n- Suggested improvements"
275 # Create supervisor
276 supervisor = Supervisor(
277 agents={"research": research_agent, "writer": writer_agent, "reviewer": reviewer_agent},
278 routing_strategy="sequential",
279 )
281 # Test cases
282 test_tasks = [
283 "Research the latest AI trends",
284 "Write a blog post about LangGraph",
285 "Review this code for security issues",
286 ]
288 print("=" * 80)
289 print("SUPERVISOR PATTERN - TEST RUN")
290 print("=" * 80)
292 for task in test_tasks:
293 print(f"\n{'=' * 80}")
294 print(f"TASK: {task}")
295 print(f"{'=' * 80}")
297 result = supervisor.invoke(task)
299 print(f"\nRouting Decision: {result['routing_decision']}")
300 print(f"Agent History: {' → '.join(result['agent_history'])}")
301 print(f"\nFINAL RESULT:\n{result['final_result']}")
303 print(f"\n{'=' * 80}\n")