Coverage for src / mcp_server_langgraph / builder / importer / graph_extractor.py: 87%

93 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Graph Extractor for LangGraph Code 

3 

4Extracts workflow structure from LangGraph Python code: 

5- Identifies StateGraph creation 

6- Extracts nodes and their functions 

7- Extracts edges and conditions 

8- Identifies entry/finish points 

9- Infers state schema 

10 

11Example: 

12 from mcp_server_langgraph.builder.importer import GraphExtractor 

13 

14 extractor = GraphExtractor() 

15 workflow = extractor.extract_from_file("agent.py") 

16 

17 print(f"Nodes: {len(workflow['nodes'])}") 

18 print(f"Edges: {len(workflow['edges'])}") 

19""" 

20 

21import ast 

22from typing import Any 

23 

24from .ast_parser import PythonCodeParser 

25 

26 

27class GraphExtractor: 

28 """ 

29 Extract LangGraph workflow structure from Python code. 

30 

31 Uses AST analysis to identify graph construction patterns. 

32 """ 

33 

34 def __init__(self) -> None: 

35 """Initialize extractor.""" 

36 self.parser = PythonCodeParser() 

37 

38 def extract_from_code(self, code: str) -> dict[str, Any]: 

39 """ 

40 Extract workflow from Python code string. 

41 

42 Args: 

43 code: Python source code 

44 

45 Returns: 

46 Workflow definition dict 

47 

48 Example: 

49 >>> workflow = extractor.extract_from_code(python_code) 

50 >>> workflow["name"] 

51 'my_agent' 

52 """ 

53 # Parse code 

54 tree = self.parser.parse_code(code) 

55 

56 # Extract components 

57 workflow_name = self._extract_workflow_name(tree, code) 

58 state_schema = self._extract_state_schema(tree) 

59 nodes = self._extract_nodes(tree) 

60 edges = self._extract_edges(tree) 

61 entry_point = self._extract_entry_point(tree) 

62 

63 return { 

64 "name": workflow_name, 

65 "description": self._extract_description(tree), 

66 "nodes": nodes, 

67 "edges": edges, 

68 "entry_point": entry_point, 

69 "state_schema": state_schema, 

70 "metadata": {"source": "imported", "parser_version": "1.0"}, 

71 } 

72 

73 def extract_from_file(self, file_path: str) -> dict[str, Any]: 

74 """ 

75 Extract workflow from Python file. 

76 

77 Args: 

78 file_path: Path to Python file 

79 

80 Returns: 

81 Workflow definition 

82 

83 Example: 

84 >>> workflow = extractor.extract_from_file("agent.py") 

85 """ 

86 with open(file_path) as f: 

87 code = f.read() 

88 

89 return self.extract_from_code(code) 

90 

91 def _extract_workflow_name(self, tree: ast.Module, code: str) -> str: 

92 """ 

93 Extract workflow name from code. 

94 

95 Tries multiple strategies: 

96 1. Function name (create_xxx_agent or create_xxx) 

97 2. Variable name (xxx_agent =) 

98 3. File name 

99 4. Default to "imported_workflow" 

100 

101 Args: 

102 tree: AST tree 

103 code: Original source code 

104 

105 Returns: 

106 Workflow name 

107 """ 

108 # Strategy 1: Find create_xxx functions 

109 for node in ast.walk(tree): 

110 if isinstance(node, ast.FunctionDef) and node.name.startswith("create_"): 

111 # Extract name after create_ prefix 

112 name = node.name.replace("create_", "") 

113 # Remove _agent or _workflow suffix if present 

114 name = name.replace("_agent", "").replace("_workflow", "") 

115 # Skip generic names like "create_graph" 

116 if name and name not in ["graph", "agent", "workflow"]: 

117 return name 

118 

119 # Strategy 2: Find xxx_agent variables 

120 assignments = self.parser.find_variable_assignments(tree) 

121 for assignment in assignments: 

122 var_name = assignment["variable"] 

123 if var_name.endswith("_agent") or var_name.endswith("_workflow") or var_name == "graph": 

124 if var_name in ["graph", "agent", "workflow"]: 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true

125 continue # Too generic 

126 result = var_name.replace("_agent", "").replace("_workflow", "") 

127 return str(result) 

128 

129 # Default 

130 return "imported_workflow" 

131 

132 def _extract_description(self, tree: ast.Module) -> str: 

133 """ 

134 Extract module docstring as description. 

135 

136 Args: 

137 tree: AST tree 

138 

139 Returns: 

140 Description string 

141 """ 

142 docstring = ast.get_docstring(tree) 

143 return docstring or "Imported workflow" 

144 

145 def _extract_state_schema(self, tree: ast.Module) -> dict[str, str]: 

146 """ 

147 Extract state schema from TypedDict or Pydantic model. 

148 

149 Args: 

150 tree: AST tree 

151 

152 Returns: 

153 State schema dict {field_name: type} 

154 

155 Example: 

156 >>> schema = extractor._extract_state_schema(tree) 

157 >>> schema 

158 {'query': 'str', 'result': 'str'} 

159 """ 

160 classes = self.parser.find_class_definitions(tree) 

161 

162 # Look for State classes 

163 for cls in classes: 

164 if "State" in cls["name"] or "TypedDict" in cls["bases"] or "BaseModel" in cls["bases"]: 164 ↛ 163line 164 didn't jump to line 163 because the condition on line 164 was always true

165 # Found state class - extract fields 

166 return self._extract_class_fields(tree, cls["name"]) 

167 

168 return {} 

169 

170 def _extract_class_fields(self, tree: ast.Module, class_name: str) -> dict[str, str]: 

171 """ 

172 Extract fields from a class definition. 

173 

174 Args: 

175 tree: AST tree 

176 class_name: Name of class 

177 

178 Returns: 

179 Fields dict 

180 """ 

181 fields = {} 

182 

183 for node in ast.walk(tree): 

184 if isinstance(node, ast.ClassDef) and node.name == class_name: 

185 # Extract annotated assignments 

186 for item in node.body: 

187 if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): 

188 field_name = item.target.id 

189 field_type = ast.unparse(item.annotation) if hasattr(ast, "unparse") else "Any" 

190 fields[field_name] = field_type 

191 

192 return fields 

193 

194 def _extract_nodes(self, tree: ast.Module) -> list[dict[str, Any]]: 

195 """ 

196 Extract nodes from add_node() calls. 

197 

198 Args: 

199 tree: AST tree 

200 

201 Returns: 

202 List of node definitions 

203 

204 Example: 

205 >>> nodes = extractor._extract_nodes(tree) 

206 >>> nodes[0] 

207 {'id': 'search', 'type': 'custom', 'label': 'search', 'config': {}} 

208 """ 

209 add_node_calls = self.parser.find_function_calls(tree, "add_node") 

210 

211 nodes = [] 

212 for call in add_node_calls: 

213 # add_node(node_id, function) or add_node(node_id, function, **kwargs) 

214 if len(call["args"]) >= 1: 214 ↛ 212line 214 didn't jump to line 212 because the condition on line 214 was always true

215 node_id = call["args"][0] 

216 function_name = call["args"][1] if len(call["args"]) > 1 else "unknown" 

217 

218 # Infer node type from function name or kwargs 

219 node_type = self._infer_node_type(function_name, call["kwargs"]) 

220 

221 nodes.append( 

222 { 

223 "id": node_id, 

224 "type": node_type, 

225 "label": node_id, 

226 "config": call["kwargs"], 

227 "position": {"x": 0, "y": 0}, # Will be set by layout engine 

228 } 

229 ) 

230 

231 return nodes 

232 

233 def _infer_node_type(self, function_name: Any, kwargs: dict[str, Any]) -> str: 

234 """ 

235 Infer node type from function name and configuration. 

236 

237 Args: 

238 function_name: Function name or reference 

239 kwargs: Keyword arguments 

240 

241 Returns: 

242 Node type (tool, llm, conditional, approval, custom) 

243 """ 

244 func_str = str(function_name).lower() 

245 

246 # Check for keywords in function name 

247 if "tool" in func_str or "call_tool" in func_str: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 return "tool" 

249 elif "llm" in func_str or "completion" in func_str or "chat" in func_str: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 return "llm" 

251 elif "condition" in func_str or "route" in func_str or "decide" in func_str: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true

252 return "conditional" 

253 elif "approval" in func_str or "approve" in func_str or "review" in func_str: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true

254 return "approval" 

255 else: 

256 return "custom" 

257 

258 def _extract_edges(self, tree: ast.Module) -> list[dict[str, str]]: 

259 """ 

260 Extract edges from add_edge() and add_conditional_edges() calls. 

261 

262 Args: 

263 tree: AST tree 

264 

265 Returns: 

266 List of edge definitions 

267 

268 Example: 

269 >>> edges = extractor._extract_edges(tree) 

270 >>> edges[0] 

271 {'from': 'search', 'to': 'summarize'} 

272 """ 

273 edges = [] 

274 

275 # Extract add_edge calls 

276 add_edge_calls = self.parser.find_function_calls(tree, "add_edge") 

277 for call in add_edge_calls: 

278 if len(call["args"]) >= 2: 278 ↛ 277line 278 didn't jump to line 277 because the condition on line 278 was always true

279 edges.append({"from": call["args"][0], "to": call["args"][1], "condition": None}) 

280 

281 # Extract add_conditional_edges calls 

282 conditional_calls = self.parser.find_function_calls(tree, "add_conditional_edges") 

283 for call in conditional_calls: 

284 if len(call["args"]) >= 2: 284 ↛ 283line 284 didn't jump to line 283 because the condition on line 284 was always true

285 # add_conditional_edges(source, routing_function) 

286 source = call["args"][0] 

287 routing_func = call["args"][1] 

288 

289 # Create edge with condition placeholder 

290 edges.append({"from": source, "to": "conditional", "condition": f"route_via_{routing_func}"}) 

291 

292 return edges 

293 

294 def _extract_entry_point(self, tree: ast.Module) -> str: 

295 """ 

296 Extract entry point from set_entry_point() call. 

297 

298 Args: 

299 tree: AST tree 

300 

301 Returns: 

302 Entry point node ID 

303 """ 

304 entry_calls = self.parser.find_function_calls(tree, "set_entry_point") 

305 

306 if entry_calls and entry_calls[0]["args"]: 

307 result = entry_calls[0]["args"][0] 

308 return str(result) 

309 

310 return "start" # Default 

311 

312 

313# ============================================================================== 

314# Example Usage 

315# ============================================================================== 

316 

317if __name__ == "__main__": 

318 # Sample LangGraph code 

319 sample_code = ''' 

320""" 

321Research Agent 

322 

323Searches and summarizes information. 

324""" 

325 

326from typing import TypedDict 

327from langgraph.graph import StateGraph 

328 

329 

330class ResearchState(TypedDict): 

331 """State for research agent.""" 

332 query: str 

333 search_results: List[str] 

334 summary: str 

335 

336 

337def search_web(state: ResearchState) -> ResearchState: 

338 """Search the web.""" 

339 # Implementation 

340 return state 

341 

342 

343def summarize_results(state: ResearchState) -> ResearchState: 

344 """Summarize search results.""" 

345 # Implementation 

346 return state 

347 

348 

349def create_research_agent(): 

350 """Create research agent.""" 

351 graph = StateGraph(ResearchState) 

352 

353 graph.add_node("search", search_web) 

354 graph.add_node("summarize", summarize_results) 

355 

356 graph.add_edge("search", "summarize") 

357 

358 graph.set_entry_point("search") 

359 graph.set_finish_point("summarize") 

360 

361 return graph.compile() 

362''' 

363 

364 # Extract workflow 

365 extractor = GraphExtractor() 

366 workflow = extractor.extract_from_code(sample_code) 

367 

368 print("=" * 80) 

369 print("GRAPH EXTRACTOR - TEST RUN") 

370 print("=" * 80) 

371 

372 print(f"\nWorkflow Name: {workflow['name']}") 

373 print(f"Description: {workflow['description']}") 

374 print(f"Entry Point: {workflow['entry_point']}") 

375 

376 print("\nState Schema:") 

377 for field, type_annotation in workflow["state_schema"].items(): 

378 print(f" - {field}: {type_annotation}") 

379 

380 print(f"\nNodes ({len(workflow['nodes'])}):") 

381 for node in workflow["nodes"]: 

382 print(f" - {node['id']} (type: {node['type']})") 

383 

384 print(f"\nEdges ({len(workflow['edges'])}):") 

385 for edge in workflow["edges"]: 

386 print(f" - {edge['from']}{edge['to']}") 

387 

388 print("\n" + "=" * 80)