Skip to content

Commit

Permalink
Merge pull request #4 from appliedAI-Initiative/array_observation
Browse files Browse the repository at this point in the history
Array observation
  • Loading branch information
MischaPanch authored Jun 17, 2024
2 parents f835e2c + 62bb320 commit 0370a01
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 37,003 deletions.
36,987 changes: 22 additions & 36,965 deletions docs/02_notebooks/L5_linear_sweep.ipynb

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import LabelmapEnvTerminationCriterion
from armscan_env.envs.observations import (
LabelmapClusterObservation,
)
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.wrapper import ArmscanEnvFactory

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import (
ExperimentConfig,
SACExperimentBuilder,
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

config = get_config()

volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))

log_name = os.path.join("sac", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()

sampling_config = SamplingConfig(
num_epochs=10,
step_per_epoch=100,
num_train_envs=1,
num_test_envs=1,
buffer_size=1000,
batch_size=256,
step_per_collect=1,
update_per_step=1,
start_timesteps=0,
start_timesteps_random=True,
)

volume_size = volume_1.GetSize()
env_factory = ArmscanEnvFactory(
name2volume={"1": volume_1},
observation=LabelmapClusterObservation(action_shape=(4,)),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=20,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
seed=experiment_config.seed,
venv_type=VectorEnvType.DUMMY,
n_stack=4,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(n_landmarks=(4, 2, 1)),
)

experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_sac_params(
SACParams(
tau=0.005,
gamma=0.99,
alpha=AutoAlphaFactoryDefault(lr=3e-4),
estimation_step=1,
actor_lr=1e-3,
critic1_lr=1e-3,
critic2_lr=1e-3,
),
)
.with_actor_factory_default(
(256, 256),
continuous_unbounded=True,
continuous_conditioned_sigma=True,
)
.with_common_critic_factory_default((256, 256))
.build()
)

experiment.run(run_name=log_name)
12 changes: 6 additions & 6 deletions scripts/armscan_ppo_hl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import LabelmapEnvTerminationCriterion
from armscan_env.envs.observations import (
LabelmapSliceAsChannelsObservation,
Expand All @@ -18,10 +19,10 @@
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils.logging import datetime_tag

path_to_labels_1 = os.path.join("..", "data", "labels", "00001_labels.nii")
volume_1 = sitk.ReadImage(path_to_labels_1)
path_to_labels_2 = os.path.join("..", "data", "labels", "00002_labels.nii")
volume_2 = sitk.ReadImage(path_to_labels_2)
config = get_config()

volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))

log_name = os.path.join("ppo", str(ExperimentConfig.seed), "4_stack-lin_sweep_v1", datetime_tag())
experiment_config = ExperimentConfig()
Expand All @@ -45,14 +46,13 @@
action_shape=(4,),
),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=100,
max_episode_len=20,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
seed=experiment_config.seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
n_stack=4,
project_to_x_translation=True,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(n_landmarks=(4, 2, 1)),
)
Expand Down
10 changes: 5 additions & 5 deletions scripts/armscan_sac_hl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import SimpleITK as sitk
from armscan_env.config import get_config
from armscan_env.envs.labelmaps_navigation import LabelmapEnvTerminationCriterion
from armscan_env.envs.observations import LabelmapSliceAsChannelsObservation
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
Expand All @@ -17,10 +18,10 @@
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils.logging import datetime_tag

path_to_labels_1 = os.path.join("..", "data", "labels", "00001_labels.nii")
volume_1 = sitk.ReadImage(path_to_labels_1)
path_to_labels_2 = os.path.join("..", "data", "labels", "00002_labels.nii")
volume_2 = sitk.ReadImage(path_to_labels_2)
config = get_config()

volume_1 = sitk.ReadImage(config.get_labels_path(1))
volume_2 = sitk.ReadImage(config.get_labels_path(2))

log_name = os.path.join("sac", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()
Expand Down Expand Up @@ -53,7 +54,6 @@
seed=experiment_config.seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
n_stack=4,
project_to_x_translation=True,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(n_landmarks=(4, 2, 1)),
)
Expand Down
5 changes: 0 additions & 5 deletions scripts/run_sample.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/armscan_env/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def find_DBSCAN_clusters(self, labelmap_slice: np.ndarray) -> list["DataCluster"
class DataCluster:
"""Data class for a cluster of a tissue in a slice."""

cluster: list[tuple[float, float]] | np.ndarray
datapoints: list[tuple[float, float]] | np.ndarray
center: tuple[np.floating[Any], np.floating[Any]]


Expand Down Expand Up @@ -93,7 +93,7 @@ def find_clusters(tissue_value: int, slice: np.ndarray) -> list[DataCluster]:

cluster_list.append(
DataCluster(
cluster=list(zip(cluster_indices[0], cluster_indices[1], strict=True)),
datapoints=list(zip(cluster_indices[0], cluster_indices[1], strict=True)),
center=center,
),
)
Expand Down Expand Up @@ -152,6 +152,6 @@ def find_DBSCAN_clusters(
label_to_pos_array = label_positions[clusters == cluster] # get positions of each cluster
cluster_centers = np.mean(label_to_pos_array, axis=0) # mean of each column

cluster_list.append(DataCluster(cluster=label_to_pos_array, center=cluster_centers))
cluster_list.append(DataCluster(datapoints=label_to_pos_array, center=cluster_centers))

return cluster_list
3 changes: 2 additions & 1 deletion src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ def get_cur_state_plot(
# REWARD
ax5.text(0, 0, f"Reward: {self.cur_reward:.2f}", fontsize=12, color="red")

fig.suptitle(title, x=0.2, y=0.95)
if fig is not None:
fig.suptitle(title, x=0.2, y=0.95)

plt.close()
return fig
Expand Down
58 changes: 49 additions & 9 deletions src/armscan_env/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gymnasium as gym
import numpy as np
from armscan_env.clustering import TissueClusters, TissueLabel
from armscan_env.envs.base import ArrayObservation, DictObservation, TStateAction
from armscan_env.envs.base import ArrayObservation, DictObservation
from armscan_env.envs.state_action import LabelmapStateAction
from armscan_env.util.img_processing import crop_center

Expand Down Expand Up @@ -237,14 +237,54 @@ class LabelmapClusterObservation(ArrayObservation[LabelmapStateAction]):
TODO: Implement this observation.
"""

def compute_observation(self, state: TStateAction) -> np.ndarray:
def __init__(self, action_shape: tuple[int]):
self.action_shape = action_shape

def compute_observation(self, state: LabelmapStateAction) -> np.ndarray:
tissue_clusters = TissueClusters.from_labelmap_slice(state.labels_2d_slice)
return self.cluster_characteristics_array(tissue_cluster=tissue_clusters)
return np.concatenate(
(
self.cluster_characteristics_array(tissue_clusters=tissue_clusters).flatten(),
np.atleast_1d(state.normalized_action_arr),
np.atleast_1d(state.last_reward),
),
axis=0,
)

def _compute_observation_space(
self,
) -> gym.spaces.Box:
"""Return the observation space as a Box, with the right bounds for each feature."""
DictObs = gym.spaces.Dict(
spaces=(
("num_clusters", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("num_points", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("cluster_center_mean", gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))),
("action", gym.spaces.Box(low=-1, high=1, shape=self.action_shape)),
("reward", gym.spaces.Box(low=-1, high=0, shape=(1,))),
),
)
return cast(gym.spaces.Box, gym.spaces.flatten_space(DictObs))

@staticmethod
def cluster_characteristics_array(tissue_cluster: TissueClusters) -> np.ndarray:
characteristics_array = np.zeros((3, 2))
characteristics_array[0, 0] = len(tissue_cluster.bones)
characteristics_array[1, 0] = len(tissue_cluster.tendons)
characteristics_array[2, 0] = len(tissue_cluster.ulnar)
return characteristics_array
def cluster_characteristics_array(tissue_clusters: TissueClusters) -> np.ndarray:
cluster_characteristics = []

for tissue_label in TissueLabel:
clusters = tissue_clusters.get_cluster_for_label(tissue_label)
num_points = 0
cluster_centers = []
for cluster in clusters:
num_points += len(cluster.datapoints)
cluster_centers.append(cluster.center)
clusters_center_mean = np.mean(np.array(cluster_centers), axis=0)
if np.any(np.isnan(clusters_center_mean)):
clusters_center_mean = np.zeros(2)
cluster_characteristics.append([len(clusters), num_points, *clusters_center_mean])

return np.array(cluster_characteristics)

@property
def observation_space(self) -> gym.spaces.Box:
"""Boolean 2-d array representing segregated labelmap slice."""
return self._compute_observation_space()
12 changes: 6 additions & 6 deletions src/armscan_env/envs/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ def anatomy_based_rwd(
bones_centers_mean = np.mean(bones_centers, axis=0)
log.debug(f"{bones_centers_mean=}")

ligament_centers = np.array([cluster.center for cluster in tissue_clusters.tendons])
ligament_centers_mean = np.mean(ligament_centers, axis=0)
log.debug(f"{ligament_centers_mean=}")
tendons_centers = np.array([cluster.center for cluster in tissue_clusters.tendons])
tendons_centers_mean = np.mean(tendons_centers, axis=0)
log.debug(f"{tendons_centers_mean=}")

# There must be only one ulnar tissue so there is no need to take the mean
ulnar_center = tissue_clusters.ulnar[0].center
log.debug(f"{ulnar_center=}")

# Check the orientation of the arm:
# The bones center might be over or under the tendons center depending on the origin
orientation = (bones_centers_mean[0] - ligament_centers_mean[0]) // abs(
bones_centers_mean[0] - ligament_centers_mean[0],
orientation = (bones_centers_mean[0] - tendons_centers_mean[0]) // abs(
bones_centers_mean[0] - tendons_centers_mean[0],
)
log.debug(f"{orientation=}")

# Ulnar artery must be under tendons in the positive orientation:
if orientation * ulnar_center[0] < orientation * ligament_centers_mean[0]:
if orientation * ulnar_center[0] < orientation * tendons_centers_mean[0]:
location_loss = 0
else:
log.debug("Ulnar center not where expected")
Expand Down
4 changes: 2 additions & 2 deletions src/armscan_env/util/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def show_clusters(
for tissue in TissueLabel:
for label, data in enumerate(tissue_clusters.get_cluster_for_label(tissue)):
# plot clusters with different colors
cluster_labels[tuple(np.array(data.cluster).T)] = (label + 1) * 10
cluster_labels[tuple(np.array(data.datapoints).T)] = (label + 1) * 10
ax.scatter(data.center[0], data.center[1], color="red", marker="*", s=20)
ax.imshow(cluster_labels.T, aspect=aspect, origin="lower")
return ax
Expand All @@ -135,7 +135,7 @@ def show_only_clusters(
for tissue in TissueLabel:
for label, data in enumerate(tissue_clusters.get_cluster_for_label(tissue)):
# plot clusters with different colors
cluster_labels[tuple(np.array(data.cluster).T)] = (label + 1) * 10
cluster_labels[tuple(np.array(data.datapoints).T)] = (label + 1) * 10
ax.scatter(data.center[0], data.center[1], color="red", marker="*", s=20)
ax.imshow(cluster_labels.T, aspect=6, origin="lower")
return ax
4 changes: 3 additions & 1 deletion src/armscan_env/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, SupportsFloat, cast

import gymnasium as gym
import numpy as np
import SimpleITK as sitk
from armscan_env.envs.base import Observation, RewardMetric, TerminationCriterion
Expand All @@ -30,7 +31,8 @@
class PatchedFrameStackObservation(FrameStackObservation):
def __init__(self, env: Env[ObsType, ActType], n_stack: int):
super().__init__(env, n_stack)
self.observation_space = MultiBoxSpace(self.observation_space)
if isinstance(self.observation_space, gym.spaces.Dict):
self.observation_space = MultiBoxSpace(self.observation_space)


class ArmscanEnvFactory(EnvFactory):
Expand Down

0 comments on commit 0370a01

Please sign in to comment.