diff --git a/docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb b/docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb index 91fbca4..315da33 100644 --- a/docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb +++ b/docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - ")# Visualization of MRI Data and Labelmaps" + "# Visualization of MRI Data and Labelmaps" ] }, { diff --git a/docs/02_notebooks/L4_normalize_volumes.ipynb b/docs/02_notebooks/L4_normalize_volumes.ipynb index da9b023..fb2962f 100644 --- a/docs/02_notebooks/L4_normalize_volumes.ipynb +++ b/docs/02_notebooks/L4_normalize_volumes.ipynb @@ -153,6 +153,8 @@ "sliced_img = sitk.GetArrayFromImage(sliced_volume)\n", "cluster = TissueClusters.from_labelmap_slice(sliced_img.T)\n", "show_clusters(cluster, sliced_img.T, extent=transversal_extent)\n", + "print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n", + "plt.axis(\"off\")\n", "plt.show()" ] }, @@ -174,6 +176,14 @@ "reward = anatomy_based_rwd(cluster)\n", "print(reward)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b0966298f2bc74c", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/02_notebooks/L5_random_volume_transformations.ipynb b/docs/02_notebooks/L5_random_volume_transformations.ipynb index f73a5ee..963ed39 100644 --- a/docs/02_notebooks/L5_random_volume_transformations.ipynb +++ b/docs/02_notebooks/L5_random_volume_transformations.ipynb @@ -121,7 +121,6 @@ "sliced_img = sitk.GetArrayFromImage(sliced_volume)\n", "print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n", "\n", - "extent_xz = (0, size[0], size[2], 0)\n", "cluster = TissueClusters.from_labelmap_slice(sliced_img.T)\n", "show_clusters(cluster, sliced_img.T)\n", "reward = anatomy_based_rwd(cluster)\n", diff --git a/docs/02_notebooks/L6_environment.ipynb b/docs/02_notebooks/L6_environment.ipynb index 048a949..6e6a69f 100644 --- a/docs/02_notebooks/L6_environment.ipynb +++ b/docs/02_notebooks/L6_environment.ipynb @@ -30,12 +30,11 @@ "from armscan_env.clustering import TissueClusters\n", "from armscan_env.config import get_config\n", "from armscan_env.envs.labelmaps_navigation import (\n", - " LabelmapClusteringBasedReward,\n", " ArmscanEnv,\n", + " LabelmapClusteringBasedReward,\n", " LabelmapEnvTerminationCriterion,\n", ")\n", "from armscan_env.envs.observations import (\n", - " ActionRewardObservation,\n", " LabelmapSliceAsChannelsObservation,\n", ")\n", "from armscan_env.envs.rewards import anatomy_based_rwd\n", @@ -87,6 +86,11 @@ "z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]\n", "o = volumes[0].GetOrigin()\n", "slice_shape = (volumes[0].GetSize()[0], volumes[0].GetSize()[2])\n", + "size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())\n", + "\n", + "transversal_extent = (0, size[0], 0, size[2])\n", + "longitudinal_extent = (0, size[1], 0, size[2])\n", + "frontal_extent = (0, size[0], size[1], 0)\n", "\n", "\n", "# Sample functions for demonstration\n", @@ -106,9 +110,9 @@ "\n", "for i in range(len(t)):\n", " # Subplot 1: Image with dashed line\n", - " ax1.imshow(img_array_1[40, :, :])\n", - " x_dash = np.arange(img_array_1.shape[2])\n", - " b = volumes[0].TransformPhysicalPointToIndex([o[0], o[1] + t[i], o[2]])[1]\n", + " ax1.imshow(img_array_1[40, :, :], extent=frontal_extent)\n", + " x_dash = np.arange(size[0])\n", + " b = t[i]\n", " y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)\n", " ax1.set_title(f\"Section {0}\")\n", " line = ax1.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")[0]\n", @@ -120,12 +124,16 @@ " action=ManipulatorAction(rotation=(z[i], 0.0), translation=(0.0, t[i])),\n", " )\n", " sliced_img = sitk.GetArrayFromImage(sliced_volume).T\n", - " ax2.imshow(sliced_img.T, aspect=6, origin=\"lower\")\n", + " ax2.imshow(\n", + " sliced_img.T,\n", + " origin=\"lower\",\n", + " extent=transversal_extent,\n", + " )\n", " ax2.set_title(f\"Slice {i}\")\n", "\n", " # OBSERVATION\n", " clusters = TissueClusters.from_labelmap_slice(sliced_img)\n", - " ax3 = show_clusters(clusters, sliced_img, ax3)\n", + " ax3 = show_clusters(clusters, sliced_img, ax3, extent=transversal_extent)\n", " ax3.set_title(f\"Clusters {i}\")\n", "\n", " # REWARD\n", @@ -194,7 +202,7 @@ " termination_criterion=LabelmapEnvTerminationCriterion(),\n", " max_episode_len=10,\n", " rotation_bounds=(30.0, 10.0),\n", - " translation_bounds=(None, None),\n", + " translation_bounds=(0.0, None),\n", " render_mode=\"animation\",\n", " apply_volume_transformation=True,\n", ")\n", @@ -216,44 +224,6 @@ "HTML(animation.to_jshtml())" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "volume_size = volumes[0].GetSize()\n", - "\n", - "env = ArmscanEnv(\n", - " name2volume={\"1\": volumes[6]},\n", - " observation=ActionRewardObservation(action_shape=(2,)).to_array_observation(),\n", - " slice_shape=(volume_size[0], volume_size[2]),\n", - " reward_metric=LabelmapClusteringBasedReward(),\n", - " termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05),\n", - " max_episode_len=10,\n", - " rotation_bounds=(30.0, 10.0),\n", - " translation_bounds=(None, None),\n", - " render_mode=\"animation\",\n", - " project_actions_to=\"zy\",\n", - " apply_volume_transformation=True,\n", - ")\n", - "\n", - "observation, info = env.reset()\n", - "for _ in range(50):\n", - " action = env.action_space.sample()\n", - " epsilon = 0.1\n", - " if np.random.rand() > epsilon:\n", - " observation, reward, terminated, truncated, info = env.step(action)\n", - " else:\n", - " observation, reward, terminated, truncated, info = env.step_to_optimal_state()\n", - " env.render()\n", - "\n", - " if terminated or truncated:\n", - " observation, info = env.reset(reset_render=False)\n", - "animation = env.get_cur_animation()\n", - "env.get_cur_animation_as_html()" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/docs/02_notebooks/L7_linear_sweep.ipynb b/docs/02_notebooks/L7_linear_sweep.ipynb index 3718657..61197da 100644 --- a/docs/02_notebooks/L7_linear_sweep.ipynb +++ b/docs/02_notebooks/L7_linear_sweep.ipynb @@ -24,8 +24,8 @@ "from armscan_env.config import get_config\n", "from armscan_env.envs.base import EnvRollout\n", "from armscan_env.envs.labelmaps_navigation import (\n", - " LabelmapClusteringBasedReward,\n", " ArmscanEnv,\n", + " LabelmapClusteringBasedReward,\n", " LabelmapEnvTerminationCriterion,\n", ")\n", "from armscan_env.envs.observations import (\n", diff --git a/notebooks/plots.ipynb b/notebooks/plots.ipynb deleted file mode 100644 index bcce502..0000000 --- a/notebooks/plots.ipynb +++ /dev/null @@ -1,337 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ruff: noqa\n", - "# type: ignore\n", - "# Ruff! Please spare me! I just need to plot the results...\n", - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "\n", - "\n", - "def add_line_to_plot(data_path, label=None):\n", - " # import all csvs in the folder into each df\n", - " files = os.listdir(data_path)\n", - "\n", - " dfs = []\n", - " for file in files:\n", - " if file.endswith(\".csv\"):\n", - " df = pd.read_csv(os.path.join(data_path, file))\n", - " dfs.append(df)\n", - "\n", - " # concatenate all dataframes into one\n", - " df = pd.concat(dfs)\n", - "\n", - " # calculate median and standard deviation of the loss values\n", - " median = df.groupby(\"Step\")[\"Value\"].median().reset_index()\n", - " std = df.groupby(\"Step\")[\"Value\"].std().reset_index()\n", - " upper_bound = median[\"Value\"] + std[\"Value\"]\n", - " lower_bound = median[\"Value\"] - std[\"Value\"]\n", - "\n", - " # add the median loss line to the plot\n", - " sns.lineplot(x=\"Step\", y=\"Value\", data=median, linewidth=2, label=label, alpha=0.8)\n", - " plt.fill_between(median[\"Step\"], upper_bound, lower_bound, alpha=0.2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Reach\n", - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Mean Reward per Episode\")\n", - "# plt.yscale('log')\n", - "plt.title(\"Reach\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"reach\"),\n", - " (\"Autoencoder\", \"reach_ae\"),\n", - " (\"MultiSegmenter\", \"reach_multiseg\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/reach.pdf\")\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Reach Success Rate\n", - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Success Rate\")\n", - "# plt.yscale('log')\n", - "plt.title(\"Reach\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"reach_succ\"),\n", - " (\"Autoencoder\", \"reach_ae_succ\"),\n", - " (\"MultiSegmenter\", \"reach_multiseg_succ\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/reach_succ.pdf\")\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# PickAndPlace\n", - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Mean Reward per Episode\")\n", - "# plt.yscale('log')\n", - "plt.title(\"PickAndPlace\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"pp\"),\n", - " # ('Autoencoder', 'reach_ae'),\n", - " (\"MultiSegmenter\", \"pp_multiseg\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/pp.pdf\")\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Success Rate\")\n", - "# plt.yscale('log')\n", - "plt.title(\"PickAndPlace\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"pp_succ\"),\n", - " # ('Autoencoder', 'reach_ae'),\n", - " (\"MultiSegmenter\", \"pp_multiseg_succ\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/pp_succ.pdf\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# PickAndPlace\n", - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Mean Reward per Episode\")\n", - "# plt.yscale('log')\n", - "plt.title(\"PegInHole\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"ph\"),\n", - " # ('Autoencoder', 'reach_ae'),\n", - " (\"MultiSegmenter\", \"ph_multiseg\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/ph.pdf\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Timesteps\")\n", - "plt.ylabel(\"Success Rate\")\n", - "# plt.yscale('log')\n", - "plt.title(\"PegInHole\")\n", - "\n", - "# add lines to the plot\n", - "lines = [\n", - " (\"Baseline\", \"ph_succ\"),\n", - " # ('Autoencoder', 'reach_ae'),\n", - " (\"MultiSegmenter\", \"ph_multiseg_succ\"),\n", - "]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/ph_succ.pdf\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "\n", - "\n", - "def add_line_to_plot(data_path, label=None):\n", - " # import all csvs in the folder into each df\n", - " files = os.listdir(data_path)\n", - "\n", - " dfs = []\n", - " for file in files:\n", - " if file.endswith(\".csv\"):\n", - " df = pd.read_csv(os.path.join(data_path, file))\n", - " dfs.append(df)\n", - "\n", - " # concatenate all dataframes into one\n", - " df = pd.concat(dfs)\n", - "\n", - " # calculate median and standard deviation of the loss values\n", - " median = df.groupby(\"Step\")[\"Value\"].median().reset_index()\n", - " std = df.groupby(\"Step\")[\"Value\"].std().reset_index()\n", - " upper_bound = median[\"Value\"] + std[\"Value\"]\n", - " lower_bound = median[\"Value\"] - std[\"Value\"]\n", - "\n", - " # add the median loss line to the plot\n", - " sns.lineplot(x=\"Step\", y=\"Value\", data=median, linewidth=2, label=label, alpha=0.8)\n", - " plt.fill_between(median[\"Step\"], upper_bound, lower_bound, alpha=0.5)\n", - "\n", - "\n", - "# create the plot\n", - "plt.figure(figsize=(5, 4))\n", - "plt.xlabel(\"Training Steps\")\n", - "plt.ylabel(\"EMD loss (MSE feature loss)\")\n", - "plt.yscale(\"log\")\n", - "plt.title(\"Autoencoder\")\n", - "\n", - "# add lines to the plot\n", - "lines = [(\"Table\", \"table_ae\"), (\"Peg In Hole\", \"pih_ae\")]\n", - "for label, data_path in lines:\n", - " add_line_to_plot(\"data/\" + data_path, label=label)\n", - "# add_line_to_plot('data3', label='Data 3')\n", - "\n", - "# show the plot\n", - "plt.legend()\n", - "# more grid lines in the background\n", - "plt.grid(which=\"both\", axis=\"both\", linestyle=\"--\", alpha=0.5)\n", - "# nicer colors\n", - "plt.style.use(\"seaborn\")\n", - "\n", - "plt.savefig(\"figures/loss.pdf\")\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/poetry.lock b/poetry.lock index f6ef455..1c90d30 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -3171,6 +3171,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -3191,6 +3192,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -5710,4 +5712,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "74ead410223f2caed225c4476c584c4be59ff55ebd98ff353d914fc1f62a1555" +content-hash = "23e66a7418443ff4710e919fff8757d7e8c1bb15c593914809201f574e2847a4" diff --git a/src/armscan_env/config.py b/src/armscan_env/config.py index b1dfe5e..af081f0 100644 --- a/src/armscan_env/config.py +++ b/src/armscan_env/config.py @@ -34,6 +34,20 @@ def get_labelmaps_basedir(self) -> str: check_existence=True, ) + def get_cropped_labelmaps_basedir(self) -> str: + return self._adjusted_path( + os.path.join(self.data, "cropped"), + relative=False, + check_existence=True, + ) + + def get_single_cropped_labelmap_path(self, labelmap_file_id: int) -> str: + single_labelmap_path = os.path.join( + self.get_cropped_labelmaps_basedir(), + f"{labelmap_file_id:05d}_cropped.nii", + ) + return self._adjusted_path(single_labelmap_path, relative=False, check_existence=True) + def count_labels(self) -> int: labels_dir = self.get_labelmaps_basedir() return len( diff --git a/src/armscan_env/envs/labelmaps_navigation.py b/src/armscan_env/envs/labelmaps_navigation.py index 21ca2bb..a4ac39e 100644 --- a/src/armscan_env/envs/labelmaps_navigation.py +++ b/src/armscan_env/envs/labelmaps_navigation.py @@ -426,7 +426,7 @@ def get_cur_state_plot( raise RuntimeError("The labelmap volume must not be None, did you call reset?") volume = self.cur_labelmap_volume - o = volume.GetOrigin() + volume.GetOrigin() img_array = sitk.GetArrayFromImage(volume) action = self.get_manipulator_action_from_normalized_action( self.cur_state_action.normalized_action_arr, @@ -434,31 +434,34 @@ def get_cur_state_plot( translation = action.translation rotation = action.rotation + size = tuple(volume.GetSize()) * np.array(volume.GetSpacing()) + transversal_extent = (0, size[0], 0, size[2]) + longitudinal_extent = (0, size[2], size[1], 0) + frontal_extent = (0, size[0], size[1], 0) + # Subplot 1: from the top iz = volume.GetSize()[2] // 2 - ax1.imshow(img_array[iz, :, :]) - x_dash = np.arange(img_array.shape[2]) - b = volume.TransformPhysicalPointToIndex( - [o[0] + translation[0], o[1] + translation[1], o[2]], - )[1] - b_x = b + np.tan(np.deg2rad(rotation[1])) * iz + ax1.imshow(img_array[iz, :, :], extent=frontal_extent) + x_dash = np.arange(size[0]) + b = translation[1] + b_x = b + np.tan(np.deg2rad(rotation[1])) * (size[2] // 2) y_dash = np.tan(np.deg2rad(rotation[0])) * x_dash + b_x - y_dash = np.clip(y_dash, 0, img_array.shape[1] - 1) + y_dash = np.clip(y_dash, 0, size[1] - 1) ax1.plot(x_dash, y_dash, linestyle="--", color="red") - ax1.set_title(f"Slice cut (labelmap name: {self.cur_labelmap_name})") + ax1.set_title("Slice cut") # Subplot 2: from the side ix = volume.GetSize()[0] // 2 - ax2.imshow(img_array[:, :, ix].T) - z_dash = np.arange(img_array.shape[0]) - b_z = b + np.tan(np.deg2rad(rotation[0])) * ix + ax2.imshow(img_array[:, :, ix].T, extent=longitudinal_extent) + z_dash = np.arange(size[2]) + b_z = b + np.tan(np.deg2rad(rotation[0])) * (size[0] // 2) y_dash_2 = np.tan(np.deg2rad(rotation[1])) * z_dash + b_z - y_dash_2 = np.clip(y_dash_2, 0, img_array.shape[1] - 1) + y_dash_2 = np.clip(y_dash_2, 0, size[1] - 1) ax2.plot(z_dash, y_dash_2, linestyle="--", color="red") # ACTION sliced_img = self.cur_state_action.labels_2d_slice - ax3.imshow(sliced_img.T, origin="lower", aspect=6) + ax3.imshow(sliced_img.T, origin="lower", extent=transversal_extent, aspect=2) txt = ( "Slice taken at position:\n" @@ -474,7 +477,7 @@ def get_cur_state_plot( # OBSERVATION clusters = TissueClusters.from_labelmap_slice(self.cur_state_action.labels_2d_slice) - show_clusters(clusters, sliced_img, ax5) + show_clusters(clusters, sliced_img, ax5, extent=transversal_extent, aspect=2) # REWARD ax5.text(0, 0, f"Reward: {self.cur_reward:.2f}", fontsize=12, color="red") diff --git a/src/armscan_env/util/visualizations.py b/src/armscan_env/util/visualizations.py index 9ffd896..2815fc1 100644 --- a/src/armscan_env/util/visualizations.py +++ b/src/armscan_env/util/visualizations.py @@ -16,7 +16,7 @@ def _show( cmap: str | None, axis: bool, **imshow_kwargs: Any, -) -> AxesImage | Axes: +) -> np.ndarray[Any, Any]: """Function to display row of image slices. :param slices: list of image slices @@ -33,13 +33,17 @@ def _show( if isinstance(slices[0], np.ndarray) and isinstance(slices[0].shape, tuple): extent = (0, slices[0].shape[0], 0, slices[0].shape[1]) else: - raise TypeError("Expected slice to be a numpy array with a shape attribute of type tuple.") + raise TypeError( + "Expected slice to be a numpy array with a shape attribute of type tuple.", + ) rows = -(-len(slices) // col) fig, ax = plt.subplots(rows, col, figsize=(15, 2 * rows)) # Flatten the ax array to simplify indexing - ax = ax.flatten() + ax = ax.flatten() if isinstance(ax, np.ndarray) else np.array(ax) for i, slice in enumerate(slices): + if i >= ax.size: + break ax[i].imshow(slice, cmap=cmap, origin="lower", extent=extent, **imshow_kwargs) ax[i].set_title(f"Slice {start - i * lap}") # Set titles if desired ax[i].axis("off") if not axis else None # Turn off axis if desired @@ -58,7 +62,7 @@ def show_slices( cmap: str | None = None, axis: bool = False, **imshow_kwargs: Any, -) -> AxesImage | Axes: +) -> np.ndarray[Any, Any]: """Function to display row of image slices. :param data: 3D image data @@ -103,7 +107,9 @@ def show_clusters( if isinstance(slice, np.ndarray) and isinstance(slice.shape, tuple): extent = (0, slice.shape[0], 0, slice.shape[1]) else: - raise TypeError("Expected slice to be a numpy array with a shape attribute of type tuple.") + raise TypeError( + "Expected slice to be a numpy array with a shape attribute of type tuple.", + ) cluster_labels = slice.copy() # Calculate the scaling factors based on the extent and slice shape diff --git a/src/armscan_env/volumes/loading.py b/src/armscan_env/volumes/loading.py index 5b5a9de..5d3e83f 100644 --- a/src/armscan_env/volumes/loading.py +++ b/src/armscan_env/volumes/loading.py @@ -58,6 +58,24 @@ def load_all_labelmaps(cls, normalize_spacing: bool = True) -> list[ImageVolume] volumes = normalize_sitk_volumes_to_highest_spacing(volumes) return volumes + def load_cropped_labelmap(self) -> ImageVolume: + cropped_volume = sitk.ReadImage( + config.get_single_cropped_labelmap_path(self.get_labelmap_id()), + ) + volume = sitk.ReadImage(config.get_single_labelmap_path(self.get_labelmap_id())) + cropped_y = (volume.GetSize()[1] - cropped_volume.GetSize()[1]) * volume.GetSpacing()[1] + optimal_action = self.get_optimal_action() + cropped_opt_y = optimal_action.translation[1] - cropped_y + optimal_action.translation = (0, cropped_opt_y) + return ImageVolume(cropped_volume, optimal_action=optimal_action) + + @classmethod + def load_all_cropped_labelmaps(cls, normalize_spacing: bool = True) -> list[ImageVolume]: + volumes = [labelmap.load_cropped_labelmap() for labelmap in list(cls)[:4]] + if normalize_spacing: + volumes = normalize_sitk_volumes_to_highest_spacing(volumes) + return volumes + def normalize_sitk_volumes_to_highest_spacing( volumes: list[ImageVolume], # n_spacing: tuple[float, float, float], @@ -104,10 +122,14 @@ def normalize_sitk_volumes_to_highest_spacing( def load_sitk_volumes( normalize: bool = True, + cropped: bool = False, ) -> list[ImageVolume]: """Load a SimpleITK volume from a file. :param normalize: whether to normalize the volumes to a single spacing + :param cropped: whether to load the cropped volumes for simplified experiment :return: the loaded volume """ + if cropped: + return RegisteredLabelmap.load_all_cropped_labelmaps(normalize_spacing=normalize) return RegisteredLabelmap.load_all_labelmaps(normalize_spacing=normalize) diff --git a/src/armscan_env/wrapper.py b/src/armscan_env/wrapper.py index 8483cb1..08771b7 100644 --- a/src/armscan_env/wrapper.py +++ b/src/armscan_env/wrapper.py @@ -79,7 +79,7 @@ def action(self, action: ActType) -> np.ndarray: class PatchedFrameStackObservation(PatchedWrapper): - """Had to copy-paste and adjust. + r"""Had to copy-paste and adjust. The inheriting from `RecordConstructorArgs` in original FrameStack is not possible together with overridden getattr, which we however need in order to not become crazy. @@ -100,6 +100,7 @@ class PatchedFrameStackObservation(PatchedWrapper): :param excluded_observation_keys: the keys of the observations to exclude from stacking. The observations with these keys will be passed through without stacking. Can only be used with Dict observation spaces. """ + def __init__( self, env: Env[ObsType, ActType], @@ -272,8 +273,8 @@ def observation( class ObsRewardHeapItem(Generic[ObsType]): - """Heap of the best rewards and their corresponding observations. - """ + """Heap of the best rewards and their corresponding observations.""" + def __init__(self, obs: ObsType, reward: float): self.obs = obs self.reward = reward @@ -307,6 +308,7 @@ class ObsHeap(Generic[ObsType]): :param padding_value: the value to use for padding :param padding_item: the item to use for padding """ + def __init__( self, max_size: int, @@ -349,6 +351,7 @@ class BestActionRewardMemory(AddObservationsWrapper): :param env: The environment to wrap. :param n_best: Number of best states to keep track of. """ + @property def additional_obs_space(self) -> gym.spaces: return self._additional_obs_space @@ -410,8 +413,8 @@ def get_additional_obs_array(self) -> np.ndarray: class ArmscanEnvFactory(EnvFactory): - """Factory for creating ArmscanEnv environments, making use of various wrappers. - """ + """Factory for creating ArmscanEnv environments, making use of various wrappers.""" + def __init__( self, name2volume: dict[str, sitk.Image],