diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 4809f004..db931937 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -6,6 +6,7 @@ import uuid import json import re +import signal import traceback from argparse import Namespace @@ -1325,15 +1326,38 @@ def _main(_args=None) -> MainResult: timer.create_subtimer("tests") - # - # run - # - total_passed = 0 total_failed = 0 total_found = 0 test_results_map = {} + # + # exit and signal handlers to avoid dropping json output + # + + def on_exit(exitcode: int) -> MainResult: + result = MainResult(exitcode, test_results_map) + + if args.json_output: + with open(args.json_output, "w") as json_file: + json.dump(asdict(result), json_file, indent=4) + + return result + + def on_signal(signum, frame): + if args.debug: + debug(f"Signal {signum} received. Dumping {test_results_map}...") + exitcode = 128 + signum + on_exit(exitcode) + sys.exit(exitcode) + + for signum in [signal.SIGINT, signal.SIGTERM]: + signal.signal(signum, on_signal) + + # + # run + # + for build_out_map, filename, contract_name in build_output_iterator(build_out): if args.contract and args.contract != contract_name: continue @@ -1404,13 +1428,7 @@ def _main(_args=None) -> MainResult: return MainResult(1) exitcode = 0 if total_failed == 0 else 1 - result = MainResult(exitcode, test_results_map) - - if args.json_output: - with open(args.json_output, "w") as json_file: - json.dump(asdict(result), json_file, indent=4) - - return result + return on_exit(exitcode) # entrypoint for the `halmos` script diff --git a/src/halmos/utils.py b/src/halmos/utils.py index f6e79cf5..cf4845f2 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -369,6 +369,10 @@ def color_info(text: str) -> str: return cyan(text) +def color_debug(text: str) -> str: + return magenta(text) + + def error(text: str) -> None: print(color_error(text)) @@ -381,6 +385,10 @@ def info(text: str) -> None: print(color_info(text)) +def debug(text: str) -> None: + print(color_debug(text)) + + def indent_text(text: str, n: int = 4) -> str: return "\n".join(" " * n + line for line in text.splitlines())