Skip to content

Commit

Permalink
add characteristic_array_obs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Jun 15, 2024
1 parent a28d922 commit 22ce091
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
81 changes: 81 additions & 0 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import LabelmapEnvTerminationCriterion
from armscan_env.envs.observations import (
LabelmapClusterObservation,
)
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.wrapper import ArmscanEnvFactory

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import (
ExperimentConfig,
SACExperimentBuilder,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

config = get_config()

volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))

log_name = os.path.join("sac", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()

sampling_config = SamplingConfig(
num_epochs=10,
step_per_epoch=100,
num_train_envs=1,
num_test_envs=1,
buffer_size=1000,
batch_size=256,
step_per_collect=1,
update_per_step=1,
start_timesteps=0,
start_timesteps_random=True,
)

volume_size = volume_1.GetSize()
env_factory = ArmscanEnvFactory(
name2volume={"1": volume_1},
observation=LabelmapClusterObservation(action_shape=(4,)),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=100,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
seed=experiment_config.seed,
venv_type=VectorEnvType.DUMMY,
n_stack=4,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(n_landmarks=(4, 2, 1)),
)

experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
SACParams(
tau=0.005,
gamma=0.99,
alpha=AutoAlphaFactoryDefault(lr=3e-4),
estimation_step=1,
actor_lr=1e-3,
critic1_lr=1e-3,
critic2_lr=1e-3,
),
)
.with_actor_factory_default(
(256, 256),
continuous_unbounded=True,
continuous_conditioned_sigma=True,
)
.with_common_critic_factory_default((256, 256))
.build()
)

experiment.run(run_name=log_name)
3 changes: 2 additions & 1 deletion src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ def get_cur_state_plot(
# REWARD
ax5.text(0, 0, f"Reward: {self.cur_reward:.2f}", fontsize=12, color="red")

fig.suptitle(title, x=0.2, y=0.95)
if fig is not None:
fig.suptitle(title, x=0.2, y=0.95)

plt.close()
return fig
Expand Down
23 changes: 23 additions & 0 deletions src/armscan_env/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ class LabelmapClusterObservation(ArrayObservation[LabelmapStateAction]):
TODO: Implement this observation.
"""

def __init__(self, action_shape: tuple[int]):
self.action_shape = action_shape

def compute_observation(self, state: LabelmapStateAction) -> np.ndarray:
tissue_clusters = TissueClusters.from_labelmap_slice(state.labels_2d_slice)
return np.concatenate(
Expand All @@ -245,6 +248,21 @@ def compute_observation(self, state: LabelmapStateAction) -> np.ndarray:
state.last_reward,
)

def _compute_observation_space(
self,
) -> gym.spaces.Box:
"""Return the observation space as a Box, with the right bounds for each feature."""
DictObs = gym.spaces.Dict(
spaces=(
("num_clusters", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("num_points", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("cluster_center_mean", gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,))),
("action", gym.spaces.Box(low=-1, high=1, shape=self.action_shape)),
("reward", gym.spaces.Box(low=-1, high=0, shape=(1,))),
),
)
return cast(gym.spaces.Box, gym.spaces.flatten_space(DictObs))

@staticmethod
def cluster_characteristics_array(tissue_clusters: TissueClusters) -> np.ndarray:
cluster_characteristics = []
Expand All @@ -260,3 +278,8 @@ def cluster_characteristics_array(tissue_clusters: TissueClusters) -> np.ndarray
cluster_characteristics.append([len(clusters), num_points, clusters_center_mean])

return np.array(cluster_characteristics)

@property
def observation_space(self) -> gym.spaces.Box:
"""Boolean 2-d array representing segregated labelmap slice."""
return self._compute_observation_space()

0 comments on commit 22ce091

Please sign in to comment.