Skip to content

Commit

Permalink
First working RL training script (toy example in armscan_sac_hl.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Jun 17, 2024
1 parent 42bbae8 commit d01a26f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 58 deletions.
3 changes: 2 additions & 1 deletion docs/02_notebooks/L5_linear_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@
"outputs": [],
"source": [
"print(\n",
" \"Observed 'rewards': \\n\", [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],\n",
" \"Observed 'rewards': \\n\",\n",
" [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],\n",
")\n",
"print(\"Env rewards: \\n\", [round(r, 4) for r in projected_env_rollout.rewards])"
]
Expand Down
105 changes: 52 additions & 53 deletions scripts/armscan_sac_hl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os

import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import LabelmapEnvTerminationCriterion
from armscan_env.envs.observations import LabelmapSliceAsChannelsObservation
from armscan_env.envs.observations import (
ActionRewardObservation,
)
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.network import ActorFactoryArmscanDQN
from armscan_env.wrapper import ArmscanEnvFactory

from tianshou.highlevel.config import SamplingConfig
Expand All @@ -18,62 +20,59 @@
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

config = get_config()
if __name__ == "__main__":
config = get_config()
logging.basicConfig(level=logging.INFO)

volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))
volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))

log_name = os.path.join("sac", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()
log_name = os.path.join("sac", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()

sampling_config = SamplingConfig(
num_epochs=200,
step_per_epoch=5000,
num_train_envs=1,
num_test_envs=10,
buffer_size=1000000,
batch_size=256,
step_per_collect=1,
update_per_step=1,
start_timesteps=10000,
start_timesteps_random=True,
)
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=5,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
update_per_step=10,
start_timesteps=5000,
start_timesteps_random=True,
)

volume_size = volume_1.GetSize()
env_factory = ArmscanEnvFactory(
name2volume={"1": volume_1},
observation=LabelmapSliceAsChannelsObservation(
volume_size = volume_1.GetSize()
env_factory = ArmscanEnvFactory(
name2volume={"1": volume_1},
observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
slice_shape=(volume_size[0], volume_size[2]),
action_shape=(4,),
),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=100,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
seed=experiment_config.seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
n_stack=4,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(),
)
max_episode_len=10,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
seed=experiment_config.seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
n_stack=3,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
)

experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
SACParams(
tau=0.005,
gamma=0.99,
alpha=AutoAlphaFactoryDefault(lr=3e-4),
estimation_step=1,
actor_lr=1e-3,
critic1_lr=1e-3,
critic2_lr=1e-3,
),
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
SACParams(
tau=0.005,
gamma=0.99,
alpha=AutoAlphaFactoryDefault(lr=3e-4),
estimation_step=1,
actor_lr=1e-3,
critic1_lr=1e-3,
critic2_lr=1e-3,
),
)
.build()
)
.with_actor_factory(ActorFactoryArmscanDQN())
.with_common_critic_factory_use_actor()
.build()
)

experiment.run(run_name=log_name)
experiment.run(run_name=log_name)
11 changes: 7 additions & 4 deletions src/armscan_env/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_array_observation(self) -> ArrayObservation[TStateAction]:
def observation_space(self) -> gym.spaces.Dict:
pass

def merged_with(self, other: Self) -> 'MergedDictObservation[TStateAction]':
def merged_with(self, other: Self) -> "MergedDictObservation[TStateAction]":
return MergedDictObservation([self, other])


Expand All @@ -90,17 +90,20 @@ def __init__(self, array_observations: list[ArrayObservation[TStateAction]]):

def compute_observation(self, state: TStateAction) -> np.ndarray:
return np.concatenate(
[obs.compute_observation(state) for obs in self.array_observations], axis=0,
[obs.compute_observation(state) for obs in self.array_observations],
axis=0,
)

@cached_property
def observation_space(self) -> gym.spaces.Box:
return gym.spaces.Box(
low=np.concatenate(
[obs.observation_space.low for obs in self.array_observations], axis=0,
[obs.observation_space.low for obs in self.array_observations],
axis=0,
),
high=np.concatenate(
[obs.observation_space.high for obs in self.array_observations], axis=0,
[obs.observation_space.high for obs in self.array_observations],
axis=0,
),
)

Expand Down
1 change: 1 addition & 0 deletions src/armscan_env/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
self,
name2volume: dict[str, sitk.Image],
observation: Observation[LabelmapStateAction, Any],
# TODO: remove mutable default values, make a proper config-based factory (not urgent)
reward_metric: RewardMetric[LabelmapStateAction] = LabelmapClusteringBasedReward(),
termination_criterion: TerminationCriterion | None = LabelmapEnvTerminationCriterion(),
slice_shape: tuple[int, int] | None = None,
Expand Down

0 comments on commit d01a26f

Please sign in to comment.