Skip to content

Commit

Permalink
Fixed slicing manipulator action in env
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Jun 28, 2024
1 parent 7a19241 commit 4eafefb
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 62 deletions.
50 changes: 25 additions & 25 deletions docs/02_notebooks/L5_linear_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
"execution_count": null,
"id": "a4e98c0276b6012d",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50b440b37fd9414b",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand All @@ -36,8 +37,7 @@
"from tianshou.highlevel.env import EnvMode\n",
"\n",
"config = get_config()"
],
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -52,6 +52,7 @@
"execution_count": null,
"id": "9ed46c7b",
"metadata": {},
"outputs": [],
"source": [
"def walk_through_env(\n",
" env: LabelmapEnv,\n",
Expand Down Expand Up @@ -122,27 +123,27 @@
"\n",
" if show:\n",
" plt.show()"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da45ed45bb7b8f3b",
"metadata": {},
"outputs": [],
"source": [
"volume_1 = sitk.ReadImage(config.get_labels_path(1))\n",
"volume_2 = sitk.ReadImage(config.get_labels_path(2))\n",
"img_array_1 = sitk.GetArrayFromImage(volume_1)\n",
"img_array_2 = sitk.GetArrayFromImage(volume_2)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63dd92db3829d7db",
"metadata": {},
"outputs": [],
"source": [
"volume_size = volume_1.GetSize()\n",
"\n",
Expand All @@ -158,41 +159,41 @@
" render_mode=\"animation\",\n",
" n_stack=2,\n",
").create_env(EnvMode.WATCH)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16a139f61aaafd19",
"metadata": {},
"outputs": [],
"source": [
"env_rollout = walk_through_env(env, 10)\n",
"\n",
"plot_rollout_rewards(env_rollout)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cdf855cc85a743a",
"metadata": {},
"outputs": [],
"source": [
"env.get_cur_animation_as_html()"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "519dde5f1cea8a5f",
"metadata": {},
"outputs": [],
"source": [
"volume_size = volume_1.GetSize()\n",
"\n",
"projected_env = ArmscanEnvFactory(\n",
" name2volume={\"1\": volume_1},\n",
" name2volume={\"2\": volume_2},\n",
" observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),\n",
" slice_shape=(volume_size[0], volume_size[2]),\n",
" reward_metric=LabelmapClusteringBasedReward(),\n",
Expand All @@ -205,55 +206,54 @@
" project_actions_to=\"y\",\n",
" apply_volume_transformation=True,\n",
").create_env(EnvMode.WATCH)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22877ab71fed2eb0",
"metadata": {},
"outputs": [],
"source": [
"projected_env_rollout = walk_through_env(\n",
" projected_env,\n",
" 10,\n",
" render_title=\"Projected labelmap slice\",\n",
")\n",
"plot_rollout_rewards(projected_env_rollout)"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2779884526e0716",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" \"Observed 'rewards': \\n\",\n",
" [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],\n",
")\n",
"print(\"Env rewards: \\n\", [round(r, 4) for r in projected_env_rollout.rewards])"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ada94c94fe77de0",
"metadata": {},
"outputs": [],
"source": [
"projected_env.get_cur_animation_as_html()"
],
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fb82c1487521b11",
"metadata": {},
"source": [],
"outputs": []
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ lint = ["_black_check", "_ruff_check", "_ruff_check_nb"]
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
_poetry_sort = "poetry sort"
clean-nbs = "python docs/nbstripout.py"
format = ["_black_format", "_ruff_format", "_ruff_format_nb"]
format = ["_black_format", "_ruff_format", "_ruff_format_nb", "_poetry_install_sort_plugin", "_poetry_sort"]
_autogen_rst = "python docs/autogen_rst.py"
_sphinx_build = "sphinx-build -W -b html docs docs/_build"
_jb_generate_toc = "python docs/create_toc.py"
Expand Down
10 changes: 5 additions & 5 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@

sampling_config = SamplingConfig(
num_epochs=10,
step_per_epoch=100000,
num_train_envs=-1,
step_per_epoch=10,
num_train_envs=1,
num_test_envs=1,
buffer_size=100000,
buffer_size=10,
batch_size=256,
step_per_collect=200,
update_per_step=2,
start_timesteps=5000,
start_timesteps=1,
start_timesteps_random=True,
)

Expand All @@ -61,7 +61,7 @@
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
apply_volume_transformation=True
apply_volume_transformation=True,
)

experiment = (
Expand Down
36 changes: 9 additions & 27 deletions src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.envs.state_action import LabelmapStateAction, ManipulatorAction
from armscan_env.util.visualizations import show_clusters
from armscan_env.volumes.slicing import slice_volume, transform_volume
from armscan_env.volumes.slicing import EulerTransform, slice_volume, transform_volume
from celluloid import Camera
from IPython.core.display import HTML
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -243,10 +243,7 @@ def _get_slice_from_action(self, action: np.ndarray | ManipulatorAction) -> np.n
sliced_volume = slice_volume(
volume=self.cur_labelmap_volume,
slice_shape=self._slice_shape,
z_rotation=manipulator_action.rotation[0],
x_rotation=manipulator_action.rotation[1],
x_trans=manipulator_action.translation[0],
y_trans=manipulator_action.translation[1],
action=manipulator_action,
)
return sitk.GetArrayFromImage(sliced_volume)[:, 0, :].T

Expand All @@ -272,32 +269,17 @@ def apply_volume_transformation(
volume: sitk.Image,
optimal_action: ManipulatorAction,
) -> (sitk.Image, ManipulatorAction): # type: ignore
small_random_z_rotation = np.random.uniform(-20, 20)
small_random_x_rotation = np.random.uniform(-5, 5)
small_random_x_translation = np.random.uniform(-25, 25)
small_random_y_translation = np.random.uniform(-5, 5)
transformed_optimal_action = ManipulatorAction(
rotation=(
optimal_action.rotation[0] + small_random_z_rotation,
optimal_action.rotation[1] + small_random_x_rotation,
),
translation=(
optimal_action.translation[0],
optimal_action.translation[1] + small_random_y_translation,
),
volume_transformation = ManipulatorAction(
rotation=(np.random.uniform(-20, 20), np.random.uniform(-5, 5)),
translation=(np.random.uniform(-5, 5), np.random.uniform(-5, 5)),
)
transformed_optimal_action = EulerTransform(volume_transformation).transform_action(
optimal_action,
)
if self.rotation_bounds:
bounds = list(self.rotation_bounds)
bounds[0] += abs(small_random_z_rotation)
bounds[1] += abs(small_random_x_rotation)
self.rotation_bounds = tuple(bounds) # type: ignore
return (
transform_volume(
volume=volume,
z_rotation=small_random_z_rotation,
x_rotation=small_random_x_rotation,
x_trans=small_random_x_translation,
y_trans=small_random_y_translation,
action=volume_transformation,
),
transformed_optimal_action,
)
Expand Down
6 changes: 3 additions & 3 deletions src/armscan_env/envs/state_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def to_normalized_array(
rotation = np.zeros(2)
translation = np.zeros(2)
if self.translation[0] < 0 or self.translation[1] < 0:
log.info("Projecting to positive because negative defined translation")
log.debug("Projecting to positive because negative defined translation")
self.project_to_positive()
for i in range(2):
if rotation_bounds[i] == 0.0:
Expand Down Expand Up @@ -84,7 +84,7 @@ def project_to_positive(self) -> None:
"""Project the action to the positive octant."""
tx, ty = self.translation
thz, thx = self.rotation
log.info(f"Translation before projection: {self.translation}")
log.debug(f"Translation before projection: {self.translation}")
while tx < 0 or ty < 0:
if tx < 0:
ty = (np.tan(np.deg2rad(thz)) * (-tx)) + ty
Expand All @@ -93,7 +93,7 @@ def project_to_positive(self) -> None:
tx = ((1 / np.tan(np.deg2rad(thz))) * (-ty)) + tx
ty = 0
translation = (tx, ty)
log.info(f"Translation after projection: {translation}")
log.debug(f"Translation after projection: {translation}")
self.translation = translation


Expand Down
2 changes: 1 addition & 1 deletion src/armscan_env/volumes/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def transform_action(self, relative_action: ManipulatorAction) -> ManipulatorAct
translation=new_action_translation,
)

log.info(
log.debug(
f"Random transformation: {self.action}\n"
f"Original action: {relative_action}\n"
f"Transformed action: {transformed_action}\n",
Expand Down

0 comments on commit 4eafefb

Please sign in to comment.