Coverage for src / mcp_server_langgraph / patterns / swarm.py: 93%

100 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Swarm Pattern for Parallel Multi-Agent Execution 

3 

4The swarm pattern executes multiple agents in parallel on the same task 

5and aggregates their results for a comprehensive answer. 

6 

7Architecture: 

8 User Query → [Agent 1, Agent 2, Agent 3, ...] (parallel) → Aggregator → Response 

9 

10Use Cases: 

11 - Consensus building (multiple agents analyze same problem) 

12 - Diverse perspectives (different approaches to same task) 

13 - Redundancy and validation (cross-checking results) 

14 - Parallel research (multiple sources simultaneously) 

15 

16Example: 

17 from mcp_server_langgraph.patterns import Swarm 

18 

19 swarm = Swarm( 

20 agents={ 

21 "gemini": gemini_agent, 

22 "claude": claude_agent, 

23 "gpt4": gpt4_agent 

24 }, 

25 aggregation_strategy="consensus" 

26 ) 

27 

28 result = swarm.invoke({"query": "What is the capital of France?"}) 

29""" 

30 

31from collections.abc import Callable 

32from typing import Annotated, Any, Literal 

33 

34from langchain_core.runnables import RunnableLambda 

35from langgraph.graph import StateGraph 

36from pydantic import BaseModel, Field 

37 

38 

39def merge_agent_results(left: dict[str, Any], right: dict[str, Any]) -> dict[str, Any]: 

40 """Merge agent results dictionaries for concurrent updates.""" 

41 result = left.copy() 

42 result.update(right) 

43 return result 

44 

45 

46class SwarmState(BaseModel): 

47 """State for swarm pattern.""" 

48 

49 query: str = Field(description="The query/task for all agents") 

50 agent_results: Annotated[dict[str, Any], merge_agent_results] = Field( 

51 default_factory=dict, description="Results from each agent" 

52 ) 

53 aggregated_result: str = Field(default="", description="Aggregated final result") 

54 consensus_score: float = Field(default=0.0, description="Agreement level between agents (0-1)") 

55 metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") 

56 

57 

58class Swarm: 

59 """ 

60 Swarm pattern for parallel multi-agent execution. 

61 

62 All agents work on the same task simultaneously, results are aggregated. 

63 """ 

64 

65 def __init__( 

66 self, 

67 agents: dict[str, Callable[[str], Any]], 

68 aggregation_strategy: Literal["consensus", "voting", "synthesis", "concatenate"] = "synthesis", 

69 min_agreement: float = 0.7, 

70 ): 

71 """ 

72 Initialize swarm. 

73 

74 Args: 

75 agents: Dictionary of {agent_name: agent_function} 

76 aggregation_strategy: How to combine results 

77 - consensus: Find common ground between all agents 

78 - voting: Majority vote on discrete choices 

79 - synthesis: LLM synthesizes all perspectives 

80 - concatenate: Simple combination of all outputs 

81 min_agreement: Minimum agreement threshold for consensus (0-1) 

82 """ 

83 self.agents = agents 

84 self.aggregation_strategy = aggregation_strategy 

85 self.min_agreement = min_agreement 

86 self._graph: StateGraph[SwarmState] | None = None 

87 

88 def _create_agent_wrapper( 

89 self, agent_name: str, agent_func: Callable[[str], Any] 

90 ) -> Callable[[SwarmState], dict[str, Any]]: 

91 """ 

92 Create wrapper for agent execution. 

93 

94 Args: 

95 agent_name: Name of the agent 

96 agent_func: Agent function 

97 

98 Returns: 

99 Wrapped function that stores result in state 

100 """ 

101 

102 def agent_node(state: SwarmState) -> dict[str, Any]: 

103 """Execute agent and store result.""" 

104 try: 

105 result = agent_func(state.query) 

106 # Return only the updated field to avoid concurrent update conflicts 

107 updated_results = state.agent_results.copy() 

108 updated_results[agent_name] = result 

109 return {"agent_results": updated_results} 

110 except Exception as e: 

111 updated_results = state.agent_results.copy() 

112 updated_results[agent_name] = f"Error: {e!s}" 

113 return {"agent_results": updated_results} 

114 

115 return agent_node 

116 

117 def _calculate_consensus(self, results: list[str]) -> float: 

118 """ 

119 Calculate consensus score between results. 

120 

121 Simple implementation based on common words. 

122 In production, use semantic similarity with embeddings. 

123 

124 Args: 

125 results: List of agent results 

126 

127 Returns: 

128 Consensus score (0-1) 

129 """ 

130 if len(results) < 2: 

131 return 1.0 

132 

133 # Extract words from all results 

134 all_words = [] 

135 for result in results: 

136 words = set(result.lower().split()) 

137 all_words.append(words) 

138 

139 # Calculate pairwise similarity 

140 similarities = [] 

141 for i in range(len(all_words)): 

142 for j in range(i + 1, len(all_words)): 

143 intersection = all_words[i] & all_words[j] 

144 union = all_words[i] | all_words[j] 

145 similarity = len(intersection) / len(union) if union else 0 

146 similarities.append(similarity) 

147 

148 return sum(similarities) / len(similarities) if similarities else 0.0 

149 

150 def _aggregate_results(self, state: SwarmState) -> dict[str, Any]: 

151 """ 

152 Aggregate results from all agents. 

153 

154 Args: 

155 state: Current state with agent results 

156 

157 Returns: 

158 Dict with updated aggregated_result and consensus_score fields 

159 """ 

160 results = list(state.agent_results.values()) 

161 aggregated_result = "" 

162 consensus_score = 0.0 

163 

164 if self.aggregation_strategy == "concatenate": 164 ↛ 166line 164 didn't jump to line 166 because the condition on line 164 was never true

165 # Simple concatenation 

166 formatted_results = [] 

167 for agent_name, result in state.agent_results.items(): 

168 formatted_results.append(f"**{agent_name.title()}:**\n{result}") 

169 aggregated_result = "\n\n".join(formatted_results) 

170 

171 elif self.aggregation_strategy == "consensus": 

172 # Find consensus 

173 consensus_score = self._calculate_consensus(results) 

174 

175 if consensus_score >= self.min_agreement: 

176 # High agreement - present consensus 

177 aggregated_result = f"**Consensus (agreement: {consensus_score:.0%}):**\n\n" 

178 aggregated_result += f"All agents agree:\n{results[0]}" 

179 else: 

180 # Low agreement - present all viewpoints 

181 aggregated_result = f"**Multiple Perspectives (agreement: {consensus_score:.0%}):**\n\n" 

182 for agent_name, result in state.agent_results.items(): 

183 aggregated_result += f"**{agent_name.title()}:** {result}\n\n" 

184 

185 elif self.aggregation_strategy == "synthesis": 

186 # Synthesize all results (simplified - in production use LLM) 

187 aggregated_result = "**Synthesized Response:**\n\n" 

188 aggregated_result += "Based on analysis from multiple agents:\n\n" 

189 

190 for agent_name, result in state.agent_results.items(): 

191 # Extract key points (simplified) 

192 aggregated_result += f"• From {agent_name}: {result[:100]}...\n" 

193 

194 consensus_score = self._calculate_consensus(results) 

195 

196 elif self.aggregation_strategy == "voting": 196 ↛ 206line 196 didn't jump to line 206 because the condition on line 196 was always true

197 # Simple voting (count identical responses) 

198 from collections import Counter 

199 

200 vote_counter = Counter(results) 

201 winner, count = vote_counter.most_common(1)[0] 

202 

203 aggregated_result = f"**Voting Result ({count}/{len(results)} agents agree):**\n\n{winner}" 

204 consensus_score = count / len(results) 

205 

206 return {"aggregated_result": aggregated_result, "consensus_score": consensus_score} 

207 

208 def build(self) -> "StateGraph[SwarmState]": 

209 """Build the swarm graph.""" 

210 from langgraph.graph import END, START 

211 

212 graph: StateGraph[SwarmState] = StateGraph(SwarmState) 

213 

214 # Add all agent nodes (wrapped with RunnableLambda for type safety) 

215 for agent_name, agent_func in self.agents.items(): 

216 agent_wrapper = self._create_agent_wrapper(agent_name, agent_func) 

217 graph.add_node(agent_name, RunnableLambda(agent_wrapper)) 

218 

219 # Add aggregator (wrapped with RunnableLambda for type safety) 

220 graph.add_node("aggregate", RunnableLambda(self._aggregate_results)) 

221 

222 # All agents run in parallel from start 

223 for agent_name in self.agents: 

224 graph.add_edge(START, agent_name) 

225 

226 # All agents feed into aggregator 

227 for agent_name in self.agents: 

228 graph.add_edge(agent_name, "aggregate") 

229 

230 # Aggregator is the end 

231 graph.add_edge("aggregate", END) 

232 

233 return graph 

234 

235 def compile(self, checkpointer: Any = None) -> Any: 

236 """ 

237 Compile the swarm graph. 

238 

239 Args: 

240 checkpointer: Optional checkpointer 

241 

242 Returns: 

243 Compiled graph 

244 """ 

245 if self._graph is None: 245 ↛ 248line 245 didn't jump to line 248 because the condition on line 245 was always true

246 self._graph = self.build() 

247 

248 return self._graph.compile(checkpointer=checkpointer) 

249 

250 def invoke(self, query: str, config: dict[str, Any] | None = None) -> dict[str, Any]: 

251 """ 

252 Execute the swarm pattern. 

253 

254 Args: 

255 query: Query for all agents 

256 config: Optional configuration 

257 

258 Returns: 

259 Aggregated results 

260 """ 

261 compiled = self.compile() 

262 state = SwarmState(query=query) 

263 

264 result = compiled.invoke(state, config=config or {}) 

265 

266 return { 

267 "query": result["query"], 

268 "aggregated_result": result["aggregated_result"], 

269 "agent_results": result["agent_results"], 

270 "consensus_score": result["consensus_score"], 

271 "num_agents": len(result["agent_results"]), 

272 } 

273 

274 

275# Example usage and testing 

276if __name__ == "__main__": 

277 # Define diverse agents with different perspectives 

278 def optimistic_agent(query: str) -> str: 

279 """Optimistic perspective.""" 

280 return f"Analysis of '{query}': This is a great opportunity with many positive aspects and potential benefits." 

281 

282 def pessimistic_agent(query: str) -> str: 

283 """Pessimistic perspective.""" 

284 return f"Analysis of '{query}': There are significant risks and challenges that need careful consideration." 

285 

286 def neutral_agent(query: str) -> str: 

287 """Balanced perspective.""" 

288 return f"Analysis of '{query}': There are both advantages and disadvantages that should be weighed carefully." 

289 

290 # Test different aggregation strategies 

291 strategies = ["concatenate", "consensus", "synthesis", "voting"] 

292 

293 for strategy in strategies: 

294 print("=" * 80) 

295 print(f"SWARM PATTERN - {strategy.upper()} STRATEGY") 

296 print("=" * 80) 

297 

298 swarm = Swarm( 

299 agents={"optimist": optimistic_agent, "pessimist": pessimistic_agent, "neutral": neutral_agent}, 

300 aggregation_strategy=strategy, # type: ignore 

301 ) 

302 

303 result = swarm.invoke("Adopting AI in our business") 

304 

305 print(f"\nQuery: {result['query']}") 

306 print(f"Number of Agents: {result['num_agents']}") 

307 print(f"Consensus Score: {result['consensus_score']:.0%}") 

308 print(f"\nAGGREGATED RESULT:\n{result['aggregated_result']}") 

309 print() 

310 

311 print("=" * 80)