Skip to content

Commit

Permalink
cleanup code and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
0xalpharush committed Apr 7, 2024
1 parent e2d0047 commit 5a6ef0d
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 164 deletions.
3 changes: 3 additions & 0 deletions slither/core/declarations/function_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def is_declared_by(self, contract: "Contract") -> bool:

@property
def file_scope(self) -> "FileScope":
# This is the contract declarer's file scope because inherited functions have access
# to the file scope which their declared in. This scope may contain references not
# available in the child contract's scope. See inherited_function_scope.sol for an example.
return self.contract_declarer.file_scope

# endregion
Expand Down
3 changes: 1 addition & 2 deletions slither/core/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, List, Dict, Union
from typing import TYPE_CHECKING

from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel
from slither.utils.using_for import USING_FOR

Expand Down
44 changes: 7 additions & 37 deletions slither/core/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,62 +60,32 @@ def __init__(self, filename: Filename) -> None:
def add_accessible_scopes(self) -> bool: # pylint: disable=too-many-branches
"""
Add information from accessible scopes. Return true if new information was obtained
:return:
:rtype:
"""

learn_something = False

# This is a hacky way to support using for directives on user defined types and user defined functions
# since it is not reflected in the "exportedSymbols" field of the AST.
for new_scope in self.accessible_scopes:
# To support using for directives on user defined types and user defined functions,
# we need to propagate the using for directives from the imported file to the importing file
# since it is not reflected in the "exportedSymbols" field of the AST.
if not new_scope.using_for_directives.issubset(self.using_for_directives):
self.using_for_directives |= new_scope.using_for_directives
learn_something = True
print("using_for_directives", learn_something)
if not _dict_contain(new_scope.type_aliases, self.type_aliases):
self.type_aliases.update(new_scope.type_aliases)
learn_something = True
if not new_scope.functions.issubset(self.functions):
self.functions |= new_scope.functions
learn_something = True

# Hack to get around https://github.com/ethereum/solidity/pull/11881
# To get around this bug for aliases https://github.com/ethereum/solidity/pull/11881,
# we propagate the exported_symbols from the imported file to the importing file
# See tests/e2e/solc_parsing/test_data/top-level-nested-import-0.7.1.sol
if not new_scope.exported_symbols.issubset(self.exported_symbols):
self.exported_symbols |= new_scope.exported_symbols
learn_something = True
# if not new_scope.imports.issubset(self.imports):
# self.imports |= new_scope.imports
# learn_something = True
# if not _dict_contain(new_scope.contracts, self.contracts):
# self.contracts.update(new_scope.contracts)
# learn_something = True
# if not new_scope.custom_errors.issubset(self.custom_errors):
# self.custom_errors |= new_scope.custom_errors
# learn_something = True
# if not _dict_contain(new_scope.enums, self.enums):
# self.enums.update(new_scope.enums)
# learn_something = True
# if not new_scope.events.issubset(self.events):
# self.events |= new_scope.events
# learn_something = True
# if not new_scope.functions.issubset(self.functions):
# self.functions |= new_scope.functions
# learn_something = True
# if not new_scope.using_for_directives.issubset(self.using_for_directives):
# self.using_for_directives |= new_scope.using_for_directives
# learn_something = True

# if not new_scope.pragmas.issubset(self.pragmas):
# self.pragmas |= new_scope.pragmas
# learn_something = True
# if not _dict_contain(new_scope.structures, self.structures):
# self.structures.update(new_scope.structures)
# learn_something = True
# if not _dict_contain(new_scope.variables, self.variables):
# self.variables.update(new_scope.variables)
# learn_something = True

# This is need to support aliasing when we do a late lookup using SolidityImportPlaceholder
if not _dict_contain(new_scope.renaming, self.renaming):
self.renaming.update(new_scope.renaming)
Expand Down
99 changes: 47 additions & 52 deletions slither/slither.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,66 @@ def _check_common_things(
) -> None:

if not issubclass(cls, base_cls) or cls is base_cls:
raise Exception(
raise SlitherError(
f"You can't register {cls!r} as a {thing_name}. You need to pass a class that inherits from {base_cls.__name__}"
)

if any(type(obj) == cls for obj in instances_list): # pylint: disable=unidiomatic-typecheck
raise Exception(f"You can't register {cls!r} twice.")
raise SlitherError(f"You can't register {cls!r} twice.")


def _update_file_scopes(candidates: ValuesView[FileScope]):
def _update_file_scopes(sol_parser: SlitherCompilationUnitSolc):

Check warning on line 38 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

R0914: Too many local variables (16/15) (too-many-locals)
"""
Because solc's import allows cycle in the import
We iterate until we aren't adding new information to the scope
Since all definitions in a file are exported by default, including definitions from its (transitive) dependencies,
we can identify all top level items that could possibly be referenced within the file from its exportedSymbols.
It is not as straightforward for user defined types and functions as well as aliasing. See add_accessible_scopes for more details.
"""
candidates = sol_parser.compilation_unit.scopes.values()
learned_something = False
# Because solc's import allows cycle in the import graph, iterate until we aren't adding new information to the scope.
while True:
for candidate in candidates:
learned_something |= candidate.add_accessible_scopes()
if not learned_something:
break
learned_something = False

for scope in candidates:
for refId in scope.exported_symbols:
if refId in sol_parser.contracts_by_id:
contract = sol_parser.contracts_by_id[refId]
scope.contracts[contract.name] = contract
elif refId in sol_parser.functions_by_id:
functions = sol_parser.functions_by_id[refId]
assert len(functions) == 1
function = functions[0]
scope.functions.add(function)
elif refId in sol_parser._imports_by_id:

Check warning on line 64 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _imports_by_id of a client class (protected-access)
import_directive = sol_parser._imports_by_id[refId]

Check warning on line 65 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _imports_by_id of a client class (protected-access)
scope.imports.add(import_directive)
elif refId in sol_parser._top_level_variables_by_id:

Check warning on line 67 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_variables_by_id of a client class (protected-access)
top_level_variable = sol_parser._top_level_variables_by_id[refId]

Check warning on line 68 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_variables_by_id of a client class (protected-access)
scope.variables[top_level_variable.name] = top_level_variable
elif refId in sol_parser._top_level_events_by_id:

Check warning on line 70 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_events_by_id of a client class (protected-access)
top_level_event = sol_parser._top_level_events_by_id[refId]

Check warning on line 71 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_events_by_id of a client class (protected-access)
scope.events.add(top_level_event)
elif refId in sol_parser._top_level_structures_by_id:

Check warning on line 73 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_structures_by_id of a client class (protected-access)
top_level_struct = sol_parser._top_level_structures_by_id[refId]

Check warning on line 74 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_structures_by_id of a client class (protected-access)
scope.structures[top_level_struct.name] = top_level_struct
elif refId in sol_parser._top_level_type_aliases_by_id:

Check warning on line 76 in slither/slither.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

W0212: Access to a protected member _top_level_type_aliases_by_id of a client class (protected-access)
top_level_type_alias = sol_parser._top_level_type_aliases_by_id[refId]
scope.type_aliases[top_level_type_alias.name] = top_level_type_alias
elif refId in sol_parser._top_level_enums_by_id:
top_level_enum = sol_parser._top_level_enums_by_id[refId]
scope.enums[top_level_enum.name] = top_level_enum
elif refId in sol_parser._top_level_errors_by_id:
top_level_custom_error = sol_parser._top_level_errors_by_id[refId]
scope.custom_errors.add(top_level_custom_error)
else:
logger.warning(
f"Failed to resolved name for reference id {refId} in {scope.filename}."
)


class Slither(
SlitherCore
Expand Down Expand Up @@ -118,61 +156,18 @@ def __init__(self, target: Union[str, CryticCompile], **kwargs) -> None:
sol_parser.parse_top_level_items(ast, path)
self.add_source_code(path)

_update_file_scopes(compilation_unit_slither.scopes.values())
# First we save all the contracts in a dict
# the key is the contractid
for contract in sol_parser._underlying_contract_to_parser:
if contract.name.startswith("SlitherInternalTopLevelContract"):
raise Exception(
raise SlitherError(
# region multi-line-string
"""Your codebase has a contract named 'SlitherInternalTopLevelContract'.
Please rename it, this name is reserved for Slither's internals"""
# endregion multi-line
)
sol_parser._contracts_by_id[contract.id] = contract
sol_parser._compilation_unit.contracts.append(contract)
print("avalilable")
for k, v in sol_parser.contracts_by_id.items():
print(k, v.name)
for scope in compilation_unit_slither.scopes.values():
for refId in scope.exported_symbols:
print("scope", scope)
print("target", refId)
if refId in sol_parser.contracts_by_id:
contract = sol_parser.contracts_by_id[refId]
scope.contracts[contract.name] = contract
elif refId in sol_parser.functions_by_id:
print("found in functions")
functions = sol_parser.functions_by_id[refId]
assert len(functions) == 1
function = functions[0]
scope.functions.add(function)
elif refId in sol_parser._imports_by_id:
import_directive = sol_parser._imports_by_id[refId]
scope.imports.add(import_directive)
elif refId in sol_parser._top_level_variables_by_id:
top_level_variable = sol_parser._top_level_variables_by_id[refId]
scope.variables[top_level_variable.name] = top_level_variable
elif refId in sol_parser._top_level_events_by_id:
top_level_event = sol_parser._top_level_events_by_id[refId]
scope.events.add(top_level_event)
elif refId in sol_parser._top_level_structures_by_id:
top_level_struct = sol_parser._top_level_structures_by_id[refId]
scope.structures[top_level_struct.name] = top_level_struct
elif refId in sol_parser._top_level_type_aliases_by_id:
top_level_type_alias = sol_parser._top_level_type_aliases_by_id[refId]
scope.type_aliases[top_level_type_alias.name] = top_level_type_alias
elif refId in sol_parser._top_level_enums_by_id:
top_level_enum = sol_parser._top_level_enums_by_id[refId]
scope.enums[top_level_enum.name] = top_level_enum
elif refId in sol_parser._top_level_errors_by_id:
print("found in errors")
top_level_custom_error = sol_parser._top_level_errors_by_id[refId]
print(top_level_custom_error.name)
scope.custom_errors.add(top_level_custom_error)
else:
print("not found", refId)
assert False

_update_file_scopes(sol_parser)

if kwargs.get("generate_patches", False):
self.generate_patches = True
Expand Down
23 changes: 3 additions & 20 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,6 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
using_for = node_function.contract.using_for_complete
elif isinstance(node_function, FunctionTopLevel):
using_for = node_function.using_for_complete
# print("\n")
# print("using_for", )
# for key,v in using_for.items():
# print("key",key, )
# for i in v:
# print("value",i,i.__class__ )

if isinstance(ir, OperationWithLValue) and ir.lvalue:
# Force assignment in case of missing previous correct type
Expand Down Expand Up @@ -668,7 +662,6 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
ir, node_function.contract
)
if can_be_low_level(ir):
print("can be low level")
return convert_to_low_level(ir)

# Convert push operations
Expand Down Expand Up @@ -1509,7 +1502,6 @@ def convert_to_pop(ir: HighLevelCall, node: "Node") -> List[Operation]:


def look_for_library_or_top_level(
contract: Contract,
ir: HighLevelCall,
using_for,
t: Union[
Expand All @@ -1519,11 +1511,7 @@ def look_for_library_or_top_level(
TypeAliasTopLevel,
],
) -> Optional[Union[LibraryCall, InternalCall,]]:
print("look_for_library_or_top_level")
print(ir.expression.source_mapping.to_detailed_str())
print(ir.function_name)
for destination in using_for[t]:
print("destionation", destination, destination.__class__)
if isinstance(destination, FunctionTopLevel) and destination.name == ir.function_name:
arguments = [ir.destination] + ir.arguments
if (
Expand Down Expand Up @@ -1566,7 +1554,6 @@ def look_for_library_or_top_level(
new_ir = convert_type_library_call(lib_call, lib_contract)
if new_ir:
new_ir.set_node(ir.node)
print("new_ir", new_ir)
return new_ir
return None

Expand All @@ -1583,12 +1570,12 @@ def convert_to_library_or_top_level(
t = ir.destination.type

if t in using_for:
new_ir = look_for_library_or_top_level(contract, ir, using_for, t)
new_ir = look_for_library_or_top_level(ir, using_for, t)
if new_ir:
return new_ir

if "*" in using_for:
new_ir = look_for_library_or_top_level(contract, ir, using_for, "*")
new_ir = look_for_library_or_top_level(ir, using_for, "*")
if new_ir:
return new_ir

Expand All @@ -1599,7 +1586,7 @@ def convert_to_library_or_top_level(
and UserDefinedType(node.function.contract) in using_for
):
new_ir = look_for_library_or_top_level(
contract, ir, using_for, UserDefinedType(node.function.contract)
ir, using_for, UserDefinedType(node.function.contract)
)
if new_ir:
return new_ir
Expand Down Expand Up @@ -1754,9 +1741,6 @@ def convert_type_of_high_and_internal_level_call(
]

for import_statement in contract.file_scope.imports:
print(import_statement)
print(import_statement.alias)
print(ir.contract_name)
if (
import_statement.alias is not None
and import_statement.alias == ir.contract_name
Expand Down Expand Up @@ -1961,7 +1945,6 @@ def convert_constant_types(irs: List[Operation]) -> None:
if isinstance(func, StateVariable):
types = export_nested_types_from_variable(func)
else:
print("func", func, ir.expression.source_mapping.to_detailed_str())
types = [p.type for p in func.parameters]
assert len(types) == len(ir.arguments)
for idx, arg in enumerate(ir.arguments):
Expand Down
13 changes: 0 additions & 13 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,6 @@ def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-loc
element.is_shadowed = True
accessible_elements[element.full_name].shadows = True
except (VariableNotFound, KeyError) as e:
for c in self._contract.inheritance:
print(c.name, c.id)
for c2 in c.inheritance:
print("\t", c2.name, c2.id)
print("\n")
self.log_incorrect_parsing(
f"Missing params {e} {self._contract.source_mapping.to_detailed_str()}"
)
Expand Down Expand Up @@ -626,11 +621,6 @@ def analyze_using_for(self) -> None: # pylint: disable=too-many-branches
self._contract.using_for[type_name] = []

if "libraryName" in using_for:
# print(using_for["libraryName"])
# x =
# for f in x.type.functions:

# assert isinstance(f, Function), x.__class__
self._contract.using_for[type_name].append(
parse_type(using_for["libraryName"], self)
)
Expand All @@ -649,7 +639,6 @@ def analyze_using_for(self) -> None: # pylint: disable=too-many-branches
old = "*"
if old not in self._contract.using_for:
self._contract.using_for[old] = []

self._contract.using_for[old].append(new)
self._usingForNotParsed = []
except (VariableNotFound, KeyError) as e:
Expand Down Expand Up @@ -689,7 +678,6 @@ def _check_aliased_import(
def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None:
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
assert isinstance(tl_function, Function)
self._contract.using_for[type_name].append(tl_function)

def _analyze_library_function(
Expand All @@ -703,7 +691,6 @@ def _analyze_library_function(
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
assert isinstance(f, FunctionContract)
self._contract.using_for[type_name].append(f)
found = True
break
Expand Down
3 changes: 0 additions & 3 deletions slither/solc_parsing/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def analyze(self) -> None:

if self._library_name:
library_name = parse_type(self._library_name, self)
assert isinstance(library_name, UserDefinedType)
# for f in library_name.type.functions:
self._using_for.using_for[type_name].append(library_name)

self._propagate_global(type_name)
else:
for f in self._functions:
Expand Down
6 changes: 1 addition & 5 deletions slither/solc_parsing/expressions/find_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,8 @@ def _find_variable_init(
scope = underlying_function.file_scope
else:
assert isinstance(underlying_function, FunctionContract)
scope = underlying_function.file_scope
scope = underlying_function.contract.file_scope

# scope = underlying_function.file_scope
# assert False
elif isinstance(caller_context, StructureTopLevelSolc):
direct_contracts = []
direct_functions_parser = []
Expand Down Expand Up @@ -463,8 +461,6 @@ def find_variable(
return all_enums[var_name], False

contracts = current_scope.contracts
# print(*contracts)
# print(var_name in contracts)
if var_name in contracts:
return contracts[var_name], False

Expand Down
Loading

0 comments on commit 5a6ef0d

Please sign in to comment.