Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix volume transformation, add observation wrapper #7

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5138bd7
fix observations on MultiBoxSpace
carlocagnetta Jun 28, 2024
e0eee64
fix volume transformation,
carlocagnetta Jun 28, 2024
2e55354
fix clustering
carlocagnetta Jun 28, 2024
5622434
new volumes src
carlocagnetta Jun 28, 2024
893e4c9
AddRewardDetailsWrapper
carlocagnetta Jun 28, 2024
59bb759
fix test slicing
carlocagnetta Jun 28, 2024
7c93fd3
fix test slicing
carlocagnetta Jun 28, 2024
50c516f
fix experiment parameters
carlocagnetta Jun 28, 2024
7a19241
merged
carlocagnetta Jun 28, 2024
4eafefb
Fixed slicing manipulator action in env
carlocagnetta Jun 28, 2024
d5f0299
Add by mistake
carlocagnetta Jun 28, 2024
5050cdb
fix optimal action 2
carlocagnetta Jun 30, 2024
1a259e2
transforming optimal action, and referencing back when slicing
carlocagnetta Jun 30, 2024
d54aaf3
Actions: improve triggers
Jul 1, 2024
a6c6958
Added a safety feature to TransformedVolume
Jul 1, 2024
f005d7f
small notebook fixes
carlocagnetta Jul 1, 2024
569cb65
Add notebooks folder
carlocagnetta Jul 1, 2024
b72f0ed
ToDo: caching
carlocagnetta Jul 3, 2024
11ffb07
add notebooks to poe tasks
carlocagnetta Jul 3, 2024
683ed47
changed random volume transformation to take action as parameter
carlocagnetta Jul 3, 2024
0c0bbde
spelling
carlocagnetta Jul 3, 2024
eecf76b
using Framestack before AddRewardDetails, including flatte observatio…
carlocagnetta Jul 3, 2024
517421d
script for debugging
carlocagnetta Jul 3, 2024
972bd05
loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
a6ccc46
spelling
carlocagnetta Jul 4, 2024
b57769b
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
46b182c
fixup! loading volumes and standardizing volumes spacing
carlocagnetta Jul 4, 2024
fb06717
test optimal action
carlocagnetta Jul 4, 2024
ed1f525
Changed slicing to use standard sitk transform
carlocagnetta Jul 12, 2024
9c39076
add 2 DoF
carlocagnetta Jul 12, 2024
d1c068c
Add more volumes
carlocagnetta Jul 15, 2024
09ea174
turned volume
carlocagnetta Jul 15, 2024
7362d4f
tuned labelmaps
carlocagnetta Jul 15, 2024
f659d59
Add more volumes
carlocagnetta Jul 15, 2024
511a051
add volume 11
carlocagnetta Jul 18, 2024
70d9de8
Create ImageVolume class for labelmaps with set optimal action
carlocagnetta Jul 18, 2024
baebae8
Merge branch 'volume_rand_transformation' of https://github.com/appli…
carlocagnetta Jul 18, 2024
a2070a5
Fixed clustering, restricted termination, added volume
carlocagnetta Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions docs/02_notebooks/L3_slicing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,11 @@
"metadata": {},
"outputs": [],
"source": [
"from armscan_env.slicing import slice_volume\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.volumes.slicing import slice_volume\n",
"\n",
"sliced_volume = slice_volume(\n",
" z_rotation=19.3,\n",
" x_rotation=0.0,\n",
" x_trans=0.0,\n",
" y_trans=140.0,\n",
" action=ManipulatorAction(rotation=(19.3, 0), translation=(0, 140)),\n",
" volume=volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
")\n",
Expand Down Expand Up @@ -323,9 +321,10 @@
" sliced_volume = slice_volume(\n",
" volume=volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
" z_rotation=z[i],\n",
" x_rotation=0,\n",
" y_trans=t[i],\n",
" action=ManipulatorAction(\n",
" rotation=(z[i], 0),\n",
" translation=(0, t[i]),\n",
" ),\n",
" )\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]\n",
" ax2.set_title(f\"Slice {i}\")\n",
Expand Down
22 changes: 10 additions & 12 deletions docs/02_notebooks/L4_environment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.config import get_config\n",
"from armscan_env.envs.rewards import anatomy_based_rwd\n",
"from armscan_env.slicing import slice_volume\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.util.visualizations import show_clusters\n",
"from armscan_env.volumes.slicing import slice_volume\n",
"from IPython.core.display import HTML\n",
"\n",
"config = get_config()"
Expand Down Expand Up @@ -110,8 +111,7 @@
" sliced_volume = slice_volume(\n",
" volume=volume_1,\n",
" slice_shape=slice_shape,\n",
" z_rotation=z[i],\n",
" y_trans=t[i],\n",
" action=ManipulatorAction(rotation=(z[i], 0.0), translation=(0.0, t[i])),\n",
" )\n",
" sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :].T\n",
" ax2.imshow(sliced_img.T, aspect=6, origin=\"lower\")\n",
Expand Down Expand Up @@ -195,14 +195,19 @@
" termination_criterion=LabelmapEnvTerminationCriterion(),\n",
" max_episode_len=10,\n",
" rotation_bounds=(30.0, 10.0),\n",
" translation_bounds=(0.0, None),\n",
" translation_bounds=(None, None),\n",
" render_mode=\"animation\",\n",
" apply_volume_transformation=True,\n",
")\n",
"\n",
"observation, info = env.reset()\n",
"for _ in range(50):\n",
" action = env.action_space.sample()\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" epsilon = 0.1\n",
" if np.random.rand() > epsilon:\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" else:\n",
" observation, reward, terminated, truncated, info = env.step_to_optimal_state()\n",
" env.render()\n",
"\n",
" if terminated or truncated:\n",
Expand All @@ -219,13 +224,6 @@
"source": [
"HTML(animation.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L5_linear_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
"volume_size = volume_1.GetSize()\n",
"\n",
"projected_env = ArmscanEnvFactory(\n",
" name2volume={\"1\": volume_1},\n",
" name2volume={\"2\": volume_2},\n",
" observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),\n",
" slice_shape=(volume_size[0], volume_size[2]),\n",
" reward_metric=LabelmapClusteringBasedReward(),\n",
Expand Down
7 changes: 4 additions & 3 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
num_test_envs=1,
buffer_size=100000,
batch_size=256,
step_per_collect=10,
update_per_step=100,
start_timesteps=500,
step_per_collect=200,
update_per_step=2,
start_timesteps=5000,
start_timesteps_random=True,
)

Expand All @@ -61,6 +61,7 @@
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
apply_volume_transformation=True,
)

experiment = (
Expand Down
2 changes: 1 addition & 1 deletion scripts/armscan_dqn_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=-1,
num_test_envs=10,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
Expand Down
4 changes: 2 additions & 2 deletions scripts/armscan_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=-1,
num_test_envs=10,
num_train_envs=40,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
Expand Down
190 changes: 190 additions & 0 deletions scripts/random_volume_transformations.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
{
carlocagnetta marked this conversation as resolved.
Show resolved Hide resolved
"cells": [
{
"metadata": {},
"cell_type": "code",
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"id": "60c69d9345beb9d0",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import SimpleITK as sitk\n",
"from armscan_env import config\n",
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.util.visualizations import show_clusters\n",
"from armscan_env.volumes.slicing import EulerTransform, slice_volume, transform_volume\n",
"\n",
"config = config.get_config()"
],
"id": "bf5c60e86d1e8e19",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"volume = sitk.ReadImage(config.get_labels_path(1))\n",
"volume_img = sitk.GetArrayFromImage(volume)\n",
"plt.imshow(volume_img[40, :, :])\n",
"action = ManipulatorAction(rotation=(19, 0), translation=(0, 140))\n",
"\n",
"o = volume.GetOrigin()\n",
"x_dash = np.arange(volume_img.shape[2])\n",
"b = volume.TransformPhysicalPointToIndex([o[0], o[1] + action.translation[1], o[2]])[1]\n",
"y_dash = x_dash * np.tan(np.deg2rad(action.rotation[0])) + b\n",
"plt.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")\n",
"\n",
"plt.show()"
],
"id": "ae347cf3897968a6",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"sliced_volume = slice_volume(\n",
" action=action,\n",
" volume=volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
")\n",
"sliced_img = sitk.GetArrayFromImage(sliced_volume)[:, 0, :]\n",
"print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n",
"\n",
"slice = sliced_img\n",
"plt.imshow(slice, aspect=6)\n",
"plt.show()"
],
"id": "cb9c333a74781d5a",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"volume_transformation = ManipulatorAction(rotation=(19, 0), translation=(15, 15))\n",
"transformed_volume = transform_volume(volume, volume_transformation)\n",
"transformed_action = EulerTransform(volume_transformation).transform_action(action)"
],
"id": "26ffcc6d7dece611",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"transformed_img = sitk.GetArrayFromImage(transformed_volume)\n",
"plt.imshow(transformed_img[40, :, :])\n",
"\n",
"ot = transformed_volume.GetOrigin()\n",
"x_dash = np.arange(transformed_img.shape[2])\n",
"b = volume.TransformPhysicalPointToIndex([o[0], o[1] + transformed_action.translation[1], o[2]])[1]\n",
"y_dash = x_dash * np.tan(np.deg2rad(transformed_action.rotation[0])) + b\n",
"plt.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")\n",
"\n",
"plt.show()"
],
"id": "db20a3ff9556e8b4",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"sliced_transformed_volume = slice_volume(\n",
" action=transformed_action,\n",
" volume=transformed_volume,\n",
" slice_shape=(volume.GetSize()[0], volume.GetSize()[2]),\n",
")\n",
"sliced_img = sitk.GetArrayFromImage(sliced_transformed_volume)[:, 0, :]\n",
"print(f\"Slice value range: {np.min(sliced_img)} - {np.max(sliced_img)}\")\n",
"\n",
"slice = sliced_img\n",
"plt.imshow(slice, aspect=6)\n",
"plt.show()"
],
"id": "acda09e94c3f2f2b",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"volume_2 = sitk.ReadImage(config.get_labels_path(2))\n",
"volume_2_img = sitk.GetArrayFromImage(volume_2)\n",
"spacing = volume_2.GetSpacing()\n",
"plt.imshow(volume_2_img[51, :, :])\n",
"action_2 = ManipulatorAction(rotation=(5, 0), translation=(0, 112))\n",
"\n",
"o = volume_2.GetOrigin()\n",
"x_dash = np.arange(volume_2_img.shape[2])\n",
"b = volume_2.TransformPhysicalPointToIndex([o[0], o[1] + action_2.translation[1], o[2]])[1]\n",
"y_dash = x_dash * np.tan(np.deg2rad(action_2.rotation[0])) + b\n",
"plt.plot(x_dash, y_dash, linestyle=\"--\", color=\"red\")\n",
"\n",
"plt.show()"
],
"id": "fb6cfecff1cb7cd4",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"sliced_volume_2 = slice_volume(\n",
" action=action_2,\n",
" volume=volume_2,\n",
" slice_shape=(volume_2.GetSize()[0], volume_2.GetSize()[2]),\n",
")\n",
"sliced_img_2 = sitk.GetArrayFromImage(sliced_volume_2)[:, 0, :]\n",
"np.save(\"./array\", sliced_img_2)\n",
"\n",
"cluster = TissueClusters.from_labelmap_slice(sliced_img_2.T)\n",
"show_clusters(cluster, sliced_img_2.T, aspect=spacing[2] / spacing[0])\n",
"\n",
"plt.show()"
],
"id": "6462b823c7903838",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
11 changes: 6 additions & 5 deletions src/armscan_env/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def find_DBSCAN_clusters(self, labelmap_slice: np.ndarray) -> list["DataCluster"
case TissueLabel.BONES:
return find_DBSCAN_clusters(self, labelmap_slice, eps=4.1, min_samples=46)
case TissueLabel.TENDONS:
return find_DBSCAN_clusters(self, labelmap_slice, eps=4.1, min_samples=46)
return find_DBSCAN_clusters(self, labelmap_slice, eps=2.5, min_samples=15)
case TissueLabel.ULNAR:
return find_DBSCAN_clusters(self, labelmap_slice, eps=2.5, min_samples=18)
return find_DBSCAN_clusters(self, labelmap_slice, eps=2.0, min_samples=10)
case _:
raise ValueError(f"Unknown tissue label: {self}")

Expand Down Expand Up @@ -142,9 +142,10 @@ def find_DBSCAN_clusters(
label_positions = np.array(list(zip(*np.where(binary_mask), strict=True)))
clusterer = DBSCAN(eps=eps, min_samples=min_samples)
clusters = clusterer.fit_predict(label_positions)
n_clusters = (
len(np.unique(clusters)) - 1
) # noise cluster has label -1, we don't take it into account
if -1 in clusters:
n_clusters = len(np.unique(clusters)) - 1
else:
n_clusters = len(np.unique(clusters))
log.debug(f"Found {n_clusters} clusters")

cluster_list = []
Expand Down
16 changes: 7 additions & 9 deletions src/armscan_env/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,18 @@ def __init__(self, array_observations: list[ArrayObservation[TStateAction]]):
def compute_observation(self, state: TStateAction) -> np.ndarray:
return np.concatenate(
[obs.compute_observation(state) for obs in self.array_observations],
axis=0,
axis=1,
)

@cached_property
def observation_space(self) -> gym.spaces.Box:
return self.concatenate_boxes([obs.observation_space for obs in self.array_observations])

@staticmethod
def concatenate_boxes(boxes: list[gym.spaces.Box]) -> gym.spaces.Box:
return gym.spaces.Box(
low=np.concatenate(
[obs.observation_space.low for obs in self.array_observations],
axis=0,
),
high=np.concatenate(
[obs.observation_space.high for obs in self.array_observations],
axis=0,
),
low=np.concatenate([box.low for box in boxes], axis=0),
high=np.concatenate([box.high for box in boxes], axis=0),
)


Expand Down
Loading
Loading