Skip to content

Commit

Permalink
avoid magic string for sha3 names
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark committed Sep 24, 2024
1 parent 610c8dc commit edd2dea
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
18 changes: 11 additions & 7 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@
debug,
extract_bytes,
f_ecrecover,
f_sha3_256_name,
f_sha3_512_name,
f_sha3_name,
hexify,
int_of,
is_bool,
is_bv,
is_bv_value,
is_concrete,
is_f_sha3_name,
is_non_zero,
is_zero,
match_dynamic_array_overflow_condition,
Expand Down Expand Up @@ -1014,7 +1018,7 @@ def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None:
if is_false(cond):
return unsat

# Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))), where offset < 2**64
# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64
if match_dynamic_array_overflow_condition(cond):
return unsat

Expand Down Expand Up @@ -1085,7 +1089,7 @@ def sha3_data(self, data: Bytes) -> Word:
data = bytes_to_bv_value(data)

f_sha3 = Function(
f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256
f_sha3_name(size * 8), BitVecSorts[size * 8], BitVecSort256
)
sha3_expr = f_sha3(data)
else:
Expand Down Expand Up @@ -1310,17 +1314,17 @@ def get_key_structure(cls, loc) -> tuple:
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
# m[k] : hash(k.m)
if loc.decl().name() == "f_sha3_512":
if loc.decl().name() == f_sha3_512_name:
args = loc.arg(0)
offset = simplify(Extract(511, 256, args))
base = simplify(Extract(255, 0, args))
return cls.decode(base) + (offset, ZERO)
# a[i] : hash(a) + i
elif loc.decl().name() == "f_sha3_256":
elif loc.decl().name() == f_sha3_256_name:
base = loc.arg(0)
return cls.decode(base) + (ZERO,)
# m[k] : hash(k.m) where |k| != 256-bit
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2:
offset = simplify(sha3_input.arg(0))
Expand Down Expand Up @@ -1439,12 +1443,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
@classmethod
def decode(cls, loc: Any) -> Any:
loc = normalize(loc)
if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively
if loc.decl().name() == f_sha3_512_name: # hash(hi,lo), recursively
args = loc.arg(0)
hi = cls.decode(simplify(Extract(511, 256, args)))
lo = cls.decode(simplify(Extract(255, 0, args)))
return cls.simple_hash(Concat(hi, lo))
elif loc.decl().name().startswith("f_sha3_"):
elif is_f_sha3_name(loc.decl().name()):
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat":
decoded_sha3_input_args = [
Expand Down
20 changes: 16 additions & 4 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ def __getitem__(self, size: int) -> BitVecSort:
)


def is_f_sha3_name(name: str) -> bool:
return name.startswith("f_sha3_")


def f_sha3_name(bitsize: int) -> str:
return f"f_sha3_{bitsize}"


f_sha3_256_name = f_sha3_name(256)
f_sha3_512_name = f_sha3_name(512)


def wrap(x: Any) -> Word:
if is_bv(x):
return x
Expand Down Expand Up @@ -356,7 +368,7 @@ def byte_length(x: Any, strict=True) -> int:
def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool:
"""
Check if `cond` matches the following pattern:
Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))), where offset < 2**64
Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64
This condition is satisfied when a dynamic array at `slot` exceeds the storage limit.
Since such an overflow is highly unlikely in practice, we assume that this condition is unsat.
Expand All @@ -378,12 +390,12 @@ def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool:
return False
left, right = ule.arg(0), ule.arg(1)

# Not(ULE(f_sha3_256(slot), offset + base))
if not (left.decl().name() == "f_sha3_256" and is_app_of(right, Z3_OP_BADD)):
# Not(ULE(f_sha3_N(slot), offset + base))
if not (is_f_sha3_name(left.decl().name()) and is_app_of(right, Z3_OP_BADD)):
return False
offset, base = right.arg(0), right.arg(1)

# Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) and offset < 2**64
# Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))) and offset < 2**64
return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64


Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
simplify,
)

from halmos.utils import match_dynamic_array_overflow_condition
from halmos.utils import f_sha3_256_name, match_dynamic_array_overflow_condition


def test_match_dynamic_array_overflow_condition():
# Create Z3 objects
f_sha3_256 = Function("f_sha3_256", BitVecSort(256), BitVecSort(256))
f_sha3_256 = Function(f_sha3_256_name, BitVecSort(256), BitVecSort(256))
slot = BitVec("slot", 256)
offset = BitVecVal(1000, 256) # Less than 2**64

Expand Down

0 comments on commit edd2dea

Please sign in to comment.