Skip to content

Commit

Permalink
Fix issue on checking variable assignment for code verification (#337)
Browse files Browse the repository at this point in the history
The previous check accidently forbids complex assignment, which should
be allowd.

See this issue:
#277
  • Loading branch information
ShilinHe authored May 10, 2024
2 parents 3207520 + 8bee922 commit c5446c5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
35 changes: 23 additions & 12 deletions taskweaver/code_interpreter/code_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def _is_allowed_function_call(self, func_name: str) -> bool:
return True

def visit_Call(self, node):
if self.allowed_functions is None and self.blocked_functions is None:
return

if isinstance(node.func, ast.Name):
function_name = node.func.id
elif isinstance(node.func, ast.Attribute):
Expand All @@ -67,6 +70,9 @@ def _is_allowed_module_import(self, mod_name: str) -> bool:
return True

def visit_Import(self, node):
if self.allowed_modules is None and self.blocked_modules is None:
return

for alias in node.names:
if "." in alias.name:
module_name = alias.name.split(".")[0]
Expand All @@ -80,6 +86,9 @@ def visit_Import(self, node):
)

def visit_ImportFrom(self, node):
if self.allowed_modules is None and self.blocked_modules is None:
return

if "." in node.module:
module_name = node.module.split(".")[0]
else:
Expand All @@ -99,21 +108,23 @@ def _is_allowed_variable(self, var_name: str) -> bool:
return True

def visit_Assign(self, node: ast.Assign):
if self.allowed_variables is None:
return

for target in node.targets:
variable_names = []
if isinstance(target, ast.Name):
variable_name = target.id
variable_names.append(target.id)
else:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} "
"=> Complex assignments are not allowed.",
)
continue

if not self._is_allowed_variable(variable_name):
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} "
f"=> Assigning to {variable_name} is not allowed.",
)
for name in ast.walk(target):
if isinstance(name, ast.Name):
variable_names.append(name.id)
for variable_name in variable_names:
if not self._is_allowed_variable(variable_name):
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} "
f"=> Assigning to {variable_name} is not allowed.",
)

def generic_visit(self, node):
super().generic_visit(node)
Expand Down
10 changes: 10 additions & 0 deletions tests/unit_tests/test_code_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ def test_allow_variable():
print("---->", code_verify_errors)
assert len(code_verify_errors) == 0

code_snippet = "name, age = 'John', 25\n" "print(f'Hello, {name}! You are {age} years old.')\n"
allowed_variables = ["name", "age"]
code_verify_errors = code_snippet_verification(
code_snippet,
code_verification_on=True,
allowed_variables=allowed_variables,
)
print("---->", code_verify_errors)
assert len(code_verify_errors) == 0


def test_magic_code():
code_snippet = (
Expand Down

0 comments on commit c5446c5

Please sign in to comment.