Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix volume transformation, add observation wrapper #7

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5138bd7
fix observations on MultiBoxSpace
carlocagnetta Jun 28, 2024
e0eee64
fix volume transformation,
carlocagnetta Jun 28, 2024
2e55354
fix clustering
carlocagnetta Jun 28, 2024
5622434
new volumes src
carlocagnetta Jun 28, 2024
893e4c9
AddRewardDetailsWrapper
carlocagnetta Jun 28, 2024
59bb759
fix test slicing
carlocagnetta Jun 28, 2024
7c93fd3
fix test slicing
carlocagnetta Jun 28, 2024
50c516f
fix experiment parameters
carlocagnetta Jun 28, 2024
7a19241
merged
carlocagnetta Jun 28, 2024
4eafefb
Fixed slicing manipulator action in env
carlocagnetta Jun 28, 2024
d5f0299
Add by mistake
carlocagnetta Jun 28, 2024
5050cdb
fix optimal action 2
carlocagnetta Jun 30, 2024
1a259e2
transforming optimal action, and referencing back when slicing
carlocagnetta Jun 30, 2024
d54aaf3
Actions: improve triggers
Jul 1, 2024
a6c6958
Added a safety feature to TransformedVolume
Jul 1, 2024
f005d7f
small notebook fixes
carlocagnetta Jul 1, 2024
569cb65
Add notebooks folder
carlocagnetta Jul 1, 2024
b72f0ed
ToDo: caching
carlocagnetta Jul 3, 2024
11ffb07
add notebooks to poe tasks
carlocagnetta Jul 3, 2024
683ed47
changed random volume transformation to take action as parameter
carlocagnetta Jul 3, 2024
0c0bbde
spelling
carlocagnetta Jul 3, 2024
eecf76b
using Framestack before AddRewardDetails, including flatte observatio…
carlocagnetta Jul 3, 2024
517421d
script for debugging
carlocagnetta Jul 3, 2024
972bd05
loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
a6ccc46
spelling
carlocagnetta Jul 4, 2024
b57769b
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
46b182c
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
fb06717
test optimal action
carlocagnetta Jul 4, 2024
ed1f525
Changed slicing to use standard sitk transform
carlocagnetta Jul 12, 2024
9c39076
add 2 DoF
carlocagnetta Jul 12, 2024
d1c068c
Add more volumes
carlocagnetta Jul 15, 2024
09ea174
turned volume
carlocagnetta Jul 15, 2024
7362d4f
tuned labelmaps
carlocagnetta Jul 15, 2024
f659d59
Add more volumes
carlocagnetta Jul 15, 2024
511a051
add volume 11
carlocagnetta Jul 18, 2024
70d9de8
Create ImageVolume class for labelmaps with set optimal action
carlocagnetta Jul 18, 2024
baebae8
Merge branch 'volume_rand_transformation' of https://github.com/appli…
carlocagnetta Jul 18, 2024
a2070a5
Fixed clustering, restricted termination, added volume
carlocagnetta Jul 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions .github/workflows/lint_and_docs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
name: PEP8, Types and Docs Check

on: [push, pull_request]

on:
pull_request:
branches:
- main
push:
branches:
- main
workflow_dispatch:
inputs:
debug_enabled:
type: boolean
description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)'
required: false
default: false
jobs:
check:
runs-on: ubuntu-latest
Expand All @@ -12,6 +24,10 @@ jobs:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
# Enable tmate debugging of manually-triggered workflows if the input option was provided
- name: Setup tmate session
uses: mxschmitt/action-tmate@v3
if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled }}
- name: Cancel previous run
uses: styfle/[email protected]
with:
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: Ubuntu

on: [push, pull_request]
on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cpu:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ repos:
language: system
- id: clean-nbs
name: clean-nbs
entry: poetry run python docs/nbstripout.py
entry: poetry run python nbstripout.py
language: system
3 changes: 3 additions & 0 deletions data/labels/00011_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/labels/00013_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/labels/00017_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/labels/00018_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/labels/00035_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/labels/00042_labels.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00011.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00013.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00017.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00018.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00035.nii
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/mri/00042.nii
Git LFS file not shown
4 changes: 2 additions & 2 deletions docs/02_notebooks/L0_MRI_and_Labelmaps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"metadata": {},
"outputs": [],
"source": [
"mri_1 = sitk.ReadImage(config.get_mri_path(1))\n",
"mri_1 = sitk.ReadImage(config.get_single_mri_path(1))\n",
"mri_1_data = sitk.GetArrayFromImage(mri_1)\n",
"print(f\"{mri_1_data.shape=}\")"
]
Expand Down Expand Up @@ -106,7 +106,7 @@
"metadata": {},
"outputs": [],
"source": [
"mri_1_label = sitk.ReadImage(config.get_labels_path(1))\n",
"mri_1_label = sitk.ReadImage(config.get_single_labelmap_path(1))\n",
"mri_1_label_data = sitk.GetArrayFromImage(mri_1_label)\n",
"print(f\"{mri_1_label_data.shape =}\")"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L1_simple_clustering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"mri_1_label = sitk.ReadImage(config.get_labels_path(1))\n",
"mri_1_label = sitk.ReadImage(config.get_single_labelmap_path(1))\n",
"mri_1_label_data = sitk.GetArrayFromImage(mri_1_label)\n",
"print(f\"{mri_1_label_data.shape=}\")"
]
Expand Down
60 changes: 36 additions & 24 deletions docs/02_notebooks/L2_DBSCAN_clustering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"metadata": {},
"outputs": [],
"source": [
"path_to_mri = config.get_labels_path(1)\n",
"path_to_mri = config.get_single_labelmap_path(1)\n",
"mri_1_label = sitk.ReadImage(path_to_mri)\n",
"mri_1_label_data = sitk.GetArrayFromImage(mri_1_label)\n",
"print(f\"{mri_1_label_data.shape=}\")"
Expand Down Expand Up @@ -207,18 +207,23 @@
"zero_loss_indices = np.where(np.array(sweep_loss) == 0)[0]\n",
"print(f\"{len(zero_loss_indices)} indices return a zero loss: \", zero_loss_indices)\n",
"\n",
"fig, axes = plt.subplots(2, 4, figsize=(21, 7))\n",
"axes = axes.flatten()\n",
"for i, idx in enumerate(zero_loss_indices):\n",
" axes[i] = show_clusters(\n",
" tissue_clusters=clusters_list[idx],\n",
" slice=mri_1_label_data[:, idx, :].T,\n",
" aspect=6,\n",
" ax=axes[i],\n",
" )\n",
" axes[i].set_title(f\"Index: {idx}, Loss: {sweep_loss[idx]:.2f}\")\n",
"nrows = 2\n",
"ncols = len(zero_loss_indices) // nrows\n",
"indices_to_display = nrows * ncols\n",
"\n",
"plt.show()"
"if indices_to_display > 0:\n",
" fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(21, 7))\n",
" axes = axes.flatten()\n",
" for i, idx in enumerate(zero_loss_indices[:indices_to_display]):\n",
" axes[i] = show_clusters(\n",
" tissue_clusters=clusters_list[idx],\n",
" slice=mri_1_label_data[:, idx, :].T,\n",
" aspect=6,\n",
" ax=axes[i],\n",
" )\n",
" axes[i].set_title(f\"Index: {idx}, Loss: {sweep_loss[idx]:.2f}\")\n",
"\n",
" plt.show()"
]
},
{
Expand All @@ -243,7 +248,7 @@
"\n",
"for i in range(mri_1_label_data.shape[1]):\n",
" clusters = TissueClusters.from_labelmap_slice(mri_1_label_data[:, i, :].T)\n",
" loss = anatomy_based_rwd(clusters, n_landmarks=[7, 2, 1])\n",
" loss = anatomy_based_rwd(clusters, n_landmarks=[7, 3, 1])\n",
" if loss == 0:\n",
" zero_loss_clusters.append(clusters)\n",
" print(f\"Loss for slice {i}: {loss}\")\n",
Expand Down Expand Up @@ -289,21 +294,28 @@
"metadata": {},
"outputs": [],
"source": [
"# TODO: reduce duplication with printing above, move to a function\n",
"\n",
"zero_loss_indices = np.where(np.array(sweep_loss) == 0)[0]\n",
"print(f\"{len(zero_loss_indices)} indices return a zero loss: \", zero_loss_indices)\n",
"\n",
"fig, axes = plt.subplots(2, 4, figsize=(21, 7))\n",
"axes = axes.flatten()\n",
"for i, idx in enumerate(zero_loss_indices):\n",
" axes[i] = show_clusters(\n",
" tissue_clusters=zero_loss_clusters[i],\n",
" slice=mri_1_label_data[:, idx, :].T,\n",
" aspect=6,\n",
" ax=axes[i],\n",
" )\n",
" axes[i].set_title(f\"Index: {idx}, Loss: {sweep_loss[idx]:.2f}\")\n",
"nrows = 2\n",
"ncols = len(zero_loss_indices) // nrows\n",
"indices_to_display = nrows * ncols\n",
"\n",
"plt.show()"
"if indices_to_display > 0:\n",
" fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(21, 7))\n",
" axes = axes.flatten()\n",
" for i, idx in enumerate(zero_loss_indices[:indices_to_display]):\n",
" axes[i] = show_clusters(\n",
" tissue_clusters=zero_loss_clusters[i],\n",
" slice=mri_1_label_data[:, idx, :].T,\n",
" aspect=6,\n",
" ax=axes[i],\n",
" )\n",
" axes[i].set_title(f\"Index: {idx}, Loss: {sweep_loss[idx]:.2f}\")\n",
"\n",
" plt.show()"
]
}
],
Expand Down
40 changes: 17 additions & 23 deletions docs/02_notebooks/L3_slicing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import SimpleITK as sitk\n",
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.config import get_config\n",
"from armscan_env.envs.rewards import anatomy_based_rwd\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.util.visualizations import show_clusters\n",
"from armscan_env.volumes.volumes import ImageVolume\n",
"from celluloid import Camera\n",
"from IPython.core.display import HTML\n",
"\n",
"config = get_config()"
Expand All @@ -46,7 +52,7 @@
"metadata": {},
"outputs": [],
"source": [
"volume = sitk.ReadImage(config.get_labels_path(1))\n",
"volume = sitk.ReadImage(config.get_single_labelmap_path(1))\n",
"img_array = sitk.GetArrayFromImage(volume)\n",
"print(f\"{img_array.shape=}\")"
]
Expand Down Expand Up @@ -164,7 +170,6 @@
"# (cosine of the angle between the normal vector and the x axis: z-rotation)\n",
"w = int(abs(volume_size[0] // e1[0]))\n",
"\n",
"\n",
"print(f\" {h=},\\n {w=}\")"
]
},
Expand Down Expand Up @@ -252,17 +257,12 @@
"metadata": {},
"outputs": [],
"source": [
"from armscan_env.slicing import slice_volume\n",
"\n",
"sliced_volume = slice_volume(\n",
" z_rotation=19.3,\n",
" x_rotation=0.0,\n",
" x_trans=0.0,\n",
" y_trans=140.0,\n",
" volume=volume,\n",
"volume = ImageVolume(volume)\n",
"sliced_volume = volume.get_volume_slice(\n",
" action=ManipulatorAction(rotation=(19.3, 0), translation=(0, 140)),\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
")\n",
"sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]\n",
"sliced_img = sitk.GetArrayFromImage(sliced_volume)\n",
"print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n",
"\n",
"slice = sliced_img\n",
Expand All @@ -276,10 +276,6 @@
"metadata": {},
"outputs": [],
"source": [
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.envs.rewards import anatomy_based_rwd\n",
"from armscan_env.util.visualizations import show_clusters\n",
"\n",
"clusters = TissueClusters.from_labelmap_slice(sliced_img.T)\n",
"show_clusters(clusters, sliced_img.T)\n",
"reward = anatomy_based_rwd(clusters, (4, 2, 1))\n",
Expand All @@ -294,8 +290,6 @@
},
"outputs": [],
"source": [
"from celluloid import Camera\n",
"\n",
"# Demonstration of arbitrary slicing\n",
"t = [160, 155, 150, 148, 146, 142, 140, 140, 115, 120, 125, 125, 130, 130, 135, 138, 140, 140, 140]\n",
"z = [0, -5, 0, 0, 5, 15, 19.3, -10, 0, 0, 0, 5, -8, 8, 0, -10, -10, 10, 19.3]\n",
Expand All @@ -320,14 +314,14 @@
" ax1.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")\n",
"\n",
" # Subplot 2: Function image\n",
" sliced_volume = slice_volume(\n",
" volume=volume,\n",
" sliced_volume = volume.get_volume_slice(\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
" z_rotation=z[i],\n",
" x_rotation=0,\n",
" y_trans=t[i],\n",
" action=ManipulatorAction(\n",
" rotation=(z[i], 0),\n",
" translation=(0, t[i]),\n",
" ),\n",
" )\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume)\n",
" ax2.set_title(f\"Slice {i}\")\n",
" ax2.imshow(sliced_img, aspect=6)\n",
" camera.snap()\n",
Expand Down
Loading
Loading