diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 14992af4..81919a1c 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -13,7 +13,7 @@ import uuid from collections import Counter from copy import deepcopy -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from datetime import timedelta from enum import Enum from importlib import metadata @@ -111,6 +111,138 @@ VERBOSITY_TRACE_CONSTRUCTOR = 5 +@dataclass +class PotentialModel: + model: AnyModel + is_valid: bool + + def __init__(self, model: ModelRef | str, args: HalmosConfig) -> None: + # convert model into string to avoid pickling errors for z3 (ctypes) objects containing pointers + self.model = ( + to_str_model(model, args.print_full_model) + if isinstance(model, ModelRef) + else model + ) + self.is_valid = is_model_valid(model) + + def __str__(self) -> str: + # expected to be a filename + if isinstance(self.model, str): + return f"see {self.model}" + + formatted = [f"\n {decl} = {val}" for decl, val in self.model.items()] + return "".join(sorted(formatted)) if formatted else "∅" + + +@dataclass(frozen=True) +class ContractContext: + # config with contract-specific overrides + args: HalmosConfig + + # name of this contract + name: str + + # signatures of test functions to run + funsigs: list[str] + + # data parsed from the build output for this contract + creation_hexcode: str + deployed_hexcode: str + abi: dict + method_identifiers: dict[str, str] + contract_json: dict + libs: dict + + # TODO: check if this is really a contract-level variable + build_out_map: dict + + +@dataclass(frozen=True) +class FunctionContext: + # config with function-specific overrides + args: HalmosConfig + + # function name, signature, and selector + info: FunctionInfo + + # solver using the function-specific config + solver: Solver + + # backlink to the parent contract context + contract_ctx: ContractContext + + # optional starting state + setup_ex: Exec | None = None + + # dump directory for this function (generated in __post_init__) + dump_dirname: str = field(init=False) + + def __post_init__(self): + object.__setattr__( + self, "dump_dirname", f"/tmp/{self.info.name}-{uuid.uuid4().hex}" + ) + + +@dataclass +class PathContext: + # id of this path + path_id: int + + # path execution object + ex: Exec + + # SMT query + query: SMTQuery + + # backlink to the parent function context + fun_ctx: FunctionContext + + # filename for this path (generated in __post_init__) + dump_filename: str = field(init=False) + + def __post_init__(self): + object.__setattr__( + self, "dump_filename", f"{self.fun_ctx.dump_dirname}/{self.path_id}.smt2" + ) + + +# XXX remove ModelWithContext +@dataclass(frozen=True) +class ModelWithContext: + # can be a filename containing the model or a dict with variable assignments + model: PotentialModel | None + index: int + result: CheckSatResult + unsat_core: list | None + + +@dataclass(frozen=True) +class TestResult: + name: str # test function name + exitcode: int + num_models: int = None + models: list[ModelWithContext] = None + num_paths: tuple[int, int, int] = None # number of paths: [total, success, blocked] + time: tuple[int, int, int] = None # time: [total, paths, models] + num_bounded_loops: int = None # number of incomplete loops + + +@dataclass(frozen=True) +class SolverOutput: + result: CheckSatResult + model: PotentialModel | None = None + unsat_core: list[str] | None = None + + +class Exitcode(Enum): + PASS = 0 + COUNTEREXAMPLE = 1 + TIMEOUT = 2 + STUCK = 3 + REVERT_ALL = 4 + EXCEPTION = 5 + + def with_devdoc(args: HalmosConfig, fn_sig: str, contract_json: dict) -> HalmosConfig: devdoc = parse_devdoc(fn_sig, contract_json) if not devdoc: @@ -182,21 +314,16 @@ def mk_this() -> Address: return con_addr(FOUNDRY_TEST) -def mk_solver(args: HalmosConfig, logic="QF_AUFBV", ctx=None, assertion=False): - timeout = ( - args.solver_timeout_assertion if assertion else args.solver_timeout_branching +def mk_solver(args: HalmosConfig, logic="QF_AUFBV", ctx=None): + return create_solver( + logic=logic, + ctx=ctx, + timeout=args.solver_timeout_branching, + max_memory=args.solver_max_memory, ) - return create_solver(logic, ctx, timeout, args.solver_max_memory) -def deploy_test( - creation_hexcode: str, - deployed_hexcode: str, - sevm: SEVM, - args: HalmosConfig, - libs: dict, - solver: Solver, -) -> Exec: +def deploy_test(ctx: ContractContext, sevm: SEVM) -> Exec: this = mk_this() message = Message( target=this, @@ -214,12 +341,12 @@ def deploy_test( block=mk_block(), context=CallContext(message=message), pgm=None, # to be added - path=Path(solver), + path=Path(ctx.solver), ) # deploy libraries and resolve library placeholders in hexcode - (creation_hexcode, deployed_hexcode) = ex.resolve_libs( - creation_hexcode, deployed_hexcode, libs + (creation_hexcode, _) = ex.resolve_libs( + ctx.creation_hexcode, ctx.deployed_hexcode, ctx.libs ) # test contract creation bytecode @@ -233,9 +360,9 @@ def deploy_test( if len(exs) != 1: raise ValueError(f"constructor: # of paths: {len(exs)}") - ex = exs[0] + [ex] = exs - if args.verbose >= VERBOSITY_TRACE_CONSTRUCTOR: + if ctx.args.verbose >= VERBOSITY_TRACE_CONSTRUCTOR: print("Constructor trace:") render_trace(ex.context) @@ -259,71 +386,71 @@ def deploy_test( return ex -def setup( - creation_hexcode: str, - deployed_hexcode: str, - abi: dict, - setup_info: FunctionInfo, - args: HalmosConfig, - libs: dict, - solver: Solver, -) -> Exec: +def setup(ctx: FunctionContext) -> Exec: setup_timer = NamedTimer("setup") setup_timer.create_subtimer("decode") + setup_info = ctx.info + args = ctx.args sevm = SEVM(args) - setup_ex = deploy_test(creation_hexcode, deployed_hexcode, sevm, args, libs, solver) + setup_ex = deploy_test(ctx, sevm) + + setup_sig = setup_info.sig + if not setup_sig: + if args.statistics: + print(setup_timer.report()) + return setup_ex setup_timer.create_subtimer("run") - setup_sig = setup_info.sig - if setup_sig: - # TODO: dyn_params may need to be passed to mk_calldata in run() - calldata, dyn_params = mk_calldata(abi, setup_info, args) - setup_ex.path.process_dyn_params(dyn_params) - - parent_message = setup_ex.message() - setup_ex.context = CallContext( - message=Message( - target=parent_message.target, - caller=parent_message.caller, - origin=parent_message.origin, - value=0, - data=calldata, - call_scheme=EVM.CALL, - ), - ) + # TODO: dyn_params may need to be passed to mk_calldata in run() + calldata, dyn_params = mk_calldata(ctx.abi, setup_info, args) + setup_ex.path.process_dyn_params(dyn_params) + + parent_message = setup_ex.message() + setup_ex.context = CallContext( + message=Message( + target=parent_message.target, + caller=parent_message.caller, + origin=parent_message.origin, + value=0, + data=calldata, + call_scheme=EVM.CALL, + ), + ) - setup_exs_all = sevm.run(setup_ex) - setup_exs_no_error = [] + setup_exs_all = sevm.run(setup_ex) + setup_exs_no_error = [] - for idx, setup_ex in enumerate(setup_exs_all): - if args.verbose >= VERBOSITY_TRACE_SETUP: - print(f"{setup_sig} trace #{idx+1}:") - render_trace(setup_ex.context) + for idx, setup_ex in enumerate(setup_exs_all): + if args.verbose >= VERBOSITY_TRACE_SETUP: + print(f"{setup_sig} trace #{idx+1}:") + render_trace(setup_ex.context) - if not (err := setup_ex.context.output.error): - setup_exs_no_error.append((setup_ex, setup_ex.path.to_smt2(args))) + if not (err := setup_ex.context.output.error): + setup_exs_no_error.append((setup_ex, setup_ex.path.to_smt2(args))) - else: - opcode = setup_ex.current_opcode() - if opcode not in [EVM.REVERT, EVM.INVALID]: - warn_code( - INTERNAL_ERROR, - f"in {setup_sig}, executing {mnemonic(opcode)} failed with: {err}", - ) + else: + opcode = setup_ex.current_opcode() + if opcode not in [EVM.REVERT, EVM.INVALID]: + warn_code( + INTERNAL_ERROR, + f"in {setup_sig}, executing {mnemonic(opcode)} failed with: {err}", + ) - # only render the trace if we didn't already do it - if ( - args.verbose < VERBOSITY_TRACE_SETUP - and args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE - ): - print(f"{setup_sig} trace:") - render_trace(setup_ex.context) + # only render the trace if we didn't already do it + if VERBOSITY_TRACE_COUNTEREXAMPLE <= args.verbose < VERBOSITY_TRACE_SETUP: + print(f"{setup_sig} trace:") + render_trace(setup_ex.context) - setup_exs = [] + setup_exs = [] - if len(setup_exs_no_error) > 1: + match len(setup_exs_no_error): + case 0: + pass + case 1: + setup_exs.append(setup_exs_no_error[0][0]) + case _: for setup_ex, query in setup_exs_no_error: res, _, _ = solve(query, args) if res != unsat: @@ -331,28 +458,24 @@ def setup( if len(setup_exs) > 1: break - elif len(setup_exs_no_error) == 1: - setup_exs.append(setup_exs_no_error[0][0]) - - if len(setup_exs) == 0: + match len(setup_exs): + case 0: raise HalmosException(f"No successful path found in {setup_sig}") - - if len(setup_exs) > 1: + case n if n > 1: debug("\n".join(map(str, setup_exs))) - raise HalmosException(f"Multiple paths were found in {setup_sig}") - setup_ex = setup_exs[0] + [setup_ex] = setup_exs - if args.print_setup_states: - print(setup_ex) + if args.print_setup_states: + print(setup_ex) - if sevm.logs.bounded_loops: - warn_code( - LOOP_BOUND, - f"{setup_sig}: paths have not been fully explored due to the loop unrolling bound: {args.loop}", - ) - debug("\n".join(jumpid_str(x) for x in sevm.logs.bounded_loops)) + if sevm.logs.bounded_loops: + warn_code( + LOOP_BOUND, + f"{setup_sig}: paths have not been fully explored due to the loop unrolling bound: {args.loop}", + ) + debug("\n".join(jumpid_str(x) for x in sevm.logs.bounded_loops)) if args.statistics: print(setup_timer.report()) @@ -360,70 +483,14 @@ def setup( return setup_ex -@dataclass -class PotentialModel: - model: AnyModel - is_valid: bool - - def __init__(self, model: ModelRef | str, args: HalmosConfig) -> None: - # convert model into string to avoid pickling errors for z3 (ctypes) objects containing pointers - self.model = ( - to_str_model(model, args.print_full_model) - if isinstance(model, ModelRef) - else model - ) - self.is_valid = is_model_valid(model) - - def __str__(self) -> str: - # expected to be a filename - if isinstance(self.model, str): - return f"see {self.model}" - - formatted = [f"\n {decl} = {val}" for decl, val in self.model.items()] - return "".join(sorted(formatted)) if formatted else "∅" - - -@dataclass(frozen=True) -class ModelWithContext: - # can be a filename containing the model or a dict with variable assignments - model: PotentialModel | None - index: int - result: CheckSatResult - unsat_core: list | None - - -@dataclass(frozen=True) -class TestResult: - name: str # test function name - exitcode: int - num_models: int = None - models: list[ModelWithContext] = None - num_paths: tuple[int, int, int] = None # number of paths: [total, success, blocked] - time: tuple[int, int, int] = None # time: [total, paths, models] - num_bounded_loops: int = None # number of incomplete loops - - -class Exitcode(Enum): - PASS = 0 - COUNTEREXAMPLE = 1 - TIMEOUT = 2 - STUCK = 3 - REVERT_ALL = 4 - EXCEPTION = 5 - - def is_global_fail_set(context: CallContext) -> bool: hevm_fail = isinstance(context.output.error, FailCheatcode) return hevm_fail or any(is_global_fail_set(x) for x in context.subcalls()) -def run_test( - setup_ex: Exec, - abi: dict, - fun_info: FunctionInfo, - args: HalmosConfig, - solver: Solver, -) -> TestResult: +def run_test(ctx: FunctionContext) -> TestResult: + args = ctx.args + fun_info = ctx.info funname, funsig = fun_info.name, fun_info.sig if args.verbose >= 1: print(f"Executing {funname}") @@ -432,7 +499,7 @@ def run_test( # prepare test dump directory if needed # - dump_dirname = f"/tmp/{funname}-{uuid.uuid4().hex}" + dump_dirname = ctx.dump_dirname should_dump = args.dump_smt_queries or args.solver_command if should_dump and not os.path.isdir(dump_dirname): os.makedirs(dump_dirname) @@ -442,11 +509,12 @@ def run_test( # prepare calldata # + setup_ex = ctx.setup_ex sevm = SEVM(args) - path = Path(solver) + path = Path(ctx.solver) path.extend_path(setup_ex.path) - cd, dyn_params = mk_calldata(abi, fun_info, args) + cd, dyn_params = mk_calldata(ctx.abi, fun_info, args) path.process_dyn_params(dyn_params) message = Message( @@ -582,7 +650,14 @@ def future_callback(future: PopenFuture): query: SMTQuery = ex.path.to_smt2(args) - dump(query, args, dump_filename=f"{dump_dirname}/{path_id}.smt2") + path_ctx = PathContext( + path_id=path_id, + ex=ex, + query=query, + fun_info=ctx, + ) + + dump(path_ctx) # if the query contains an unsat-core, it is unsat; no need to run the solver if check_unsat_cores(query, unsat_cores): @@ -623,11 +698,15 @@ def future_callback(future: PopenFuture): break num_execs = path_id + + # the name is a bit misleading: this timer only starts after the exploration phase is complete + # but it's possible that solvers have already been running for a while timer.create_subtimer("models") if potential > 0 and args.verbose >= 1: print( - f"# of potential paths involving assertion violations: {potential} / {num_execs} (--solver-threads {args.solver_threads})" + f"# of potential paths involving assertion violations: {potential} / {num_execs}" + f" (--solver-threads {args.solver_threads})" ) # display assertion solving progress @@ -728,63 +807,51 @@ def extract_setup(methodIdentifiers: dict[str, str]) -> FunctionInfo: return FunctionInfo(setup_name, setup_sig, setup_selector) -@dataclass(frozen=True) -class RunArgs: - # signatures of test functions to run - funsigs: list[str] - - # code of the current contract - creation_hexcode: str - deployed_hexcode: str +def run_contract(ctx: ContractContext) -> list[TestResult]: + BuildOut().set_build_out(ctx.build_out_map) - abi: dict - methodIdentifiers: dict[str, str] - - args: HalmosConfig - contract_json: dict - libs: dict - - build_out_map: dict - - -def run_contract(run_args: RunArgs) -> list[TestResult]: - BuildOut().set_build_out(run_args.build_out_map) - - args = run_args.args - setup_info = extract_setup(run_args.methodIdentifiers) + args = ctx.args + setup_info = extract_setup(ctx.method_identifiers) try: - setup_config = with_devdoc(args, setup_info.sig, run_args.contract_json) + setup_config = with_devdoc(args, setup_info.sig, ctx.contract_json) setup_solver = mk_solver(setup_config) - setup_ex = setup( - run_args.creation_hexcode, - run_args.deployed_hexcode, - run_args.abi, - setup_info, - setup_config, - run_args.libs, - setup_solver, + setup_ctx = FunctionContext( + args=setup_config, + info=setup_info, + solver=setup_solver, + contract_ctx=ctx, ) + + setup_ex = setup(setup_ctx) except Exception as err: error(f"{setup_info.sig} failed: {type(err).__name__}: {err}") if args.debug: traceback.print_exc() + # reset any remaining solver states from the default context setup_solver.reset() + return [] test_results = [] - for funsig in run_args.funsigs: - fun_info = FunctionInfo( - funsig.split("(")[0], funsig, run_args.methodIdentifiers[funsig] - ) + for funsig in ctx.funsigs: + selector = ctx.method_identifiers[funsig] + fun_info = FunctionInfo(funsig.split("(")[0], funsig, selector) try: - test_config = with_devdoc(args, funsig, run_args.contract_json) + test_config = with_devdoc(args, funsig, ctx.contract_json) solver = mk_solver(test_config) debug(f"{test_config.formatted_layers()}") - test_result = run_test( - setup_ex, run_args.abi, fun_info, test_config, solver + + test_ctx = FunctionContext( + args=test_config, + info=fun_info, + solver=solver, + contract_ctx=ctx, + setup_ex=setup_ex, ) + + test_result = run_test(test_ctx) except Exception as err: print(f"{color_error('[ERROR]')} {funsig}") error(f"{type(err).__name__}: {err}") @@ -831,10 +898,11 @@ def parse_unsat_core(output) -> list[str] | None: def dump( - query: SMTQuery, args: HalmosConfig, dump_filename: str | None = None + path_ctx: PathContext, ) -> tuple[CheckSatResult, PotentialModel | None, list | None]: - if not dump_filename: - dump_filename = f"/tmp/{uuid.uuid4().hex}.smt2" + args = path_ctx.fun_ctx.args + query = path_ctx.query + dump_filename = path_ctx.dump_filename if args.verbose >= 1: debug(f"Writing SMT query to {dump_filename}") @@ -867,13 +935,6 @@ def dump( f.write("(get-model)\n") -@dataclass(frozen=True) -class SolverOutput: - result: CheckSatResult - model: PotentialModel | None = None - unsat_core: list[str] | None = None - - def solve(smt2_filename: str, args: HalmosConfig) -> SolverOutput: if args.verbose >= 1: debug(" Checking with external solver process") @@ -891,22 +952,27 @@ def solve(smt2_filename: str, args: HalmosConfig) -> SolverOutput: res_str = subprocess.run( cmd, capture_output=True, text=True, timeout=timeout_seconds ).stdout.strip() - res_str_head = res_str.split("\n", 1)[0] + # save solver output to file with open(f"{smt2_filename}.out", "w") as f: f.write(res_str) + # extract the first line (we expect sat/unsat/unknown) + newline_idx = res_str.find("\n") + res_str_head = res_str[:newline_idx] if newline_idx != -1 else res_str if args.verbose >= 1: debug(f" {res_str_head}") - if res_str_head == "unsat": - unsat_core = parse_unsat_core(res_str) if args.cache_solver else None - return SolverOutput(result=unsat, unsat_core=unsat_core) - elif res_str_head == "sat": - model = PotentialModel(f"{smt2_filename}.out", args) - return SolverOutput(result=sat, model=model) - else: - return SolverOutput(result=unknown) + match res_str_head: + case "unsat": + unsat_core = parse_unsat_core(res_str) if args.cache_solver else None + return SolverOutput(result=unsat, unsat_core=unsat_core) + case "sat": + model = PotentialModel(f"{smt2_filename}.out", args) + return SolverOutput(result=sat, model=model) + case _: + return SolverOutput(result=unknown) + except subprocess.TimeoutExpired: return SolverOutput(result=unknown) @@ -1326,26 +1392,27 @@ def on_signal(signum, frame): # support for `/// @custom:halmos` annotations contract_args = with_natspec(args, contract_name, natspec) - run_args = RunArgs( - funsigs, - creation_hexcode, - deployed_hexcode, - abi, - methodIdentifiers, - contract_args, - contract_json, - libs, - build_out_map, + contract_ctx = ContractContext( + args=contract_args, + name=contract_name, + creation_hexcode=creation_hexcode, + deployed_hexcode=deployed_hexcode, + abi=abi, + method_identifiers=methodIdentifiers, + contract_json=contract_json, + libs=libs, + build_out_map=build_out_map, ) - test_results = run_contract(run_args) - + test_results = run_contract(contract_ctx) num_passed = sum(r.exitcode == 0 for r in test_results) num_failed = num_found - num_passed print( - f"Symbolic test result: {num_passed} passed; " - f"{num_failed} failed; {contract_timer.report()}" + "Symbolic test result: " + f"{num_passed} passed; " + f"{num_failed} failed; " + f"{contract_timer.report()}" ) total_found += num_found