Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix volume transformation, add observation wrapper #7

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5138bd7
fix observations on MultiBoxSpace
carlocagnetta Jun 28, 2024
e0eee64
fix volume transformation,
carlocagnetta Jun 28, 2024
2e55354
fix clustering
carlocagnetta Jun 28, 2024
5622434
new volumes src
carlocagnetta Jun 28, 2024
893e4c9
AddRewardDetailsWrapper
carlocagnetta Jun 28, 2024
59bb759
fix test slicing
carlocagnetta Jun 28, 2024
7c93fd3
fix test slicing
carlocagnetta Jun 28, 2024
50c516f
fix experiment parameters
carlocagnetta Jun 28, 2024
7a19241
merged
carlocagnetta Jun 28, 2024
4eafefb
Fixed slicing manipulator action in env
carlocagnetta Jun 28, 2024
d5f0299
Add by mistake
carlocagnetta Jun 28, 2024
5050cdb
fix optimal action 2
carlocagnetta Jun 30, 2024
1a259e2
transforming optimal action, and referencing back when slicing
carlocagnetta Jun 30, 2024
d54aaf3
Actions: improve triggers
Jul 1, 2024
a6c6958
Added a safety feature to TransformedVolume
Jul 1, 2024
f005d7f
small notebook fixes
carlocagnetta Jul 1, 2024
569cb65
Add notebooks folder
carlocagnetta Jul 1, 2024
b72f0ed
ToDo: caching
carlocagnetta Jul 3, 2024
11ffb07
add notebooks to poe tasks
carlocagnetta Jul 3, 2024
683ed47
changed random volume transformation to take action as parameter
carlocagnetta Jul 3, 2024
0c0bbde
spelling
carlocagnetta Jul 3, 2024
eecf76b
using Framestack before AddRewardDetails, including flatte observatio…
carlocagnetta Jul 3, 2024
517421d
script for debugging
carlocagnetta Jul 3, 2024
972bd05
loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
a6ccc46
spelling
carlocagnetta Jul 4, 2024
b57769b
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
46b182c
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
fb06717
test optimal action
carlocagnetta Jul 4, 2024
ed1f525
Changed slicing to use standard sitk transform
carlocagnetta Jul 12, 2024
9c39076
add 2 DoF
carlocagnetta Jul 12, 2024
d1c068c
Add more volumes
carlocagnetta Jul 15, 2024
09ea174
turned volume
carlocagnetta Jul 15, 2024
7362d4f
tuned labelmaps
carlocagnetta Jul 15, 2024
f659d59
Add more volumes
carlocagnetta Jul 15, 2024
511a051
add volume 11
carlocagnetta Jul 18, 2024
70d9de8
Create ImageVolume class for labelmaps with set optimal action
carlocagnetta Jul 18, 2024
baebae8
Merge branch 'volume_rand_transformation' of https://github.com/appli…
carlocagnetta Jul 18, 2024
a2070a5
Fixed clustering, restricted termination, added volume
carlocagnetta Jul 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions docs/02_notebooks/L3_slicing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,11 @@
"metadata": {},
"outputs": [],
"source": [
"from armscan_env.slicing import slice_volume\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.volumes.slicing import slice_volume\n",
"\n",
"sliced_volume = slice_volume(\n",
" z_rotation=19.3,\n",
" x_rotation=0.0,\n",
" x_trans=0.0,\n",
" y_trans=140.0,\n",
" action=ManipulatorAction(rotation=(19.3, 0), translation=(0, 140)),\n",
" volume=volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
")\n",
Expand Down Expand Up @@ -323,9 +321,10 @@
" sliced_volume = slice_volume(\n",
" volume=volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
" z_rotation=z[i],\n",
" x_rotation=0,\n",
" y_trans=t[i],\n",
" action=ManipulatorAction(\n",
" rotation=(z[i], 0),\n",
" translation=(0, t[i]),\n",
" ),\n",
" )\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]\n",
" ax2.set_title(f\"Slice {i}\")\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L4_environment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.config import get_config\n",
"from armscan_env.envs.rewards import anatomy_based_rwd\n",
"from armscan_env.slicing import slice_volume\n",
"from armscan_env.util.visualizations import show_clusters\n",
"from armscan_env.volumes.slicing import slice_volume\n",
"from IPython.core.display import HTML\n",
"\n",
"config = get_config()"
Expand Down
48 changes: 24 additions & 24 deletions docs/02_notebooks/L5_linear_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
"execution_count": null,
"id": "a4e98c0276b6012d",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "50b440b37fd9414b",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand All @@ -37,7 +36,8 @@
"from tianshou.highlevel.env import EnvMode\n",
"\n",
"config = get_config()"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -52,7 +52,6 @@
"execution_count": null,
"id": "9ed46c7b",
"metadata": {},
"outputs": [],
"source": [
"def walk_through_env(\n",
" env: LabelmapEnv,\n",
Expand Down Expand Up @@ -123,27 +122,27 @@
"\n",
" if show:\n",
" plt.show()"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "da45ed45bb7b8f3b",
"metadata": {},
"outputs": [],
"source": [
"volume_1 = sitk.ReadImage(config.get_labels_path(1))\n",
"volume_2 = sitk.ReadImage(config.get_labels_path(2))\n",
"img_array_1 = sitk.GetArrayFromImage(volume_1)\n",
"img_array_2 = sitk.GetArrayFromImage(volume_2)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "63dd92db3829d7db",
"metadata": {},
"outputs": [],
"source": [
"volume_size = volume_1.GetSize()\n",
"\n",
Expand All @@ -159,36 +158,36 @@
" render_mode=\"animation\",\n",
" n_stack=2,\n",
").create_env(EnvMode.WATCH)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "16a139f61aaafd19",
"metadata": {},
"outputs": [],
"source": [
"env_rollout = walk_through_env(env, 10)\n",
"\n",
"plot_rollout_rewards(env_rollout)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cdf855cc85a743a",
"metadata": {},
"outputs": [],
"source": [
"env.get_cur_animation_as_html()"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "519dde5f1cea8a5f",
"metadata": {},
"outputs": [],
"source": [
"volume_size = volume_1.GetSize()\n",
"\n",
Expand All @@ -206,54 +205,55 @@
" project_actions_to=\"y\",\n",
" apply_volume_transformation=True,\n",
").create_env(EnvMode.WATCH)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "22877ab71fed2eb0",
"metadata": {},
"outputs": [],
"source": [
"projected_env_rollout = walk_through_env(\n",
" projected_env,\n",
" 10,\n",
" render_title=\"Projected labelmap slice\",\n",
")\n",
"plot_rollout_rewards(projected_env_rollout)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2779884526e0716",
"metadata": {},
"outputs": [],
"source": [
"print(\n",
" \"Observed 'rewards': \\n\",\n",
" [round(obs[1][-1], 4) for obs in projected_env_rollout.observations],\n",
")\n",
"print(\"Env rewards: \\n\", [round(r, 4) for r in projected_env_rollout.rewards])"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ada94c94fe77de0",
"metadata": {},
"outputs": [],
"source": [
"projected_env.get_cur_animation_as_html()"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fb82c1487521b11",
"metadata": {},
"outputs": [],
"source": []
"source": [],
"outputs": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ lint = ["_black_check", "_ruff_check", "_ruff_check_nb"]
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
_poetry_sort = "poetry sort"
clean-nbs = "python docs/nbstripout.py"
format = ["_black_format", "_ruff_format", "_ruff_format_nb", "_poetry_install_sort_plugin", "_poetry_sort"]
carlocagnetta marked this conversation as resolved.
Show resolved Hide resolved
format = ["_black_format", "_ruff_format", "_ruff_format_nb"]
_autogen_rst = "python docs/autogen_rst.py"
_sphinx_build = "sphinx-build -W -b html docs docs/_build"
_jb_generate_toc = "python docs/create_toc.py"
Expand Down
7 changes: 4 additions & 3 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
num_test_envs=1,
buffer_size=100000,
batch_size=256,
step_per_collect=10,
update_per_step=100,
start_timesteps=500,
step_per_collect=200,
update_per_step=2,
start_timesteps=5000,
start_timesteps_random=True,
)

Expand All @@ -61,6 +61,7 @@
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
apply_volume_transformation=True
)

experiment = (
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_dqn_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=-1,
num_test_envs=10,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
Expand Down
4 changes: 2 additions & 2 deletions scripts/armscan_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=-1,
num_test_envs=10,
num_train_envs=40,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
Expand Down
Loading
Loading