From 4eafefb50f1595c4fe5626663d55c8e769ab8b5d Mon Sep 17 00:00:00 2001 From: carlocagnetta Date: Fri, 28 Jun 2024 20:31:04 +0200 Subject: [PATCH] Fixed slicing manipulator action in env --- docs/02_notebooks/L5_linear_sweep.ipynb | 50 ++++++++++---------- pyproject.toml | 2 +- scripts/armscan_array_obs.py | 10 ++-- src/armscan_env/envs/labelmaps_navigation.py | 36 ++++---------- src/armscan_env/envs/state_action.py | 6 +-- src/armscan_env/volumes/slicing.py | 2 +- 6 files changed, 44 insertions(+), 62 deletions(-) diff --git a/docs/02_notebooks/L5_linear_sweep.ipynb b/docs/02_notebooks/L5_linear_sweep.ipynb index c4ab64d..956e2fe 100644 --- a/docs/02_notebooks/L5_linear_sweep.ipynb +++ b/docs/02_notebooks/L5_linear_sweep.ipynb @@ -5,17 +5,18 @@ "execution_count": null, "id": "a4e98c0276b6012d", "metadata": {}, + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "50b440b37fd9414b", "metadata": {}, + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -36,8 +37,7 @@ "from tianshou.highlevel.env import EnvMode\n", "\n", "config = get_config()" - ], - "outputs": [] + ] }, { "cell_type": "markdown", @@ -52,6 +52,7 @@ "execution_count": null, "id": "9ed46c7b", "metadata": {}, + "outputs": [], "source": [ "def walk_through_env(\n", " env: LabelmapEnv,\n", @@ -122,27 +123,27 @@ "\n", " if show:\n", " plt.show()" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "da45ed45bb7b8f3b", "metadata": {}, + "outputs": [], "source": [ "volume_1 = sitk.ReadImage(config.get_labels_path(1))\n", "volume_2 = sitk.ReadImage(config.get_labels_path(2))\n", "img_array_1 = sitk.GetArrayFromImage(volume_1)\n", "img_array_2 = sitk.GetArrayFromImage(volume_2)" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "63dd92db3829d7db", "metadata": {}, + "outputs": [], "source": [ "volume_size = volume_1.GetSize()\n", "\n", @@ -158,41 +159,41 @@ " render_mode=\"animation\",\n", " n_stack=2,\n", ").create_env(EnvMode.WATCH)" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "16a139f61aaafd19", "metadata": {}, + "outputs": [], "source": [ "env_rollout = walk_through_env(env, 10)\n", "\n", "plot_rollout_rewards(env_rollout)" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "6cdf855cc85a743a", "metadata": {}, + "outputs": [], "source": [ "env.get_cur_animation_as_html()" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "519dde5f1cea8a5f", "metadata": {}, + "outputs": [], "source": [ "volume_size = volume_1.GetSize()\n", "\n", "projected_env = ArmscanEnvFactory(\n", - " name2volume={\"1\": volume_1},\n", + " name2volume={\"2\": volume_2},\n", " observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),\n", " slice_shape=(volume_size[0], volume_size[2]),\n", " reward_metric=LabelmapClusteringBasedReward(),\n", @@ -205,14 +206,14 @@ " project_actions_to=\"y\",\n", " apply_volume_transformation=True,\n", ").create_env(EnvMode.WATCH)" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "22877ab71fed2eb0", "metadata": {}, + "outputs": [], "source": [ "projected_env_rollout = walk_through_env(\n", " projected_env,\n", @@ -220,40 +221,39 @@ " render_title=\"Projected labelmap slice\",\n", ")\n", "plot_rollout_rewards(projected_env_rollout)" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "c2779884526e0716", "metadata": {}, + "outputs": [], "source": [ "print(\n", " \"Observed 'rewards': \\n\",\n", " [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],\n", ")\n", "print(\"Env rewards: \\n\", [round(r, 4) for r in projected_env_rollout.rewards])" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "6ada94c94fe77de0", "metadata": {}, + "outputs": [], "source": [ "projected_env.get_cur_animation_as_html()" - ], - "outputs": [] + ] }, { "cell_type": "code", "execution_count": null, "id": "4fb82c1487521b11", "metadata": {}, - "source": [], - "outputs": [] + "outputs": [], + "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 8dbee5c..9644114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,7 +154,7 @@ lint = ["_black_check", "_ruff_check", "_ruff_check_nb"] _poetry_install_sort_plugin = "poetry self add poetry-plugin-sort" _poetry_sort = "poetry sort" clean-nbs = "python docs/nbstripout.py" -format = ["_black_format", "_ruff_format", "_ruff_format_nb"] +format = ["_black_format", "_ruff_format", "_ruff_format_nb", "_poetry_install_sort_plugin", "_poetry_sort"] _autogen_rst = "python docs/autogen_rst.py" _sphinx_build = "sphinx-build -W -b html docs docs/_build" _jb_generate_toc = "python docs/create_toc.py" diff --git a/scripts/armscan_array_obs.py b/scripts/armscan_array_obs.py index cb8f78d..77210ff 100644 --- a/scripts/armscan_array_obs.py +++ b/scripts/armscan_array_obs.py @@ -33,14 +33,14 @@ sampling_config = SamplingConfig( num_epochs=10, - step_per_epoch=100000, - num_train_envs=-1, + step_per_epoch=10, + num_train_envs=1, num_test_envs=1, - buffer_size=100000, + buffer_size=10, batch_size=256, step_per_collect=200, update_per_step=2, - start_timesteps=5000, + start_timesteps=1, start_timesteps_random=True, ) @@ -61,7 +61,7 @@ termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1), reward_metric=LabelmapClusteringBasedReward(), project_actions_to="y", - apply_volume_transformation=True + apply_volume_transformation=True, ) experiment = ( diff --git a/src/armscan_env/envs/labelmaps_navigation.py b/src/armscan_env/envs/labelmaps_navigation.py index 81f72fe..2b550cd 100644 --- a/src/armscan_env/envs/labelmaps_navigation.py +++ b/src/armscan_env/envs/labelmaps_navigation.py @@ -16,7 +16,7 @@ from armscan_env.envs.rewards import LabelmapClusteringBasedReward from armscan_env.envs.state_action import LabelmapStateAction, ManipulatorAction from armscan_env.util.visualizations import show_clusters -from armscan_env.volumes.slicing import slice_volume, transform_volume +from armscan_env.volumes.slicing import EulerTransform, slice_volume, transform_volume from celluloid import Camera from IPython.core.display import HTML from matplotlib import pyplot as plt @@ -243,10 +243,7 @@ def _get_slice_from_action(self, action: np.ndarray | ManipulatorAction) -> np.n sliced_volume = slice_volume( volume=self.cur_labelmap_volume, slice_shape=self._slice_shape, - z_rotation=manipulator_action.rotation[0], - x_rotation=manipulator_action.rotation[1], - x_trans=manipulator_action.translation[0], - y_trans=manipulator_action.translation[1], + action=manipulator_action, ) return sitk.GetArrayFromImage(sliced_volume)[:, 0, :].T @@ -272,32 +269,17 @@ def apply_volume_transformation( volume: sitk.Image, optimal_action: ManipulatorAction, ) -> (sitk.Image, ManipulatorAction): # type: ignore - small_random_z_rotation = np.random.uniform(-20, 20) - small_random_x_rotation = np.random.uniform(-5, 5) - small_random_x_translation = np.random.uniform(-25, 25) - small_random_y_translation = np.random.uniform(-5, 5) - transformed_optimal_action = ManipulatorAction( - rotation=( - optimal_action.rotation[0] + small_random_z_rotation, - optimal_action.rotation[1] + small_random_x_rotation, - ), - translation=( - optimal_action.translation[0], - optimal_action.translation[1] + small_random_y_translation, - ), + volume_transformation = ManipulatorAction( + rotation=(np.random.uniform(-20, 20), np.random.uniform(-5, 5)), + translation=(np.random.uniform(-5, 5), np.random.uniform(-5, 5)), + ) + transformed_optimal_action = EulerTransform(volume_transformation).transform_action( + optimal_action, ) - if self.rotation_bounds: - bounds = list(self.rotation_bounds) - bounds[0] += abs(small_random_z_rotation) - bounds[1] += abs(small_random_x_rotation) - self.rotation_bounds = tuple(bounds) # type: ignore return ( transform_volume( volume=volume, - z_rotation=small_random_z_rotation, - x_rotation=small_random_x_rotation, - x_trans=small_random_x_translation, - y_trans=small_random_y_translation, + action=volume_transformation, ), transformed_optimal_action, ) diff --git a/src/armscan_env/envs/state_action.py b/src/armscan_env/envs/state_action.py index defcf90..495863b 100644 --- a/src/armscan_env/envs/state_action.py +++ b/src/armscan_env/envs/state_action.py @@ -31,7 +31,7 @@ def to_normalized_array( rotation = np.zeros(2) translation = np.zeros(2) if self.translation[0] < 0 or self.translation[1] < 0: - log.info("Projecting to positive because negative defined translation") + log.debug("Projecting to positive because negative defined translation") self.project_to_positive() for i in range(2): if rotation_bounds[i] == 0.0: @@ -84,7 +84,7 @@ def project_to_positive(self) -> None: """Project the action to the positive octant.""" tx, ty = self.translation thz, thx = self.rotation - log.info(f"Translation before projection: {self.translation}") + log.debug(f"Translation before projection: {self.translation}") while tx < 0 or ty < 0: if tx < 0: ty = (np.tan(np.deg2rad(thz)) * (-tx)) + ty @@ -93,7 +93,7 @@ def project_to_positive(self) -> None: tx = ((1 / np.tan(np.deg2rad(thz))) * (-ty)) + tx ty = 0 translation = (tx, ty) - log.info(f"Translation after projection: {translation}") + log.debug(f"Translation after projection: {translation}") self.translation = translation diff --git a/src/armscan_env/volumes/slicing.py b/src/armscan_env/volumes/slicing.py index 6c85d25..604d90b 100644 --- a/src/armscan_env/volumes/slicing.py +++ b/src/armscan_env/volumes/slicing.py @@ -102,7 +102,7 @@ def transform_action(self, relative_action: ManipulatorAction) -> ManipulatorAct translation=new_action_translation, ) - log.info( + log.debug( f"Random transformation: {self.action}\n" f"Original action: {relative_action}\n" f"Transformed action: {transformed_action}\n",