From 31d329425be0310790bc8e9042b1ca8fa75df89d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 Aug 2024 16:17:58 +0200 Subject: [PATCH] Added script demonstrating how to restore a policy --- scripts/restore_and_watch_policy.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 scripts/restore_and_watch_policy.py diff --git a/scripts/restore_and_watch_policy.py b/scripts/restore_and_watch_policy.py new file mode 100644 index 0000000..fbfd8aa --- /dev/null +++ b/scripts/restore_and_watch_policy.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import cast + +from armscan_env.wrapper import ArmscanEnvFactory + +from tianshou.highlevel.env import EnvMode +from tianshou.highlevel.experiment import Experiment + +# Place your path here +saved_experiment_dir = Path("log/random-actions/42/20240814-150513") + +if __name__ == "__main__": + restored_experiment = Experiment.from_directory(str(saved_experiment_dir), restore_policy=True) + + restored_policy = restored_experiment.create_experiment_world().policy + # this can be now used to perform actions, you can also use a notebook + + env_factory: ArmscanEnvFactory = cast(ArmscanEnvFactory, restored_experiment.env_factory) + # TODO @Carlo: modify here to set different volumes + env_factory.name2volume = env_factory.name2volume + + # Create env manually and run policy on it + restored_env = env_factory.create_env(mode=EnvMode.WATCH) + obs, info = restored_env.reset() + for _ in range(5): + obs, *_ = restored_env.step(restored_policy.compute_action(obs)) + + # Or use the restored experiment to run the policy + restored_experiment.config.train = False + restored_experiment.config.watch = True + restored_experiment.config.persistence_enabled = False + restored_experiment.run()