Skip to content

Commit

Permalink
changed random volume transformation to take action as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Jul 3, 2024
1 parent 11ffb07 commit 683ed47
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 14 deletions.
34 changes: 22 additions & 12 deletions src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def get_optimal_action_array(self) -> np.ndarray:
full_action_arr = self.get_full_optimal_action_array()
return full_action_arr[self._get_projected_action_arr_idx()]

def step_to_optimal_state(self) -> tuple[TObs, float, bool, bool, dict[str, Any]]:
def step_to_optimal_state(
self,
) -> tuple[Observation[LabelmapStateAction, Any], float, bool, bool, dict[str, Any]]:
return self.step(self.get_optimal_action())

@property
Expand Down Expand Up @@ -271,24 +273,30 @@ def compute_next_state(
def apply_volume_transformation(
self,
volume: sitk.Image,
volume_transformation_action: ManipulatorAction,
optimal_action: ManipulatorAction,
) -> (sitk.Image, ManipulatorAction): # type: ignore
volume_transformation = ManipulatorAction(
rotation=(np.random.uniform(-20, 20), np.random.uniform(-5, 5)),
translation=(np.random.uniform(-5, 5), np.random.uniform(-5, 5)),
)
transformed_optimal_action = EulerTransform(volume_transformation).transform_action(
"""Apply a random transformation to the volume and to the optimal action. The transformation is a random rotation
and translation. The bounds of the rotation are updated if they have already been set. The translation bounds are
computed from the volume size in the 'sample_initial_state' method.
:param volume: the volume to transform
:param volume_transformation_action: the transformation action to apply to the volume
:param optimal_action: the optimal action for the volume to transform accordingly
:return: the transformed volume and the transformed optimal action
"""
transformed_optimal_action = EulerTransform(volume_transformation_action).transform_action(
optimal_action,
)
if self.rotation_bounds:
bounds = list(self.rotation_bounds)
bounds[0] += abs(volume_transformation.rotation[0])
bounds[1] += abs(volume_transformation.rotation[1])
bounds[0] += abs(volume_transformation_action.rotation[0])
bounds[1] += abs(volume_transformation_action.rotation[1])
self.rotation_bounds = tuple(bounds) # type: ignore
return (
create_transformed_volume(
volume=volume,
transformation_action=volume_transformation,
transformation_action=volume_transformation_action,
),
transformed_optimal_action,
)
Expand All @@ -303,9 +311,11 @@ def sample_initial_state(self) -> LabelmapStateAction:
volume_optimal_action = deepcopy(_VOL_NAME_TO_OPTIMAL_ACTION[sampled_image_name])

if self._apply_volume_transformation:
volume_transformation_action = ManipulatorAction.sample()
self._cur_labelmap_volume, self._cur_optimal_action = self.apply_volume_transformation(
self.name2volume[sampled_image_name],
volume_optimal_action,
volume=self.name2volume[sampled_image_name],
volume_transformation_action=volume_transformation_action,
optimal_action=volume_optimal_action,
)
else:
self._cur_labelmap_volume = self.name2volume[sampled_image_name]
Expand Down Expand Up @@ -361,7 +371,7 @@ def get_cur_manipulator_action(self) -> ManipulatorAction:
def step(
self,
action: np.ndarray | ManipulatorAction,
) -> tuple[TObs, float, bool, bool, dict[str, Any]]:
) -> tuple[Observation[LabelmapStateAction, Any], float, bool, bool, dict[str, Any]]:
if isinstance(action, ManipulatorAction):
action = action.to_normalized_array(self.rotation_bounds, self.translation_bounds)
return super().step(action)
Expand Down
32 changes: 30 additions & 2 deletions src/armscan_env/envs/state_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def to_normalized_array(
rotation = np.zeros(2)
translation = np.zeros(2)
if self.translation[0] < 0 or self.translation[1] < 0:
log.debug("Projecting to positive because negative defined translation")
log.debug(
"Action contains a negative translation, out of bounds.\n"
"Projecting the origin of the viewing plane to positive octant.",
)
self.project_to_positive()
for i in range(2):
if rotation_bounds[i] == 0.0:
Expand Down Expand Up @@ -81,7 +84,16 @@ def from_normalized_array(
return cls(rotation=tuple(rotation), translation=tuple(translation)) # type: ignore

def project_to_positive(self) -> None:
"""Project the action to the positive octant."""
"""Project the action to the positive octant.
This is needed when transformin the optimal action accordingly to the random volume transformation.
It might be, that for a negative translation and/or a negative z-rotation, the coordinates defining the
optimal action land in negative space. Since the action defines a coordinate frame which infers a plane
(x-z plane, y normal to the plane), assuming that this plane is still intercepting the positive octant,
it is possible to redefine the action in positive coordinates by projecting it into the positive octant.
It needs to be tested, that the volume transformations keep the optimal action in a reachable space.
Volume transformations are used for data augmentation only, so can be defined in the most convenient way.
"""
tx, ty = self.translation
thz, thx = self.rotation
log.debug(f"Translation before projection: {self.translation}")
Expand All @@ -96,6 +108,22 @@ def project_to_positive(self) -> None:
log.debug(f"Translation after projection: {translation}")
self.translation = translation

@classmethod
def sample(
cls,
rotation_range: tuple[float, float] = (20.0, 5.0),
translation_range: tuple[float, float] = (5.0, 5.0),
) -> Self:
rotation = (
np.random.uniform(-rotation_range[0], rotation_range[0]),
np.random.uniform(-rotation_range[1], rotation_range[1]),
)
translation = (
np.random.uniform(-translation_range[0], translation_range[0]),
np.random.uniform(-translation_range[1], translation_range[1]),
)
return cls(rotation=rotation, translation=translation)


@dataclass(kw_only=True)
class LabelmapStateAction(StateAction):
Expand Down

0 comments on commit 683ed47

Please sign in to comment.