Skip to content

Commit

Permalink
Replaced LinearSweep wrapper buy projecting in the base env
Browse files Browse the repository at this point in the history
Also fixed reset for the projected case, extended interfaces to better deal with
actions in various formats
Added multiple getters to the LabelmapEnv and made names more explicit
  • Loading branch information
Michael Panchenko committed Jun 14, 2024
1 parent 5dbac4c commit f835e2c
Show file tree
Hide file tree
Showing 6 changed files with 37,203 additions and 139 deletions.
37,123 changes: 37,046 additions & 77 deletions docs/02_notebooks/L5_linear_sweep.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/armscan_env/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class EnvPreconditionError(RuntimeError):

@dataclass(kw_only=True)
class StateAction:
action: Any
normalized_action_arr: Any
# state of the env will be reflected by fields added to subclasses
# but action is a reserved field name. Subclasses should override the
# type of action to be more specific
Expand Down Expand Up @@ -249,16 +249,16 @@ def append_step(
action: TAction,
observation: TObs,
reward: float,
info: dict[str, Any],
terminated: bool,
truncated: bool,
info: dict[str, Any],
) -> None:
self.observations.append(observation)
self.rewards.append(reward)
self.actions.append(action)
self.infos.append(info)
self.terminated.append(terminated)
self.truncated.append(truncated)
self.infos.append(info)

def append_reset(
self,
Expand Down
163 changes: 132 additions & 31 deletions src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC
from copy import copy
from copy import copy, deepcopy
from typing import Any, ClassVar, Literal

import gymnasium as gym
Expand Down Expand Up @@ -28,7 +28,6 @@

log = logging.getLogger(__name__)


_VOL_NAME_TO_OPTIMAL_ACTION = {
"1": ManipulatorAction(rotation=(19.3, 0.0), translation=(0.0, 140.0)),
"2": ManipulatorAction(rotation=(0.0, 0.0), translation=(0.0, 115.0)),
Expand Down Expand Up @@ -59,7 +58,7 @@ class LabelmapEnv(ModularEnv[LabelmapStateAction, np.ndarray, np.ndarray]):
:param seed: seed for the random number generator
"""

_INITIAL_POS_ROTATION = np.zeros(4)
_INITIAL_FULL_NORMALIZED_ACTION_ARR = np.zeros(4)

metadata: ClassVar[dict] = {"render_modes": ["plt", "animation", None], "render_fps": 10}

Expand All @@ -75,6 +74,7 @@ def __init__(
translation_bounds: tuple[float | None, float | None] = (None, None),
render_mode: Literal["plt", "animation"] | None = None,
seed: int | None = DEFAULT_SEED,
project_actions_to: Literal["x", "y", "xy"] | None = None,
):
if not name2volume:
raise ValueError("name2volume must not be empty")
Expand All @@ -87,6 +87,7 @@ def __init__(
self.name2volume = name2volume
self._slice_shape = slice_shape
self._seed = seed
self._project_actions_to = project_actions_to

# set at reset
self._cur_labelmap_name: str | None = None
Expand All @@ -101,27 +102,59 @@ def __init__(
self._axes: tuple[plt.Axes, plt.Axes, plt.Axes, plt.Axes, plt.Axes] | None = None
self._camera: Camera | None = None

@property
def project_actions_to(self) -> Literal["x", "y", "xy"] | None:
return self._project_actions_to

def get_optimal_action(self) -> ManipulatorAction:
if self.cur_labelmap_name is None:
raise RuntimeError("The labelmap name must not be None, did you call reset?")
return copy(_VOL_NAME_TO_OPTIMAL_ACTION[self.cur_labelmap_name])
return deepcopy(_VOL_NAME_TO_OPTIMAL_ACTION[self.cur_labelmap_name])

def step_to_solution(self) -> None:
self.step(self.get_optimal_action())
def get_full_optimal_action_array(self) -> np.ndarray:
return self.get_optimal_action().to_normalized_array(
self.rotation_bounds,
self.translation_bounds,
)

def unnormalize_rotation_translation(self, action: np.ndarray) -> ManipulatorAction:
"""Unnormalizes an array with values in the range [-1, 1] to the original range that is
consumed by :func:`slice_volume`.
def _get_projected_action_arr_idx(self) -> list[int]:
match self._project_actions_to:
case None:
return list(range(4))
case "x":
return [2]
case "y":
return [3]
case "xy":
return [2, 3]
case _:
raise ValueError(f"Unknown {self._project_actions_to=}")

:param action: normalized action with values in the range [-1, 1]
:return: unnormalized action that can be used with :func:`slice_volume`
"""
def _get_full_action_arr_len(self) -> int:
return len(self._INITIAL_FULL_NORMALIZED_ACTION_ARR)

def _get_full_action_leading_to_initial_state_normalized_arr(self) -> np.ndarray:
initial_action_arr = self.get_full_optimal_action_array()
project_idx = self._get_projected_action_arr_idx()
initial_action_arr[project_idx] = copy(
self._INITIAL_FULL_NORMALIZED_ACTION_ARR[project_idx],
)
return initial_action_arr

def _get_action_leading_to_initial_state(self) -> ManipulatorAction:
return ManipulatorAction.from_normalized_array(
action,
self._get_full_action_leading_to_initial_state_normalized_arr(),
self.rotation_bounds,
self.translation_bounds,
)

def get_optimal_action_array(self) -> np.ndarray:
full_action_arr = self.get_full_optimal_action_array()
return full_action_arr[self._get_projected_action_arr_idx()]

def step_to_optimal_state(self) -> None:
self.step(self.get_optimal_action())

@property
def cur_labelmap_name(self) -> str | None:
return self._cur_labelmap_name
Expand All @@ -132,20 +165,60 @@ def cur_labelmap_volume(self) -> sitk.Image | None:

@property
def action_space(self) -> gym.spaces.Space[np.ndarray]:
return gym.spaces.Box(
low=-1,
high=1.0,
shape=(4,),
) # 2 rotations, 2 translations.
action_dim = len(self._get_projected_action_arr_idx())
return gym.spaces.Box(low=-1, high=1, shape=(action_dim,), dtype=np.float32)

def close(self) -> None:
super().close()
self.name2volume = {}
self._cur_labelmap_name = None
self._cur_labelmap_volume = None

def _get_slice_from_action(self, action: np.ndarray) -> np.ndarray:
manipulator_action = self.unnormalize_rotation_translation(action)
def get_full_action_array_from_projected_action(
self,
normalized_action_arr: np.ndarray,
) -> np.ndarray:
"""Converts a (potentially projected and) normalized action array to a full action array.
If `project_actions_to` is not None, the `normalized_action_arr` is assumed to be a projection
of the correct dimension.
"""
full_action_arr = self.get_full_optimal_action_array()
project_idx = self._get_projected_action_arr_idx()
if len(normalized_action_arr) != len(project_idx):
raise ValueError(
f"Expected {len(project_idx)} elements in normalized_action_arr, "
f"but got {len(normalized_action_arr)}",
)
full_action_arr[project_idx] = normalized_action_arr
return full_action_arr

def get_manipulator_action_from_normalized_action(
self,
normalized_action_arr: np.ndarray,
) -> ManipulatorAction:
"""Converts a (potentially projected and) normalized action array to a ManipulatorAction.
Passing a full action array is also supported, even if `project_actions_to` is not None.
If `normalized_action_arr` is of a lower len and `project_actions_to` is not None, the
`normalized_action_arr` is assumed to be a projection
of the correct dimension.
"""
if len(normalized_action_arr) != self._get_full_action_arr_len():
normalized_action_arr = self.get_full_action_array_from_projected_action(
normalized_action_arr,
)
return ManipulatorAction.from_normalized_array(
normalized_action_arr,
self.rotation_bounds,
self.translation_bounds,
)

def _get_slice_from_action(self, action: np.ndarray | ManipulatorAction) -> np.ndarray:
if isinstance(action, np.ndarray):
manipulator_action = self.get_manipulator_action_from_normalized_action(action)
else:
manipulator_action = action
sliced_volume = slice_volume(
volume=self.cur_labelmap_volume,
slice_shape=self._slice_shape,
Expand All @@ -157,12 +230,16 @@ def _get_slice_from_action(self, action: np.ndarray) -> np.ndarray:
return sitk.GetArrayFromImage(sliced_volume)[:, 0, :].T

def _get_initial_slice(self) -> np.ndarray:
return self._get_slice_from_action(self._INITIAL_POS_ROTATION)
action_to_initial_slice = self._get_action_leading_to_initial_state()
return self._get_slice_from_action(action_to_initial_slice)

def compute_next_state(self, action: np.ndarray) -> LabelmapStateAction:
new_slice = self._get_slice_from_action(action)
def compute_next_state(
self,
normalized_action_arr: np.ndarray | ManipulatorAction,
) -> LabelmapStateAction:
new_slice = self._get_slice_from_action(normalized_action_arr)
return LabelmapStateAction(
action=action,
normalized_action_arr=normalized_action_arr,
labels_2d_slice=new_slice,
last_reward=self.reward_metric.compute_reward(self.cur_state_action),
# cur_state_action is the previous state, so this reward is computed for the previous state
Expand All @@ -184,9 +261,12 @@ def sample_initial_state(self) -> LabelmapStateAction:
self.compute_slice_shape(volume=self.cur_labelmap_volume)
initial_slice = self._get_initial_slice()
return LabelmapStateAction(
action=self._INITIAL_POS_ROTATION,
normalized_action_arr=copy(
self._INITIAL_FULL_NORMALIZED_ACTION_ARR[self._get_projected_action_arr_idx()],
),
labels_2d_slice=initial_slice,
last_reward=-1.0,
# TODO: pass the env's optimal position and labelmap or remove them from the StateAction?
optimal_position=None,
optimal_labelmap=None,
)
Expand Down Expand Up @@ -214,6 +294,19 @@ def compute_translation_bounds(self) -> None:
def get_translation_bounds(self) -> tuple[float | None, float | None]:
return self.translation_bounds

def get_cur_full_normalized_action_arr(self) -> np.ndarray:
return self.get_full_action_array_from_projected_action(
self.cur_state_action.normalized_action_arr,
)

def get_cur_manipulator_action(self) -> ManipulatorAction:
cur_full_action_arr = self.get_cur_full_normalized_action_arr()
return ManipulatorAction.from_normalized_array(
cur_full_action_arr,
self.rotation_bounds,
self.translation_bounds,
)

def step(
self,
action: np.ndarray | ManipulatorAction,
Expand All @@ -240,13 +333,13 @@ def reset(
obs, info = super().reset(seed=self._seed, **kwargs)
return obs, info

def render(self) -> plt.Figure | Camera | None:
def render(self, title: str = "Labelmap slice") -> plt.Figure | Camera | None:
match self.render_mode:
case "plt":
return self.get_cur_state_plot(create_new_figure=False)
return self.get_cur_state_plot(create_new_figure=False, title=title)
case "animation":
camera = self.get_camera()
self.get_cur_state_plot(create_new_figure=False)
self.get_cur_state_plot(create_new_figure=False, title=title)
camera.snap()
return camera
case None:
Expand All @@ -261,7 +354,11 @@ def render(self) -> plt.Figure | Camera | None:
f"Unknown render mode: {self.render_mode}, this should not happen.",
)

def get_cur_state_plot(self, create_new_figure: bool = True) -> plt.Figure | None:
def get_cur_state_plot(
self,
create_new_figure: bool = True,
title: str = "Labelmap slice",
) -> plt.Figure | None:
"""Retrieve a figure visualizing the current state of the environment.
:param create_new_figure: if True, a new figure will be created. Otherwise, a single figure will be used
Expand All @@ -277,7 +374,9 @@ def get_cur_state_plot(self, create_new_figure: bool = True) -> plt.Figure | Non
volume = self.cur_labelmap_volume
o = volume.GetOrigin()
img_array = sitk.GetArrayFromImage(volume)
action = self.unnormalize_rotation_translation(self.cur_state_action.action)
action = self.get_manipulator_action_from_normalized_action(
self.cur_state_action.normalized_action_arr,
)
translation = action.translation
rotation = action.rotation

Expand All @@ -290,7 +389,7 @@ def get_cur_state_plot(self, create_new_figure: bool = True) -> plt.Figure | Non
y_dash = np.tan(np.deg2rad(rotation[0])) * x_dash + b_x
y_dash = np.clip(y_dash, 0, img_array.shape[1] - 1)
ax1.plot(x_dash, y_dash, linestyle="--", color="red")
ax1.set_title("Slice cut")
ax1.set_title(f"Slice cut (labelmap name: {self.cur_labelmap_name})")

# Subplot 2: from the side
ix = volume.GetSize()[0] // 2
Expand Down Expand Up @@ -324,6 +423,8 @@ def get_cur_state_plot(self, create_new_figure: bool = True) -> plt.Figure | Non
# REWARD
ax5.text(0, 0, f"Reward: {self.cur_reward:.2f}", fontsize=12, color="red")

fig.suptitle(title, x=0.2, y=0.95)

plt.close()
return fig

Expand Down
12 changes: 10 additions & 2 deletions src/armscan_env/envs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def compute_observation(
state: LabelmapStateAction,
) -> ChanneledLabelmapsObsWithActReward:
"""Return the observation as a dictionary of the type ChanneledLabelmapsObsWithActReward."""
return self.compute_from_slice(state.labels_2d_slice, state.action, state.last_reward)
return self.compute_from_slice(
state.labels_2d_slice,
state.normalized_action_arr,
state.last_reward,
)

def compute_from_slice(
self,
Expand Down Expand Up @@ -176,7 +180,11 @@ def compute_observation(
state: LabelmapStateAction,
) -> ChanneledLabelmapsObsWithActReward:
"""Return the observation as a dictionary of the type ChanneledLabelmapsObsWithActReward."""
return self.compute_from_slice(state.labels_2d_slice, state.action, state.last_reward)
return self.compute_from_slice(
state.labels_2d_slice,
state.normalized_action_arr,
state.last_reward,
)

def compute_from_slice(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/armscan_env/envs/state_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def from_normalized_array(

@dataclass(kw_only=True)
class LabelmapStateAction(StateAction):
action: np.ndarray
normalized_action_arr: np.ndarray
"""Array of shape (4,) representing two angles and two translations"""
labels_2d_slice: np.ndarray
"""Two-dimensional slice of the labelmap, i.e., an array of shape (N, M) with integer values.
Expand Down
Loading

0 comments on commit f835e2c

Please sign in to comment.