From c68755767e7beab45164fbf7ac7d5490ed258acf Mon Sep 17 00:00:00 2001 From: charliebrownies Date: Thu, 15 Aug 2024 11:23:53 +0200 Subject: [PATCH] restore policy update --- scripts/restore_and_watch_policy.py | 30 +++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/scripts/restore_and_watch_policy.py b/scripts/restore_and_watch_policy.py index fbfd8aa..68bf9b8 100644 --- a/scripts/restore_and_watch_policy.py +++ b/scripts/restore_and_watch_policy.py @@ -1,23 +1,45 @@ from pathlib import Path -from typing import cast +from typing import cast, Sequence +from armscan_env.volumes.loading import load_sitk_volumes from armscan_env.wrapper import ArmscanEnvFactory from tianshou.highlevel.env import EnvMode from tianshou.highlevel.experiment import Experiment # Place your path here -saved_experiment_dir = Path("log/random-actions/42/20240814-150513") +saved_experiment_dir = Path("log/sac-characteristic-array-rew-details-y/42/20240630-191219") if __name__ == "__main__": + volumes = load_sitk_volumes() restored_experiment = Experiment.from_directory(str(saved_experiment_dir), restore_policy=True) restored_policy = restored_experiment.create_experiment_world().policy # this can be now used to perform actions, you can also use a notebook - env_factory: ArmscanEnvFactory = cast(ArmscanEnvFactory, restored_experiment.env_factory) + old_env_factory: ArmscanEnvFactory = cast(ArmscanEnvFactory, restored_experiment.env_factory) # TODO @Carlo: modify here to set different volumes - env_factory.name2volume = env_factory.name2volume + env_factory = ArmscanEnvFactory( + name2volume={"3": volumes[2]}, + observation=old_env_factory.observation, + reward_metric=old_env_factory.reward_metric, + termination_criterion=old_env_factory.termination_criterion, + slice_shape=old_env_factory.slice_shape, + max_episode_len=old_env_factory.max_episode_len, + rotation_bounds=old_env_factory.rotation_bounds, + translation_bounds=old_env_factory.translation_bounds, + render_mode_train=old_env_factory.render_modes[EnvMode.TRAIN], + render_mode_test=old_env_factory.render_modes[EnvMode.TEST], + render_mode_watch=old_env_factory.render_modes[EnvMode.WATCH], + venv_type=old_env_factory.venv_type, + seed=old_env_factory.seed, + n_stack=old_env_factory.n_stack, + project_actions_to=old_env_factory.project_actions_to, + apply_volume_transformation=old_env_factory.apply_volume_transformation, + best_reward_memory=0, + exclude_keys_from_framestack=(), + **old_env_factory.make_kwargs + ) # Create env manually and run policy on it restored_env = env_factory.create_env(mode=EnvMode.WATCH)