Skip to content

Commit

Permalink
fixed extent in notebooks and rendering
Browse files Browse the repository at this point in the history
add cropped volumes
  • Loading branch information
carlocagnetta committed Aug 13, 2024
1 parent 8f1c75a commit 329fdcf
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 413 deletions.
2 changes: 1 addition & 1 deletion docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
")# Visualization of MRI Data and Labelmaps"
"# Visualization of MRI Data and Labelmaps"
]
},
{
Expand Down
10 changes: 10 additions & 0 deletions docs/02_notebooks/L4_normalize_volumes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@
"sliced_img = sitk.GetArrayFromImage(sliced_volume)\n",
"cluster = TissueClusters.from_labelmap_slice(sliced_img.T)\n",
"show_clusters(cluster, sliced_img.T, extent=transversal_extent)\n",
"print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
Expand All @@ -174,6 +176,14 @@
"reward = anatomy_based_rwd(cluster)\n",
"print(reward)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b0966298f2bc74c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
1 change: 0 additions & 1 deletion docs/02_notebooks/L5_random_volume_transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@
"sliced_img = sitk.GetArrayFromImage(sliced_volume)\n",
"print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n",
"\n",
"extent_xz = (0, size[0], size[2], 0)\n",
"cluster = TissueClusters.from_labelmap_slice(sliced_img.T)\n",
"show_clusters(cluster, sliced_img.T)\n",
"reward = anatomy_based_rwd(cluster)\n",
Expand Down
62 changes: 16 additions & 46 deletions docs/02_notebooks/L6_environment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.config import get_config\n",
"from armscan_env.envs.labelmaps_navigation import (\n",
" LabelmapClusteringBasedReward,\n",
" ArmscanEnv,\n",
" LabelmapClusteringBasedReward,\n",
" LabelmapEnvTerminationCriterion,\n",
")\n",
"from armscan_env.envs.observations import (\n",
" ActionRewardObservation,\n",
" LabelmapSliceAsChannelsObservation,\n",
")\n",
"from armscan_env.envs.rewards import anatomy_based_rwd\n",
Expand Down Expand Up @@ -87,6 +86,11 @@
"z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]\n",
"o = volumes[0].GetOrigin()\n",
"slice_shape = (volumes[0].GetSize()[0], volumes[0].GetSize()[2])\n",
"size = np.array(volumes[0].GetSize()) * np.array(volumes[0].GetSpacing())\n",
"\n",
"transversal_extent = (0, size[0], 0, size[2])\n",
"longitudinal_extent = (0, size[1], 0, size[2])\n",
"frontal_extent = (0, size[0], size[1], 0)\n",
"\n",
"\n",
"# Sample functions for demonstration\n",
Expand All @@ -106,9 +110,9 @@
"\n",
"for i in range(len(t)):\n",
" # Subplot 1: Image with dashed line\n",
" ax1.imshow(img_array_1[40, :, :])\n",
" x_dash = np.arange(img_array_1.shape[2])\n",
" b = volumes[0].TransformPhysicalPointToIndex([o[0], o[1] + t[i], o[2]])[1]\n",
" ax1.imshow(img_array_1[40, :, :], extent=frontal_extent)\n",
" x_dash = np.arange(size[0])\n",
" b = t[i]\n",
" y_dash = linear_function(x_dash, np.tan(np.deg2rad(z[i])), b)\n",
" ax1.set_title(f\"Section {0}\")\n",
" line = ax1.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")[0]\n",
Expand All @@ -120,12 +124,16 @@
" action=ManipulatorAction(rotation=(z[i], 0.0), translation=(0.0, t[i])),\n",
" )\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume).T\n",
" ax2.imshow(sliced_img.T, aspect=6, origin=\"lower\")\n",
" ax2.imshow(\n",
" sliced_img.T,\n",
" origin=\"lower\",\n",
" extent=transversal_extent,\n",
" )\n",
" ax2.set_title(f\"Slice {i}\")\n",
"\n",
" # OBSERVATION\n",
" clusters = TissueClusters.from_labelmap_slice(sliced_img)\n",
" ax3 = show_clusters(clusters, sliced_img, ax3)\n",
" ax3 = show_clusters(clusters, sliced_img, ax3, extent=transversal_extent)\n",
" ax3.set_title(f\"Clusters {i}\")\n",
"\n",
" # REWARD\n",
Expand Down Expand Up @@ -194,7 +202,7 @@
" termination_criterion=LabelmapEnvTerminationCriterion(),\n",
" max_episode_len=10,\n",
" rotation_bounds=(30.0, 10.0),\n",
" translation_bounds=(None, None),\n",
" translation_bounds=(0.0, None),\n",
" render_mode=\"animation\",\n",
" apply_volume_transformation=True,\n",
")\n",
Expand All @@ -216,44 +224,6 @@
"HTML(animation.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"volume_size = volumes[0].GetSize()\n",
"\n",
"env = ArmscanEnv(\n",
" name2volume={\"1\": volumes[6]},\n",
" observation=ActionRewardObservation(action_shape=(2,)).to_array_observation(),\n",
" slice_shape=(volume_size[0], volume_size[2]),\n",
" reward_metric=LabelmapClusteringBasedReward(),\n",
" termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05),\n",
" max_episode_len=10,\n",
" rotation_bounds=(30.0, 10.0),\n",
" translation_bounds=(None, None),\n",
" render_mode=\"animation\",\n",
" project_actions_to=\"zy\",\n",
" apply_volume_transformation=True,\n",
")\n",
"\n",
"observation, info = env.reset()\n",
"for _ in range(50):\n",
" action = env.action_space.sample()\n",
" epsilon = 0.1\n",
" if np.random.rand() > epsilon:\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" else:\n",
" observation, reward, terminated, truncated, info = env.step_to_optimal_state()\n",
" env.render()\n",
"\n",
" if terminated or truncated:\n",
" observation, info = env.reset(reset_render=False)\n",
"animation = env.get_cur_animation()\n",
"env.get_cur_animation_as_html()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L7_linear_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
"from armscan_env.config import get_config\n",
"from armscan_env.envs.base import EnvRollout\n",
"from armscan_env.envs.labelmaps_navigation import (\n",
" LabelmapClusteringBasedReward,\n",
" ArmscanEnv,\n",
" LabelmapClusteringBasedReward,\n",
" LabelmapEnvTerminationCriterion,\n",
")\n",
"from armscan_env.envs.observations import (\n",
Expand Down
Loading

0 comments on commit 329fdcf

Please sign in to comment.