Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: injectivity for truncated hashes #408

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,19 +1196,22 @@ def assume_sha3_distinct(self, sha3_expr) -> None:
#
# this approach results in O(n) constraints, where each constraint is independent from other hashes.

# injectivity is assumed for the lower 160-bit part, which is used for ethereum addresses
sha3_expr_core = Extract(159, 0, sha3_expr)

if eq(sha3_expr, f_sha3_empty):
self.path.append(f_inv_sha3_size(sha3_expr) == ZERO)
self.path.append(f_inv_sha3_size(sha3_expr_core) == ZERO)

else:
# sha3_expr is expected to be in the format: `sha3_<input_size>(input_data)`
input_data = sha3_expr.arg(0)
input_size = input_data.size()

f_inv_name = f_inv_sha3_name(input_size)
f_inv_sha3 = Function(f_inv_name, BitVecSort256, BitVecSorts[input_size])
self.path.append(f_inv_sha3(sha3_expr) == input_data)
f_inv_sha3 = Function(f_inv_name, BitVecSort160, BitVecSorts[input_size])
self.path.append(f_inv_sha3(sha3_expr_core) == input_data)

self.path.append(f_inv_sha3_size(sha3_expr) == con(input_size))
self.path.append(f_inv_sha3_size(sha3_expr_core) == con(input_size))

self.sha3s[sha3_expr] = len(self.sha3s)

Expand Down
2 changes: 1 addition & 1 deletion src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def f_inv_sha3_name(bitsize: int) -> str:


# TODO: explore the impact of using a smaller bitsize for the range sort
f_inv_sha3_size = Function("f_inv_sha3_size", BitVecSort256, BitVecSort256)
f_inv_sha3_size = Function("f_inv_sha3_size", BitVecSort160, BitVecSort256)


f_sha3_0_name = f_sha3_name(0)
Expand Down
36 changes: 36 additions & 0 deletions tests/expected/all.json
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,15 @@
}
],
"test/Sha3.t.sol:Sha3Test": [
{
"name": "check_address_collision_pass(uint256,uint256)",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_concrete_keccak_does_not_split_paths()",
"exitcode": 0,
Expand Down Expand Up @@ -2350,6 +2359,33 @@
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_uint128_collision_fail(uint256,uint256)",
"exitcode": 1,
"num_models": 1,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_uint160_collision_pass(uint256,uint256)",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_uint256_collision(uint256,uint256)",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/SignExtend.t.sol:SignExtendTest": [
Expand Down
27 changes: 27 additions & 0 deletions tests/regression/test/Sha3.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,31 @@ contract Sha3Test is Test, SymTest {
function _assert_neq(bytes memory data1, bytes memory data2) internal {
assert(keccak256(data1) != keccak256(data2));
}

function check_uint256_collision(uint256 x, uint256 y) public {
vm.assume(x != y);
assertNotEq(keccak256(abi.encode(x)), keccak256(abi.encode(y)));
}

// we assume that the lower 160-bit parts do not collide
// see: https://github.com/a16z/halmos/issues/347
function check_address_collision_pass(uint256 x, uint256 y) public {
vm.assume(x != y);
assertNotEq(to_address(x), to_address(y)); // pass
}

function to_address(uint256 x) internal pure returns (address) {
return address(uint160(uint256(keccak256(abi.encode(x)))));
}

function check_uint160_collision_pass(uint256 x, uint256 y) public {
vm.assume(x != y);
assertNotEq(uint160(uint256(keccak256(abi.encode(x)))), uint160(uint256(keccak256(abi.encode(y))))); // pass
}

// we don't rule out potential collision in the part lower than 160-bit
function check_uint128_collision_fail(uint256 x, uint256 y) public {
vm.assume(x != y);
assertNotEq(uint128(uint256(keccak256(abi.encode(x)))), uint128(uint256(keccak256(abi.encode(y))))); // fail
}
}
Loading