Skip to content

Commit

Permalink
Merge pull request #2589 from crytic/dev-event-selector-ir
Browse files Browse the repository at this point in the history
Fix IR conversion when an Event selector is accessed
  • Loading branch information
montyly authored Oct 17, 2024
2 parents 9e89bbb + 5488d50 commit 79619f6
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 29 deletions.
2 changes: 1 addition & 1 deletion slither/core/solidity_types/function_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def storage_size(self) -> Tuple[int, bool]:
def is_dynamic(self) -> bool:
return False

def __str__(self):
def __str__(self) -> str:
# Use x.type
# x.name may be empty
params = ",".join([str(x.type) for x in self._params])
Expand Down
7 changes: 4 additions & 3 deletions slither/core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
from slither.core.declarations import Function

# pylint: disable=too-many-instance-attributes
class Variable(SourceMapping):
def __init__(self) -> None:
super().__init__()
self._name: Optional[str] = None
self._initial_expression: Optional["Expression"] = None
self._type: Optional[Type] = None
self._type: Optional[Union[List, Type, "Function", str]] = None
self._initialized: Optional[bool] = None
self._visibility: Optional[str] = None
self._is_constant = False
Expand Down Expand Up @@ -77,7 +78,7 @@ def name(self, name: str) -> None:
self._name = name

@property
def type(self) -> Optional[Type]:
def type(self) -> Optional[Union[List, Type, "Function", str]]:
return self._type

@type.setter
Expand Down Expand Up @@ -120,7 +121,7 @@ def visibility(self) -> Optional[str]:
def visibility(self, v: str) -> None:
self._visibility = v

def set_type(self, t: Optional[Union[List, Type, str]]) -> None:
def set_type(self, t: Optional[Union[List, Type, "Function", str]]) -> None:
if isinstance(t, str):
self._type = ElementaryType(t)
return
Expand Down
51 changes: 26 additions & 25 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias
from slither.core.variables.function_type_variable import FunctionTypeVariable
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable
from slither.slithir.exceptions import SlithIRError
Expand Down Expand Up @@ -81,7 +80,7 @@
from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure
from slither.slithir.variables import Constant, ReferenceVariable, TemporaryVariable
from slither.slithir.variables import TupleVariable
from slither.utils.function import get_function_id
from slither.utils.function import get_function_id, get_event_id
from slither.utils.type import export_nested_types_from_variable
from slither.utils.using_for import USING_FOR
from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR
Expand Down Expand Up @@ -279,20 +278,6 @@ def is_temporary(ins: Operation) -> bool:
)


def _make_function_type(func: Function) -> FunctionType:
parameters = []
returns = []
for parameter in func.parameters:
v = FunctionTypeVariable()
v.name = parameter.name
parameters.append(v)
for return_var in func.returns:
v = FunctionTypeVariable()
v.name = return_var.name
returns.append(v)
return FunctionType(parameters, returns)


# endregion
###################################################################################
###################################################################################
Expand Down Expand Up @@ -793,12 +778,29 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
assignment.set_node(ir.node)
assignment.lvalue.set_type(ElementaryType("bytes4"))
return assignment
if ir.variable_right == "selector" and isinstance(
ir.variable_left.type, (Function)
if ir.variable_right == "selector" and isinstance(ir.variable_left, (Event)):
# the event selector returns a bytes32, which is different from the error/function selector
# which returns a bytes4
assignment = Assignment(
ir.lvalue,
Constant(
str(get_event_id(ir.variable_left.full_name)), ElementaryType("bytes32")
),
ElementaryType("bytes32"),
)
assignment.set_expression(ir.expression)
assignment.set_node(ir.node)
assignment.lvalue.set_type(ElementaryType("bytes32"))
return assignment
if ir.variable_right == "selector" and (
isinstance(ir.variable_left.type, (Function))
):
assignment = Assignment(
ir.lvalue,
Constant(str(get_function_id(ir.variable_left.type.full_name))),
Constant(
str(get_function_id(ir.variable_left.type.full_name)),
ElementaryType("bytes4"),
),
ElementaryType("bytes4"),
)
assignment.set_expression(ir.expression)
Expand Down Expand Up @@ -826,10 +828,9 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
targeted_function = next(
(x for x in ir_func.contract.functions if x.name == str(ir.variable_right))
)
t = _make_function_type(targeted_function)
ir.lvalue.set_type(t)
ir.lvalue.set_type(targeted_function)
elif isinstance(left, (Variable, SolidityVariable)):
t = ir.variable_left.type
t = left.type
elif isinstance(left, (Contract, Enum, Structure)):
t = UserDefinedType(left)
# can be None due to temporary operation
Expand All @@ -846,10 +847,10 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
ir.lvalue.set_type(elems[elem].type)
else:
assert isinstance(type_t, Contract)
# Allow type propagtion as a Function
# Allow type propagation as a Function
# Only for reference variables
# This allows to track the selector keyword
# We dont need to check for function collision, as solc prevents the use of selector
# We don't need to check for function collision, as solc prevents the use of selector
# if there are multiple functions with the same name
f = next(
(f for f in type_t.functions if f.name == ir.variable_right),
Expand All @@ -858,7 +859,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
if f:
ir.lvalue.set_type(f)
else:
# Allow propgation for variable access through contract's name
# Allow propagation for variable access through contract's name
# like Base_contract.my_variable
v = next(
(
Expand Down
13 changes: 13 additions & 0 deletions slither/utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ def get_function_id(sig: str) -> int:
digest = keccak.new(digest_bits=256)
digest.update(sig.encode("utf-8"))
return int("0x" + digest.hexdigest()[:8], 16)


def get_event_id(sig: str) -> int:
"""'
Return the event id of the given signature
Args:
sig (str)
Return:
(int)
"""
digest = keccak.new(digest_bits=256)
digest.update(sig.encode("utf-8"))
return int("0x" + digest.hexdigest(), 16)
47 changes: 47 additions & 0 deletions tests/unit/slithir/test_data/selector.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
interface I{
function testFunction(uint a) external ;
}

contract A{
function testFunction() public{}
}

contract Test{
event TestEvent();
struct St{
uint a;
}
error TestError();

function testFunction(uint a) public {}


function testFunctionStructure(St memory s) public {}

function returnEvent() public returns (bytes32){
return TestEvent.selector;
}

function returnError() public returns (bytes4){
return TestError.selector;
}


function returnFunctionFromContract() public returns (bytes4){
return I.testFunction.selector;
}


function returnFunction() public returns (bytes4){
return this.testFunction.selector;
}

function returnFunctionWithStructure() public returns (bytes4){
return this.testFunctionStructure.selector;
}

function returnFunctionThroughLocaLVar() public returns(bytes4){
A a;
return a.testFunction.selector;
}
}
32 changes: 32 additions & 0 deletions tests/unit/slithir/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from slither import Slither
from slither.slithir.operations import Assignment
from slither.slithir.variables import Constant

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


func_to_results = {
"returnEvent()": "16700440330922901039223184000601971290390760458944929668086539975128325467771",
"returnError()": "224292994",
"returnFunctionFromContract()": "890000139",
"returnFunction()": "890000139",
"returnFunctionWithStructure()": "1430834845",
"returnFunctionThroughLocaLVar()": "3781905051",
}


def test_enum_max_min(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.19")
slither = Slither(Path(TEST_DATA_DIR, "selector.sol").as_posix(), solc=solc_path)

contract = slither.get_contract_from_name("Test")[0]

for func_name, value in func_to_results.items():
f = contract.get_function_from_signature(func_name)
assignment = f.slithir_operations[0]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == value
)

0 comments on commit 79619f6

Please sign in to comment.