diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index dc12753f..1e9cd617 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -95,12 +95,16 @@ debug, extract_bytes, f_ecrecover, + f_sha3_256_name, + f_sha3_512_name, + f_sha3_name, hexify, int_of, is_bool, is_bv, is_bv_value, is_concrete, + is_f_sha3_name, is_non_zero, is_zero, match_dynamic_array_overflow_condition, @@ -1014,7 +1018,7 @@ def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None: if is_false(cond): return unsat - # Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))), where offset < 2**64 + # Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64 if match_dynamic_array_overflow_condition(cond): return unsat @@ -1085,7 +1089,7 @@ def sha3_data(self, data: Bytes) -> Word: data = bytes_to_bv_value(data) f_sha3 = Function( - f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256 + f_sha3_name(size * 8), BitVecSorts[size * 8], BitVecSort256 ) sha3_expr = f_sha3(data) else: @@ -1310,17 +1314,17 @@ def get_key_structure(cls, loc) -> tuple: def decode(cls, loc: Any) -> Any: loc = normalize(loc) # m[k] : hash(k.m) - if loc.decl().name() == "f_sha3_512": + if loc.decl().name() == f_sha3_512_name: args = loc.arg(0) offset = simplify(Extract(511, 256, args)) base = simplify(Extract(255, 0, args)) return cls.decode(base) + (offset, ZERO) # a[i] : hash(a) + i - elif loc.decl().name() == "f_sha3_256": + elif loc.decl().name() == f_sha3_256_name: base = loc.arg(0) return cls.decode(base) + (ZERO,) # m[k] : hash(k.m) where |k| != 256-bit - elif loc.decl().name().startswith("f_sha3_"): + elif is_f_sha3_name(loc.decl().name()): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2: offset = simplify(sha3_input.arg(0)) @@ -1439,12 +1443,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @classmethod def decode(cls, loc: Any) -> Any: loc = normalize(loc) - if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively + if loc.decl().name() == f_sha3_512_name: # hash(hi,lo), recursively args = loc.arg(0) hi = cls.decode(simplify(Extract(511, 256, args))) lo = cls.decode(simplify(Extract(255, 0, args))) return cls.simple_hash(Concat(hi, lo)) - elif loc.decl().name().startswith("f_sha3_"): + elif is_f_sha3_name(loc.decl().name()): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat": decoded_sha3_input_args = [ diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 54a54dbf..05a13f7f 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -98,6 +98,18 @@ def __getitem__(self, size: int) -> BitVecSort: ) +def is_f_sha3_name(name: str) -> bool: + return name.startswith("f_sha3_") + + +def f_sha3_name(bitsize: int) -> str: + return f"f_sha3_{bitsize}" + + +f_sha3_256_name = f_sha3_name(256) +f_sha3_512_name = f_sha3_name(512) + + def wrap(x: Any) -> Word: if is_bv(x): return x @@ -356,7 +368,7 @@ def byte_length(x: Any, strict=True) -> int: def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool: """ Check if `cond` matches the following pattern: - Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))), where offset < 2**64 + Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64 This condition is satisfied when a dynamic array at `slot` exceeds the storage limit. Since such an overflow is highly unlikely in practice, we assume that this condition is unsat. @@ -378,12 +390,12 @@ def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool: return False left, right = ule.arg(0), ule.arg(1) - # Not(ULE(f_sha3_256(slot), offset + base)) - if not (left.decl().name() == "f_sha3_256" and is_app_of(right, Z3_OP_BADD)): + # Not(ULE(f_sha3_N(slot), offset + base)) + if not (is_f_sha3_name(left.decl().name()) and is_app_of(right, Z3_OP_BADD)): return False offset, base = right.arg(0), right.arg(1) - # Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) and offset < 2**64 + # Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))) and offset < 2**64 return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64 diff --git a/tests/test_utils.py b/tests/test_utils.py index e986eeec..cdb214ad 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,12 +8,12 @@ simplify, ) -from halmos.utils import match_dynamic_array_overflow_condition +from halmos.utils import f_sha3_256_name, match_dynamic_array_overflow_condition def test_match_dynamic_array_overflow_condition(): # Create Z3 objects - f_sha3_256 = Function("f_sha3_256", BitVecSort(256), BitVecSort(256)) + f_sha3_256 = Function(f_sha3_256_name, BitVecSort(256), BitVecSort(256)) slot = BitVec("slot", 256) offset = BitVecVal(1000, 256) # Less than 2**64