Skip to content

Commit

Permalink
restore policy update
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Aug 15, 2024
1 parent 7ec0c46 commit c687557
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions scripts/restore_and_watch_policy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
from pathlib import Path
from typing import cast
from typing import cast, Sequence

from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory

from tianshou.highlevel.env import EnvMode
from tianshou.highlevel.experiment import Experiment

# Place your path here
saved_experiment_dir = Path("log/random-actions/42/20240814-150513")
saved_experiment_dir = Path("log/sac-characteristic-array-rew-details-y/42/20240630-191219")

if __name__ == "__main__":
volumes = load_sitk_volumes()
restored_experiment = Experiment.from_directory(str(saved_experiment_dir), restore_policy=True)

restored_policy = restored_experiment.create_experiment_world().policy
# this can be now used to perform actions, you can also use a notebook

env_factory: ArmscanEnvFactory = cast(ArmscanEnvFactory, restored_experiment.env_factory)
old_env_factory: ArmscanEnvFactory = cast(ArmscanEnvFactory, restored_experiment.env_factory)
# TODO @Carlo: modify here to set different volumes
env_factory.name2volume = env_factory.name2volume
env_factory = ArmscanEnvFactory(
name2volume={"3": volumes[2]},
observation=old_env_factory.observation,
reward_metric=old_env_factory.reward_metric,
termination_criterion=old_env_factory.termination_criterion,
slice_shape=old_env_factory.slice_shape,
max_episode_len=old_env_factory.max_episode_len,
rotation_bounds=old_env_factory.rotation_bounds,
translation_bounds=old_env_factory.translation_bounds,
render_mode_train=old_env_factory.render_modes[EnvMode.TRAIN],
render_mode_test=old_env_factory.render_modes[EnvMode.TEST],
render_mode_watch=old_env_factory.render_modes[EnvMode.WATCH],
venv_type=old_env_factory.venv_type,
seed=old_env_factory.seed,
n_stack=old_env_factory.n_stack,
project_actions_to=old_env_factory.project_actions_to,
apply_volume_transformation=old_env_factory.apply_volume_transformation,
best_reward_memory=0,
exclude_keys_from_framestack=(),
**old_env_factory.make_kwargs
)

# Create env manually and run policy on it
restored_env = env_factory.create_env(mode=EnvMode.WATCH)
Expand Down

0 comments on commit c687557

Please sign in to comment.