Coverage for src / mcp_server_langgraph / builder / importer / graph_extractor.py: 87%
93 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 00:43 +0000
1"""
2Graph Extractor for LangGraph Code
4Extracts workflow structure from LangGraph Python code:
5- Identifies StateGraph creation
6- Extracts nodes and their functions
7- Extracts edges and conditions
8- Identifies entry/finish points
9- Infers state schema
11Example:
12 from mcp_server_langgraph.builder.importer import GraphExtractor
14 extractor = GraphExtractor()
15 workflow = extractor.extract_from_file("agent.py")
17 print(f"Nodes: {len(workflow['nodes'])}")
18 print(f"Edges: {len(workflow['edges'])}")
19"""
21import ast
22from typing import Any
24from .ast_parser import PythonCodeParser
27class GraphExtractor:
28 """
29 Extract LangGraph workflow structure from Python code.
31 Uses AST analysis to identify graph construction patterns.
32 """
34 def __init__(self) -> None:
35 """Initialize extractor."""
36 self.parser = PythonCodeParser()
38 def extract_from_code(self, code: str) -> dict[str, Any]:
39 """
40 Extract workflow from Python code string.
42 Args:
43 code: Python source code
45 Returns:
46 Workflow definition dict
48 Example:
49 >>> workflow = extractor.extract_from_code(python_code)
50 >>> workflow["name"]
51 'my_agent'
52 """
53 # Parse code
54 tree = self.parser.parse_code(code)
56 # Extract components
57 workflow_name = self._extract_workflow_name(tree, code)
58 state_schema = self._extract_state_schema(tree)
59 nodes = self._extract_nodes(tree)
60 edges = self._extract_edges(tree)
61 entry_point = self._extract_entry_point(tree)
63 return {
64 "name": workflow_name,
65 "description": self._extract_description(tree),
66 "nodes": nodes,
67 "edges": edges,
68 "entry_point": entry_point,
69 "state_schema": state_schema,
70 "metadata": {"source": "imported", "parser_version": "1.0"},
71 }
73 def extract_from_file(self, file_path: str) -> dict[str, Any]:
74 """
75 Extract workflow from Python file.
77 Args:
78 file_path: Path to Python file
80 Returns:
81 Workflow definition
83 Example:
84 >>> workflow = extractor.extract_from_file("agent.py")
85 """
86 with open(file_path) as f:
87 code = f.read()
89 return self.extract_from_code(code)
91 def _extract_workflow_name(self, tree: ast.Module, code: str) -> str:
92 """
93 Extract workflow name from code.
95 Tries multiple strategies:
96 1. Function name (create_xxx_agent or create_xxx)
97 2. Variable name (xxx_agent =)
98 3. File name
99 4. Default to "imported_workflow"
101 Args:
102 tree: AST tree
103 code: Original source code
105 Returns:
106 Workflow name
107 """
108 # Strategy 1: Find create_xxx functions
109 for node in ast.walk(tree):
110 if isinstance(node, ast.FunctionDef) and node.name.startswith("create_"):
111 # Extract name after create_ prefix
112 name = node.name.replace("create_", "")
113 # Remove _agent or _workflow suffix if present
114 name = name.replace("_agent", "").replace("_workflow", "")
115 # Skip generic names like "create_graph"
116 if name and name not in ["graph", "agent", "workflow"]:
117 return name
119 # Strategy 2: Find xxx_agent variables
120 assignments = self.parser.find_variable_assignments(tree)
121 for assignment in assignments:
122 var_name = assignment["variable"]
123 if var_name.endswith("_agent") or var_name.endswith("_workflow") or var_name == "graph":
124 if var_name in ["graph", "agent", "workflow"]: 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true
125 continue # Too generic
126 result = var_name.replace("_agent", "").replace("_workflow", "")
127 return str(result)
129 # Default
130 return "imported_workflow"
132 def _extract_description(self, tree: ast.Module) -> str:
133 """
134 Extract module docstring as description.
136 Args:
137 tree: AST tree
139 Returns:
140 Description string
141 """
142 docstring = ast.get_docstring(tree)
143 return docstring or "Imported workflow"
145 def _extract_state_schema(self, tree: ast.Module) -> dict[str, str]:
146 """
147 Extract state schema from TypedDict or Pydantic model.
149 Args:
150 tree: AST tree
152 Returns:
153 State schema dict {field_name: type}
155 Example:
156 >>> schema = extractor._extract_state_schema(tree)
157 >>> schema
158 {'query': 'str', 'result': 'str'}
159 """
160 classes = self.parser.find_class_definitions(tree)
162 # Look for State classes
163 for cls in classes:
164 if "State" in cls["name"] or "TypedDict" in cls["bases"] or "BaseModel" in cls["bases"]: 164 ↛ 163line 164 didn't jump to line 163 because the condition on line 164 was always true
165 # Found state class - extract fields
166 return self._extract_class_fields(tree, cls["name"])
168 return {}
170 def _extract_class_fields(self, tree: ast.Module, class_name: str) -> dict[str, str]:
171 """
172 Extract fields from a class definition.
174 Args:
175 tree: AST tree
176 class_name: Name of class
178 Returns:
179 Fields dict
180 """
181 fields = {}
183 for node in ast.walk(tree):
184 if isinstance(node, ast.ClassDef) and node.name == class_name:
185 # Extract annotated assignments
186 for item in node.body:
187 if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
188 field_name = item.target.id
189 field_type = ast.unparse(item.annotation) if hasattr(ast, "unparse") else "Any"
190 fields[field_name] = field_type
192 return fields
194 def _extract_nodes(self, tree: ast.Module) -> list[dict[str, Any]]:
195 """
196 Extract nodes from add_node() calls.
198 Args:
199 tree: AST tree
201 Returns:
202 List of node definitions
204 Example:
205 >>> nodes = extractor._extract_nodes(tree)
206 >>> nodes[0]
207 {'id': 'search', 'type': 'custom', 'label': 'search', 'config': {}}
208 """
209 add_node_calls = self.parser.find_function_calls(tree, "add_node")
211 nodes = []
212 for call in add_node_calls:
213 # add_node(node_id, function) or add_node(node_id, function, **kwargs)
214 if len(call["args"]) >= 1: 214 ↛ 212line 214 didn't jump to line 212 because the condition on line 214 was always true
215 node_id = call["args"][0]
216 function_name = call["args"][1] if len(call["args"]) > 1 else "unknown"
218 # Infer node type from function name or kwargs
219 node_type = self._infer_node_type(function_name, call["kwargs"])
221 nodes.append(
222 {
223 "id": node_id,
224 "type": node_type,
225 "label": node_id,
226 "config": call["kwargs"],
227 "position": {"x": 0, "y": 0}, # Will be set by layout engine
228 }
229 )
231 return nodes
233 def _infer_node_type(self, function_name: Any, kwargs: dict[str, Any]) -> str:
234 """
235 Infer node type from function name and configuration.
237 Args:
238 function_name: Function name or reference
239 kwargs: Keyword arguments
241 Returns:
242 Node type (tool, llm, conditional, approval, custom)
243 """
244 func_str = str(function_name).lower()
246 # Check for keywords in function name
247 if "tool" in func_str or "call_tool" in func_str: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 return "tool"
249 elif "llm" in func_str or "completion" in func_str or "chat" in func_str: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 return "llm"
251 elif "condition" in func_str or "route" in func_str or "decide" in func_str: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true
252 return "conditional"
253 elif "approval" in func_str or "approve" in func_str or "review" in func_str: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true
254 return "approval"
255 else:
256 return "custom"
258 def _extract_edges(self, tree: ast.Module) -> list[dict[str, str]]:
259 """
260 Extract edges from add_edge() and add_conditional_edges() calls.
262 Args:
263 tree: AST tree
265 Returns:
266 List of edge definitions
268 Example:
269 >>> edges = extractor._extract_edges(tree)
270 >>> edges[0]
271 {'from': 'search', 'to': 'summarize'}
272 """
273 edges = []
275 # Extract add_edge calls
276 add_edge_calls = self.parser.find_function_calls(tree, "add_edge")
277 for call in add_edge_calls:
278 if len(call["args"]) >= 2: 278 ↛ 277line 278 didn't jump to line 277 because the condition on line 278 was always true
279 edges.append({"from": call["args"][0], "to": call["args"][1], "condition": None})
281 # Extract add_conditional_edges calls
282 conditional_calls = self.parser.find_function_calls(tree, "add_conditional_edges")
283 for call in conditional_calls:
284 if len(call["args"]) >= 2: 284 ↛ 283line 284 didn't jump to line 283 because the condition on line 284 was always true
285 # add_conditional_edges(source, routing_function)
286 source = call["args"][0]
287 routing_func = call["args"][1]
289 # Create edge with condition placeholder
290 edges.append({"from": source, "to": "conditional", "condition": f"route_via_{routing_func}"})
292 return edges
294 def _extract_entry_point(self, tree: ast.Module) -> str:
295 """
296 Extract entry point from set_entry_point() call.
298 Args:
299 tree: AST tree
301 Returns:
302 Entry point node ID
303 """
304 entry_calls = self.parser.find_function_calls(tree, "set_entry_point")
306 if entry_calls and entry_calls[0]["args"]:
307 result = entry_calls[0]["args"][0]
308 return str(result)
310 return "start" # Default
313# ==============================================================================
314# Example Usage
315# ==============================================================================
317if __name__ == "__main__":
318 # Sample LangGraph code
319 sample_code = '''
320"""
321Research Agent
323Searches and summarizes information.
324"""
326from typing import TypedDict
327from langgraph.graph import StateGraph
330class ResearchState(TypedDict):
331 """State for research agent."""
332 query: str
333 search_results: List[str]
334 summary: str
337def search_web(state: ResearchState) -> ResearchState:
338 """Search the web."""
339 # Implementation
340 return state
343def summarize_results(state: ResearchState) -> ResearchState:
344 """Summarize search results."""
345 # Implementation
346 return state
349def create_research_agent():
350 """Create research agent."""
351 graph = StateGraph(ResearchState)
353 graph.add_node("search", search_web)
354 graph.add_node("summarize", summarize_results)
356 graph.add_edge("search", "summarize")
358 graph.set_entry_point("search")
359 graph.set_finish_point("summarize")
361 return graph.compile()
362'''
364 # Extract workflow
365 extractor = GraphExtractor()
366 workflow = extractor.extract_from_code(sample_code)
368 print("=" * 80)
369 print("GRAPH EXTRACTOR - TEST RUN")
370 print("=" * 80)
372 print(f"\nWorkflow Name: {workflow['name']}")
373 print(f"Description: {workflow['description']}")
374 print(f"Entry Point: {workflow['entry_point']}")
376 print("\nState Schema:")
377 for field, type_annotation in workflow["state_schema"].items():
378 print(f" - {field}: {type_annotation}")
380 print(f"\nNodes ({len(workflow['nodes'])}):")
381 for node in workflow["nodes"]:
382 print(f" - {node['id']} (type: {node['type']})")
384 print(f"\nEdges ({len(workflow['edges'])}):")
385 for edge in workflow["edges"]:
386 print(f" - {edge['from']} → {edge['to']}")
388 print("\n" + "=" * 80)