Skip to content

Commit

Permalink
ToDo: caching
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Jul 3, 2024
1 parent 569cb65 commit b72f0ed
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion notebooks/random_volume_transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
"from armscan_env.clustering import TissueClusters\n",
"from armscan_env.envs.state_action import ManipulatorAction\n",
"from armscan_env.util.visualizations import show_clusters\n",
"from armscan_env.volumes.slicing import EulerTransform, get_volume_slice, create_transformed_volume\n",
"from armscan_env.volumes.slicing import (\n",
" EulerTransform,\n",
" create_transformed_volume,\n",
" get_volume_slice,\n",
")\n",
"\n",
"config = config.get_config()"
],
Expand Down
4 changes: 4 additions & 0 deletions src/armscan_env/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DataCluster:
datapoints: list[tuple[float, float]] | np.ndarray
center: tuple[np.floating[Any], np.floating[Any]]

# ToDo: Make a custom __hash__ method for the class that deals with lists


@dataclass(kw_only=True)
class TissueClusters:
Expand All @@ -46,6 +48,8 @@ class TissueClusters:
tendons: list[DataCluster]
ulnar: list[DataCluster]

# ToDo: Make a custom __hash__ method for the class that deals with lists

def get_cluster_for_label(self, label: TissueLabel) -> list[DataCluster]:
"""Get the clusters for a given tissue label."""
match label:
Expand Down
1 change: 1 addition & 0 deletions src/armscan_env/envs/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
log = logging.getLogger(__name__)


# ToDo: make a cache for the function
def anatomy_based_rwd(
tissue_clusters: TissueClusters,
n_landmarks: Sequence[int] = (4, 2, 1),
Expand Down
3 changes: 2 additions & 1 deletion src/armscan_env/volumes/slicing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any

import numpy as np
import SimpleITK as sitk
Expand Down Expand Up @@ -126,7 +127,7 @@ class TransformedVolume(sitk.Image):
Should only ever be instantiated by `create_transformed_volume`.
"""

def __init__(self, *args, transformation_action: ManipulatorAction, _private: int):
def __init__(self, *args: Any, transformation_action: ManipulatorAction, _private: int):
if _private != 42:
raise ValueError(
"TransformedVolume should only be instantiated by create_transformed_volume.",
Expand Down

0 comments on commit b72f0ed

Please sign in to comment.