diff --git a/runtime/tools/python/ttrt/__init__.py b/runtime/tools/python/ttrt/__init__.py index 37d4c6e79..be197112c 100644 --- a/runtime/tools/python/ttrt/__init__.py +++ b/runtime/tools/python/ttrt/__init__.py @@ -80,6 +80,11 @@ def main(): action="store_true", help="save all artifacts during run", ) + run_parser.add_argument( + "--seed", + default=0, + help="Seed for random number generator", + ) run_parser.add_argument("binary", help="flatbuffer binary file") run_parser.set_defaults(func=run) diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 0517f4b7b..8f2de3738 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -133,6 +133,8 @@ def run(args): device = ttrt.runtime.open_device(device_ids) atexit.register(lambda: ttrt.runtime.close_device(device)) + torch.manual_seed(args.seed) + for (binary_name, fbb, fbb_dict) in fbb_list: torch_inputs[binary_name] = [] torch_outputs[binary_name] = []