Coverage for src / mcp_server_langgraph / patterns / swarm.py: 93%
100 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Swarm Pattern for Parallel Multi-Agent Execution
4The swarm pattern executes multiple agents in parallel on the same task
5and aggregates their results for a comprehensive answer.
7Architecture:
8 User Query → [Agent 1, Agent 2, Agent 3, ...] (parallel) → Aggregator → Response
10Use Cases:
11 - Consensus building (multiple agents analyze same problem)
12 - Diverse perspectives (different approaches to same task)
13 - Redundancy and validation (cross-checking results)
14 - Parallel research (multiple sources simultaneously)
16Example:
17 from mcp_server_langgraph.patterns import Swarm
19 swarm = Swarm(
20 agents={
21 "gemini": gemini_agent,
22 "claude": claude_agent,
23 "gpt4": gpt4_agent
24 },
25 aggregation_strategy="consensus"
26 )
28 result = swarm.invoke({"query": "What is the capital of France?"})
29"""
31from collections.abc import Callable
32from typing import Annotated, Any, Literal
34from langchain_core.runnables import RunnableLambda
35from langgraph.graph import StateGraph
36from pydantic import BaseModel, Field
39def merge_agent_results(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]:
40 """Merge agent results dictionaries for concurrent updates."""
41 result = left.copy()
42 result.update(right)
43 return result
46class SwarmState(BaseModel):
47 """State for swarm pattern."""
49 query: str = Field(description="The query/task for all agents")
50 agent_results: Annotated[dict[str, Any], merge_agent_results] = Field(
51 default_factory=dict, description="Results from each agent"
52 )
53 aggregated_result: str = Field(default="", description="Aggregated final result")
54 consensus_score: float = Field(default=0.0, description="Agreement level between agents (0-1)")
55 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
58class Swarm:
59 """
60 Swarm pattern for parallel multi-agent execution.
62 All agents work on the same task simultaneously, results are aggregated.
63 """
65 def __init__(
66 self,
67 agents: dict[str, Callable[[str], Any]],
68 aggregation_strategy: Literal["consensus", "voting", "synthesis", "concatenate"] = "synthesis",
69 min_agreement: float = 0.7,
70 ):
71 """
72 Initialize swarm.
74 Args:
75 agents: Dictionary of {agent_name: agent_function}
76 aggregation_strategy: How to combine results
77 - consensus: Find common ground between all agents
78 - voting: Majority vote on discrete choices
79 - synthesis: LLM synthesizes all perspectives
80 - concatenate: Simple combination of all outputs
81 min_agreement: Minimum agreement threshold for consensus (0-1)
82 """
83 self.agents = agents
84 self.aggregation_strategy = aggregation_strategy
85 self.min_agreement = min_agreement
86 self._graph: StateGraph[SwarmState] | None = None
88 def _create_agent_wrapper(
89 self, agent_name: str, agent_func: Callable[[str], Any]
90 ) -> Callable[[SwarmState], dict[str, Any]]:
91 """
92 Create wrapper for agent execution.
94 Args:
95 agent_name: Name of the agent
96 agent_func: Agent function
98 Returns:
99 Wrapped function that stores result in state
100 """
102 def agent_node(state: SwarmState) -> dict[str, Any]:
103 """Execute agent and store result."""
104 try:
105 result = agent_func(state.query)
106 # Return only the updated field to avoid concurrent update conflicts
107 updated_results = state.agent_results.copy()
108 updated_results[agent_name] = result
109 return {"agent_results": updated_results}
110 except Exception as e:
111 updated_results = state.agent_results.copy()
112 updated_results[agent_name] = f"Error: {e!s}"
113 return {"agent_results": updated_results}
115 return agent_node
117 def _calculate_consensus(self, results: list[str]) -> float:
118 """
119 Calculate consensus score between results.
121 Simple implementation based on common words.
122 In production, use semantic similarity with embeddings.
124 Args:
125 results: List of agent results
127 Returns:
128 Consensus score (0-1)
129 """
130 if len(results) < 2:
131 return 1.0
133 # Extract words from all results
134 all_words = []
135 for result in results:
136 words = set(result.lower().split())
137 all_words.append(words)
139 # Calculate pairwise similarity
140 similarities = []
141 for i in range(len(all_words)):
142 for j in range(i + 1, len(all_words)):
143 intersection = all_words[i] & all_words[j]
144 union = all_words[i] | all_words[j]
145 similarity = len(intersection) / len(union) if union else 0
146 similarities.append(similarity)
148 return sum(similarities) / len(similarities) if similarities else 0.0
150 def _aggregate_results(self, state: SwarmState) -> dict[str, Any]:
151 """
152 Aggregate results from all agents.
154 Args:
155 state: Current state with agent results
157 Returns:
158 Dict with updated aggregated_result and consensus_score fields
159 """
160 results = list(state.agent_results.values())
161 aggregated_result = ""
162 consensus_score = 0.0
164 if self.aggregation_strategy == "concatenate": 164 ↛ 166line 164 didn't jump to line 166 because the condition on line 164 was never true
165 # Simple concatenation
166 formatted_results = []
167 for agent_name, result in state.agent_results.items():
168 formatted_results.append(f"**{agent_name.title()}:**\n{result}")
169 aggregated_result = "\n\n".join(formatted_results)
171 elif self.aggregation_strategy == "consensus":
172 # Find consensus
173 consensus_score = self._calculate_consensus(results)
175 if consensus_score >= self.min_agreement:
176 # High agreement - present consensus
177 aggregated_result = f"**Consensus (agreement: {consensus_score:.0%}):**\n\n"
178 aggregated_result += f"All agents agree:\n{results[0]}"
179 else:
180 # Low agreement - present all viewpoints
181 aggregated_result = f"**Multiple Perspectives (agreement: {consensus_score:.0%}):**\n\n"
182 for agent_name, result in state.agent_results.items():
183 aggregated_result += f"**{agent_name.title()}:** {result}\n\n"
185 elif self.aggregation_strategy == "synthesis":
186 # Synthesize all results (simplified - in production use LLM)
187 aggregated_result = "**Synthesized Response:**\n\n"
188 aggregated_result += "Based on analysis from multiple agents:\n\n"
190 for agent_name, result in state.agent_results.items():
191 # Extract key points (simplified)
192 aggregated_result += f"• From {agent_name}: {result[:100]}...\n"
194 consensus_score = self._calculate_consensus(results)
196 elif self.aggregation_strategy == "voting": 196 ↛ 206line 196 didn't jump to line 206 because the condition on line 196 was always true
197 # Simple voting (count identical responses)
198 from collections import Counter
200 vote_counter = Counter(results)
201 winner, count = vote_counter.most_common(1)[0]
203 aggregated_result = f"**Voting Result ({count}/{len(results)} agents agree):**\n\n{winner}"
204 consensus_score = count / len(results)
206 return {"aggregated_result": aggregated_result, "consensus_score": consensus_score}
208 def build(self) -> "StateGraph[SwarmState]":
209 """Build the swarm graph."""
210 from langgraph.graph import END, START
212 graph: StateGraph[SwarmState] = StateGraph(SwarmState)
214 # Add all agent nodes (wrapped with RunnableLambda for type safety)
215 for agent_name, agent_func in self.agents.items():
216 agent_wrapper = self._create_agent_wrapper(agent_name, agent_func)
217 graph.add_node(agent_name, RunnableLambda(agent_wrapper))
219 # Add aggregator (wrapped with RunnableLambda for type safety)
220 graph.add_node("aggregate", RunnableLambda(self._aggregate_results))
222 # All agents run in parallel from start
223 for agent_name in self.agents:
224 graph.add_edge(START, agent_name)
226 # All agents feed into aggregator
227 for agent_name in self.agents:
228 graph.add_edge(agent_name, "aggregate")
230 # Aggregator is the end
231 graph.add_edge("aggregate", END)
233 return graph
235 def compile(self, checkpointer: Any = None) -> Any:
236 """
237 Compile the swarm graph.
239 Args:
240 checkpointer: Optional checkpointer
242 Returns:
243 Compiled graph
244 """
245 if self._graph is None: 245 ↛ 248line 245 didn't jump to line 248 because the condition on line 245 was always true
246 self._graph = self.build()
248 return self._graph.compile(checkpointer=checkpointer)
250 def invoke(self, query: str, config: dict[str, Any] | None = None) -> dict[str, Any]:
251 """
252 Execute the swarm pattern.
254 Args:
255 query: Query for all agents
256 config: Optional configuration
258 Returns:
259 Aggregated results
260 """
261 compiled = self.compile()
262 state = SwarmState(query=query)
264 result = compiled.invoke(state, config=config or {})
266 return {
267 "query": result["query"],
268 "aggregated_result": result["aggregated_result"],
269 "agent_results": result["agent_results"],
270 "consensus_score": result["consensus_score"],
271 "num_agents": len(result["agent_results"]),
272 }
275# Example usage and testing
276if __name__ == "__main__":
277 # Define diverse agents with different perspectives
278 def optimistic_agent(query: str) -> str:
279 """Optimistic perspective."""
280 return f"Analysis of '{query}': This is a great opportunity with many positive aspects and potential benefits."
282 def pessimistic_agent(query: str) -> str:
283 """Pessimistic perspective."""
284 return f"Analysis of '{query}': There are significant risks and challenges that need careful consideration."
286 def neutral_agent(query: str) -> str:
287 """Balanced perspective."""
288 return f"Analysis of '{query}': There are both advantages and disadvantages that should be weighed carefully."
290 # Test different aggregation strategies
291 strategies = ["concatenate", "consensus", "synthesis", "voting"]
293 for strategy in strategies:
294 print("=" * 80)
295 print(f"SWARM PATTERN - {strategy.upper()} STRATEGY")
296 print("=" * 80)
298 swarm = Swarm(
299 agents={"optimist": optimistic_agent, "pessimist": pessimistic_agent, "neutral": neutral_agent},
300 aggregation_strategy=strategy, # type: ignore
301 )
303 result = swarm.invoke("Adopting AI in our business")
305 print(f"\nQuery: {result['query']}")
306 print(f"Number of Agents: {result['num_agents']}")
307 print(f"Consensus Score: {result['consensus_score']:.0%}")
308 print(f"\nAGGREGATED RESULT:\n{result['aggregated_result']}")
309 print()
311 print("=" * 80)