Coverage for src / mcp_server_langgraph / execution / code_validator.py: 95%

109 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Code validator for secure Python code execution 

3 

4Uses AST-based validation to detect dangerous patterns and enforce import whitelists. 

5Security-first design following defense-in-depth principles. 

6""" 

7 

8import ast 

9from dataclasses import dataclass, field 

10 

11 

12class CodeValidationError(Exception): 

13 """Raised when code validation fails""" 

14 

15 

16@dataclass 

17class ValidationResult: 

18 """Result of code validation""" 

19 

20 is_valid: bool 

21 errors: list[str] = field(default_factory=list) 

22 warnings: list[str] = field(default_factory=list) 

23 

24 def __repr__(self) -> str: 

25 "valid" if self.is_valid else "invalid" 

26 return f"ValidationResult(is_valid={self.is_valid}, errors={len(self.errors)}, warnings={len(self.warnings)})" 

27 

28 

29class CodeValidator: 

30 """ 

31 Validates Python code for safe execution in sandboxed environments. 

32 

33 Uses AST analysis to detect: 

34 - Dangerous imports (os, subprocess, sys, etc.) 

35 - Dangerous builtin calls (eval, exec, compile, etc.) 

36 - Reflection/introspection abuse (globals, locals, getattr, etc.) 

37 - File access (open) 

38 - Network access (socket, urllib, requests) 

39 

40 Example: 

41 >>> validator = CodeValidator(allowed_imports=["json", "math"]) 

42 >>> result = validator.validate("import json\\ndata = json.dumps({})") 

43 >>> assert result.is_valid is True 

44 """ 

45 

46 # Dangerous modules that should never be allowed 

47 BLOCKED_MODULES = { 

48 "os", 

49 "sys", 

50 "subprocess", 

51 "socket", 

52 "pickle", 

53 "marshal", 

54 "ctypes", 

55 "importlib", 

56 "pty", 

57 "fcntl", 

58 "termios", 

59 "tty", 

60 "shutil", 

61 "tempfile", 

62 "urllib", 

63 "urllib.request", 

64 "urllib2", 

65 "httplib", 

66 "http.client", 

67 "requests", 

68 "ptrace", 

69 "resource", # Can modify resource limits 

70 "signal", # Can send signals to processes 

71 "multiprocessing", # Can spawn processes 

72 "threading", # Excessive threading can DoS 

73 "asyncio.subprocess", # Subprocess via asyncio 

74 "code", # Interactive interpreter 

75 "pdb", # Debugger 

76 "inspect", # Introspection 

77 "gc", # Garbage collector manipulation 

78 "weakref", # Weak reference manipulation 

79 "ast", # Can be used to bypass restrictions 

80 "dis", # Disassembler (introspection) 

81 "imp", # Import hooks (deprecated but dangerous) 

82 "importlib.util", # Dynamic imports 

83 "importlib.machinery", # Import machinery 

84 "pkgutil", # Package utilities 

85 "modulefinder", # Module finding 

86 "runpy", # Run Python modules 

87 "platform", # System information disclosure 

88 } 

89 

90 # Dangerous builtin functions 

91 BLOCKED_BUILTINS = { 

92 "eval", 

93 "exec", 

94 "compile", 

95 "__import__", 

96 "globals", 

97 "locals", 

98 "vars", 

99 "getattr", 

100 "setattr", 

101 "delattr", 

102 "open", 

103 "input", 

104 "raw_input", # Python 2 

105 "execfile", # Python 2 

106 "reload", # Reload modules 

107 "breakpoint", # Debugger 

108 "help", # Interactive help (can leak info) 

109 "dir", # Introspection 

110 "id", # Object ID (info disclosure) 

111 "memoryview", # Direct memory access 

112 } 

113 

114 # Patterns that suggest malicious intent 

115 SUSPICIOUS_PATTERNS = { 

116 "__builtins__", 

117 "__globals__", 

118 "__dict__", 

119 "__class__", 

120 "__bases__", 

121 "__subclasses__", 

122 "__init__", 

123 "__code__", 

124 "func_code", # Python 2 

125 "func_globals", # Python 2 

126 } 

127 

128 def __init__(self, allowed_imports: list[str] | None = None): 

129 """ 

130 Initialize code validator. 

131 

132 Args: 

133 allowed_imports: List of allowed module names (e.g., ["json", "math", "pandas"]) 

134 """ 

135 self.allowed_imports = set(allowed_imports or []) 

136 

137 def validate(self, code: str) -> ValidationResult: 

138 """ 

139 Validate Python code for security issues. 

140 

141 Args: 

142 code: Python source code to validate 

143 

144 Returns: 

145 ValidationResult with validation status and any errors/warnings 

146 """ 

147 errors: list[str] = [] 

148 warnings: list[str] = [] 

149 

150 # Check for empty code 

151 if not code or not code.strip(): 

152 errors.append("Code is empty or contains only whitespace") 

153 return ValidationResult(is_valid=False, errors=errors, warnings=warnings) 

154 

155 # Parse code into AST 

156 try: 

157 tree = ast.parse(code) 

158 except SyntaxError as e: 

159 errors.append(f"Syntax error: {e}") 

160 return ValidationResult(is_valid=False, errors=errors, warnings=warnings) 

161 except Exception as e: 

162 errors.append(f"Failed to parse code: {e}") 

163 return ValidationResult(is_valid=False, errors=errors, warnings=warnings) 

164 

165 # Analyze AST for security issues 

166 visitor = SecurityVisitor(self.allowed_imports, self.BLOCKED_MODULES, self.BLOCKED_BUILTINS) 

167 visitor.visit(tree) 

168 

169 errors.extend(visitor.errors) 

170 warnings.extend(visitor.warnings) 

171 

172 is_valid = len(errors) == 0 

173 

174 return ValidationResult(is_valid=is_valid, errors=errors, warnings=warnings) 

175 

176 

177class SecurityVisitor(ast.NodeVisitor): 

178 """AST visitor that detects security issues in Python code""" 

179 

180 def __init__( 

181 self, 

182 allowed_imports: set[str], 

183 blocked_modules: set[str], 

184 blocked_builtins: set[str], 

185 ): 

186 self.allowed_imports = allowed_imports 

187 self.blocked_modules = blocked_modules 

188 self.blocked_builtins = blocked_builtins 

189 self.errors: list[str] = [] 

190 self.warnings: list[str] = [] 

191 

192 def visit_Import(self, node: ast.Import) -> None: 

193 """Check import statements""" 

194 for alias in node.names: 

195 module_name = alias.name 

196 

197 # Check if module is explicitly blocked 

198 if self._is_blocked_module(module_name): 

199 self.errors.append(f"Import of blocked module '{module_name}' not allowed") 

200 continue 

201 

202 # Check if module is in allowed list 

203 if module_name not in self.allowed_imports: 

204 self.errors.append(f"Import of module '{module_name}' not in allowed list") 

205 

206 self.generic_visit(node) 

207 

208 def visit_ImportFrom(self, node: ast.ImportFrom) -> None: 

209 """Check from ... import statements""" 

210 if node.module: 210 ↛ 221line 210 didn't jump to line 221 because the condition on line 210 was always true

211 # Check if module is explicitly blocked 

212 if self._is_blocked_module(node.module): 

213 self.errors.append(f"Import from blocked module '{node.module}' not allowed") 

214 self.generic_visit(node) 

215 return 

216 

217 # Check if module is in allowed list 

218 if node.module not in self.allowed_imports: 218 ↛ 221line 218 didn't jump to line 221 because the condition on line 218 was always true

219 self.errors.append(f"Import from module '{node.module}' not in allowed list") 

220 

221 self.generic_visit(node) 

222 

223 def visit_Call(self, node: ast.Call) -> None: 

224 """Check function calls for dangerous builtins""" 

225 func_name = self._get_call_name(node) 

226 

227 if func_name: 

228 # Check for blocked builtin functions 

229 if func_name in self.blocked_builtins: 

230 self.errors.append(f"Call to blocked builtin '{func_name}' not allowed") 

231 

232 # Check for dangerous attribute access patterns 

233 if func_name == "system": 

234 self.errors.append("Call to 'system' function not allowed") 

235 

236 # Check for suspicious patterns in call arguments 

237 self._check_suspicious_patterns(node) 

238 

239 self.generic_visit(node) 

240 

241 def visit_Attribute(self, node: ast.Attribute) -> None: 

242 """Check attribute access for suspicious patterns""" 

243 attr_name = node.attr 

244 

245 # Check for dangerous attribute access 

246 if attr_name in ["system", "popen", "spawn", "exec"]: 

247 self.errors.append(f"Access to attribute '{attr_name}' not allowed") 

248 

249 # Check for introspection patterns 

250 if attr_name in CodeValidator.SUSPICIOUS_PATTERNS: 

251 self.warnings.append(f"Suspicious attribute access '{attr_name}' detected") 

252 

253 self.generic_visit(node) 

254 

255 def visit_Name(self, node: ast.Name) -> None: 

256 """Check variable/name access for suspicious patterns""" 

257 name = node.id 

258 

259 # Check for direct access to dangerous builtins 

260 if name in self.blocked_builtins: 

261 self.errors.append(f"Access to blocked name '{name}' not allowed") 

262 

263 # Check for suspicious patterns 

264 if name in CodeValidator.SUSPICIOUS_PATTERNS: 

265 self.warnings.append(f"Suspicious name '{name}' detected") 

266 

267 self.generic_visit(node) 

268 

269 def visit_While(self, node: ast.While) -> None: 

270 """Check while loops for obvious infinite loops""" 

271 # Detect 'while True:' or 'while 1:' 

272 if isinstance(node.test, ast.Constant) and (node.test.value is True or node.test.value == 1): 272 ↛ 275line 272 didn't jump to line 275 because the condition on line 272 was always true

273 self.warnings.append("Infinite loop detected: 'while True' or 'while 1'") 

274 

275 self.generic_visit(node) 

276 

277 def _get_call_name(self, node: ast.Call) -> str | None: 

278 """Extract function name from call node""" 

279 if isinstance(node.func, ast.Name): 

280 return node.func.id 

281 elif isinstance(node.func, ast.Attribute): 

282 return node.func.attr 

283 return None 

284 

285 def _is_blocked_module(self, module_name: str) -> bool: 

286 """Check if module or any of its parents are blocked""" 

287 # Check exact match 

288 if module_name in self.blocked_modules: 

289 return True 

290 

291 # Check parent modules (e.g., 'os.path' should be blocked if 'os' is blocked) 

292 parts = module_name.split(".") 

293 for i in range(len(parts)): 

294 parent = ".".join(parts[: i + 1]) 

295 if parent in self.blocked_modules: 

296 return True 

297 

298 return False 

299 

300 def _check_suspicious_patterns(self, node: ast.Call) -> None: 

301 """Check for suspicious patterns in call arguments""" 

302 # Check for eval/exec with string formatting (code injection) 

303 func_name = self._get_call_name(node) 

304 if func_name in ["eval", "exec", "compile"]: 

305 for arg in node.args: 

306 if isinstance(arg, ast.JoinedStr): # f-string 

307 self.errors.append(f"Dangerous pattern: {func_name} with f-string") 

308 elif isinstance(arg, ast.BinOp): # String concatenation 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 self.errors.append(f"Dangerous pattern: {func_name} with string concatenation")