Skip to content

Commit

Permalink
running array obs experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Jun 15, 2024
1 parent 22ce091 commit 62bb320
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
name2volume={"1": volume_1},
observation=LabelmapClusterObservation(action_shape=(4,)),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=100,
max_episode_len=20,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
action_shape=(4,),
),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=100,
max_episode_len=20,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
render_mode="animation",
Expand Down
15 changes: 10 additions & 5 deletions src/armscan_env/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ def __init__(self, action_shape: tuple[int]):
def compute_observation(self, state: LabelmapStateAction) -> np.ndarray:
tissue_clusters = TissueClusters.from_labelmap_slice(state.labels_2d_slice)
return np.concatenate(
self.cluster_characteristics_array(tissue_clusters=tissue_clusters).flatten(),
state.action,
state.last_reward,
(
self.cluster_characteristics_array(tissue_clusters=tissue_clusters).flatten(),
np.atleast_1d(state.normalized_action_arr),
np.atleast_1d(state.last_reward),
),
axis=0,
)

def _compute_observation_space(
Expand All @@ -256,7 +259,7 @@ def _compute_observation_space(
spaces=(
("num_clusters", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("num_points", gym.spaces.Box(low=0, high=np.inf, shape=(3,))),
("cluster_center_mean", gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,))),
("cluster_center_mean", gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))),
("action", gym.spaces.Box(low=-1, high=1, shape=self.action_shape)),
("reward", gym.spaces.Box(low=-1, high=0, shape=(1,))),
),
Expand All @@ -275,7 +278,9 @@ def cluster_characteristics_array(tissue_clusters: TissueClusters) -> np.ndarray
num_points += len(cluster.datapoints)
cluster_centers.append(cluster.center)
clusters_center_mean = np.mean(np.array(cluster_centers), axis=0)
cluster_characteristics.append([len(clusters), num_points, clusters_center_mean])
if np.any(np.isnan(clusters_center_mean)):
clusters_center_mean = np.zeros(2)
cluster_characteristics.append([len(clusters), num_points, *clusters_center_mean])

return np.array(cluster_characteristics)

Expand Down
4 changes: 3 additions & 1 deletion src/armscan_env/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, SupportsFloat, cast

import gymnasium as gym
import numpy as np
import SimpleITK as sitk
from armscan_env.envs.base import Observation, RewardMetric, TerminationCriterion
Expand All @@ -30,7 +31,8 @@
class PatchedFrameStackObservation(FrameStackObservation):
def __init__(self, env: Env[ObsType, ActType], n_stack: int):
super().__init__(env, n_stack)
self.observation_space = MultiBoxSpace(self.observation_space)
if isinstance(self.observation_space, gym.spaces.Dict):
self.observation_space = MultiBoxSpace(self.observation_space)


class ArmscanEnvFactory(EnvFactory):
Expand Down

0 comments on commit 62bb320

Please sign in to comment.