Skip to content

Commit

Permalink
Add src files for slicing, clustering, reward and visualizing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Mar 8, 2024
1 parent 6f89886 commit d24621f
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 2 deletions.
133 changes: 133 additions & 0 deletions src/armscan_env/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any

import numpy as np
from numpy import dtype, ndarray
from scipy.ndimage import label
from sklearn.cluster import DBSCAN


def find_clusters(tissue_value: int, slice: np.ndarray) -> list[dict]:
"""Find clusters of a given tissue in a slice
:param tissue_value: value of the tissue to cluster
:param slice: image slice to cluster
:return: list of clusters and their centers.
"""
# Create a binary mask based on the threshold
binary_mask = slice == tissue_value

# Check if there are tissues with given label
if np.all(binary_mask is False):
print("No tissues to cluster. Please set values using set_values method.")
return []

# Label connected components in the binary mask
labeled_array, num_clusters = label(binary_mask)

# Extract clusters and their centers
cluster_data = []

for cluster_label in range(num_clusters):
cluster_indices = np.where(labeled_array == cluster_label + 1)
# Calculate the center of the cluster
center_x = np.mean(cluster_indices[0])
center_y = np.mean(cluster_indices[1])
center = (center_x, center_y)

# Save both the cluster and center under the same key
cluster_data.append(
{
"cluster": np.array(
list(zip(cluster_indices[0], cluster_indices[1], strict=False)),
),
"center": center,
},
)

return cluster_data


def cluster_iter(tissues: dict, slice: np.ndarray) -> dict:
"""Find clusters of all tissues in a slice
:param tissues: dictionary of tissues and their values
:param slice: image slice to cluster
:return: dictionary of tissues and their clusters.
"""
# store clusters of tissues in a dict
tissues_clusters = {}

for tissue in tissues:
print(f"Finding {tissue} clusters, with value {tissues[tissue]}:")
tissues_clusters[tissue] = find_clusters(tissues[tissue], slice)

print(f"Found {len(tissues_clusters[tissue])} clusters\n")
print("---------------------------------------\n")
return tissues_clusters


def find_DBSCAN_clusters(
tissue_value: int,
slice: np.ndarray,
eps: float,
min_samples: int,
) -> list[Any] | list[dict[str, ndarray[Any, dtype[Any]] | Any]]:
"""Find clusters of a given tissue in a slice using DBSCAN
:param tissue_value: value of the tissue to cluster
:param slice: image slice to cluster
:param eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other.
:param min_samples: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
:return: list of clusters and their centers.
"""
# binary filter for the tissue_value
binary_mask = slice == tissue_value

# Check if there are tissues with given tissue_value
if np.all(binary_mask == 0):
print("No tissues to cluster with given value.")
return []

# find label positions, upon which clustering wil be defined
label_positions = np.array(list(zip(*np.where(binary_mask))))
# define clusterer
clusterer = DBSCAN(eps=eps, min_samples=min_samples)

# find cluster prediction
clusters = clusterer.fit_predict(label_positions)
n_clusters = (
len(np.unique(clusters)) - 1
) # noise cluster has label -1, we don't take it into account
print(f"Found {n_clusters} clusters")

# Extract clusters and their centers
cluster_data = []

for cluster in range(n_clusters):
label_to_pos_array = label_positions[clusters == cluster] # get positions of each cluster
cluster_centers = np.mean(label_to_pos_array, axis=0) # mean of each column
# Save both the cluster and center under the same key
cluster_data.append({"cluster": label_to_pos_array, "center": cluster_centers})

return cluster_data


# TODO: set different parameters for each tissue
def DBSCAN_cluster_iter(tissues: dict, slice: np.ndarray, eps: float, min_samples: int) -> dict:
"""Find clusters of all tissues in a slice using DBSCAN
:param tissues: dictionary of tissues and their values
:param slice: image slice to cluster
:param eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other.
:param min_samples: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
:return: dictionary of tissues and their clusters.
"""
# store clusters of tissues in a dict
tissues_clusters = {}

for tissue in tissues:
print(f"Finding {tissue} clusters, with value {tissues[tissue]}:")
# find clusters for each tissue
tissues_clusters[tissue] = find_DBSCAN_clusters(tissues[tissue], slice, eps, min_samples)

# print the identified clusters and their centers
for index, data in enumerate(tissues_clusters[tissue]):
print(f"Center of {tissue} cluster {index}: {data['center']}")
print("---------------------------------------\n")
return tissues_clusters
2 changes: 1 addition & 1 deletion src/armscan_env/envs/labelmaps_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LabelmapStateAction(StateAction):


class LabelmapSliceObservation(ArrayObservation[LabelmapStateAction]):
def __init__(self, slice_shape: tuple[int, int]):
def __init__(self, slice_shape: tuple[int, int], render_mode: str | None = None):
""":param slice_shape: slices will be cropped to this shape (we need a consistent observation space)."""
self._slice_shape = slice_shape
self._observation_space = gym.spaces.Box(low=0, high=1, shape=slice_shape)
Expand Down
88 changes: 88 additions & 0 deletions src/armscan_env/envs/rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# my custom loss function for image navigation
from typing import Any

import numpy as np


def anatomy_based_rwd(tissue_clusters: dict, n_landmarks: list = [5, 2, 1]) -> np.ndarray[Any, np.dtype[
np.floating[Any]]]:
"""Calculate the reward based on the presence and location of anatomical landmarks
:param tissue_clusters: dictionary of tissues and their clusters
:param n_landmarks: number of landmarks for each tissue
:return: reward value.
"""
print("####################################################")
print("Calculating loss function:")

# Presence of landmark tissues:

bones_loss = abs(len(tissue_clusters["bones"]) - n_landmarks[0])
ligament_loss = abs(len(tissue_clusters["tendons"]) - n_landmarks[1])
ulnar_loss = abs(len(tissue_clusters["ulnar"]) - n_landmarks[2])

landmark_loss = bones_loss + ligament_loss + ulnar_loss

# Absence of landmarks:
missing_landmark_loss = 0

# Location of landmarks:
location_loss = 1

# There must be bones:
if len(tissue_clusters["bones"]) != 0:
# Get centers of tissue clusters:
bones_centers = [cluster["center"] for _, cluster in enumerate(tissue_clusters["bones"])]
bones_centers_mean = np.mean(bones_centers, axis=0)

# There must be tendons:
if len(tissue_clusters["tendons"]) != 0:
# Get centers of tissue clusters:
ligament_centers = [
cluster["center"] for _, cluster in enumerate(tissue_clusters["tendons"])
]
ligament_centers_mean = np.mean(ligament_centers, axis=0)

# Check the orientation of the arm:
# The bones center might be over or under the tendons center depending on the origin
if bones_centers_mean[0] > ligament_centers_mean[0]:
print("Orientation: bones over tendons")
orientation = -1
else:
print("Orientation: bones under tendons")
orientation = 1

# There must be one ulnar artery:
if len(tissue_clusters["ulnar"]) == 1:
# There must be only one ulnar tissue so there is no need to take the mean
ulnar_center = tissue_clusters["ulnar"][0]["center"]

# Ulnar artery must be over tendons in the positive orientation:
if orientation * ulnar_center[0] > orientation * ligament_centers_mean[0]:
location_loss = 0
else:
print("Ulnar center not where expected")

# if no ulnar artery
else:
missing_landmark_loss = 1
print("No ulnar artery found")
# if no tendons
else:
missing_landmark_loss = 2
print("No tendons found")
# if no bones:
else:
missing_landmark_loss = 3
print("No bones found")

# Loss is bounded between 0 and 1
loss = (1 / 3) * (0.1 * landmark_loss + (1 / 3) * missing_landmark_loss + location_loss)

print(f"Landmark loss: {landmark_loss}")
print(f"Missing landmark loss: {missing_landmark_loss}")
print(f"Location loss: {location_loss}")
print(f"Total loss: {loss}")

print("#################################################### \n")

return loss
2 changes: 1 addition & 1 deletion src/armscan_env/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def slice_volume(
:param volume: 3D volume to be sliced
:return: the sliced volume.
"""
# Euler transformation
# Euler's transformation
# Rotation is defined by three rotations around z1, x2, z2 axis
th_z1 = np.deg2rad(z_rotation)
th_x2 = np.deg2rad(x_rotation)
Expand Down
134 changes: 134 additions & 0 deletions src/armscan_env/util/visualizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.image import AxesImage


def _show(
slices: list,
start: int,
lap: int,
col: int = 5,
cmap: str | None = None,
aspect: int = 6,
) -> AxesImage:
"""Function to display row of image slices
:param slices: list of image slices
:param start: starting slice number
:param lap: number of slices to skip
:param col: number of columns to display
:param cmap: color map to use
:param aspect: aspect ratio of each image
:return: None.
"""
rows = -(-len(slices) // col)
fig, ax = plt.subplots(rows, col, figsize=(15, 2 * rows))
# Flatten the ax array to simplify indexing
ax = ax.flatten()
for i, slice in enumerate(slices):
ax[i].imshow(slice, cmap=cmap, origin="lower", aspect=aspect)
ax[i].set_title(f"Slice {start - i * lap}") # Set titles if desired
# Adjust layout to prevent overlap of titles
plt.tight_layout()
return ax


def show_slices(
data: np.ndarray,
start: int,
end: int,
lap: int,
col: int = 5,
cmap: str | None = None,
aspect: int = 6,
) -> AxesImage:
"""Function to display row of image slices
:param data: 3D image data
:param start: starting slice number
:param end: ending slice number
:param lap: number of slices to skip
:param col: number of columns to display
:param cmap: color map to use
:param aspect: aspect ratio of each image
:return: None.
"""
it = 0
slices = []
for slice in range(start, 0, -lap):
it += 1
slices.append(data[:, slice, :])
if it == end:
break
return _show(slices, start, lap, col, cmap, aspect)


def show_cluster_centers(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage:
"""Plot the centers of the clusters of all tissues in a slice
:param tissue_clusters: dictionary of tissues and their clusters
:param slice: image slice to cluster
:param ax: axis to plot on
:return: None.
"""
ax = ax or plt.gca()

for tissue in tissue_clusters:
for _label, data in enumerate(tissue_clusters[tissue]):
# plot clusters with different colors
ax.scatter(
data["center"][1],
data["center"][0],
color="red",
marker="*",
s=20,
) # plot centers

ax.imshow(slice, aspect=6, origin="lower")
return ax


def show_clusters(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage:
"""Plot the clusters of all tissues in a slice
:param tissue_clusters: dictionary of tissues and their clusters
:param slice: image slice to cluster
:param ax: axis to plot on
:return: None.
"""
ax = ax or plt.gca()

# create an empty array for cluster labels
cluster_labels = slice.copy()

for tissue in tissue_clusters:
for label, data in enumerate(tissue_clusters[tissue]):
# plot clusters with different colors
cluster_labels[tuple(data["cluster"].T)] = (label + 1) * 10
ax.scatter(data["center"][1], data["center"][0], color="red", marker="*", s=20)

ax.imshow(cluster_labels, aspect=6, origin="lower")
return ax


def show_only_clusters(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage:
"""Plot only the clusters of all tissues in a slice
:param tissue_clusters: dictionary of tissues and their clusters
:param slice: image slice to cluster
:param ax: axis to plot on
:return: None.
"""
ax = ax or plt.gca()

# create an empty array for cluster labels
cluster_labels = np.ones_like(slice) * 0

for tissue in tissue_clusters:
for label, data in enumerate(tissue_clusters[tissue]):
# plot clusters with different colors
cluster_labels[tuple(data["cluster"].T)] = (label + 1) * 10
ax.scatter(
data["center"][1],
data["center"][0],
color="red",
marker="*",
s=20,
) # plot centers
ax.imshow(cluster_labels, aspect=6, origin="lower")
return ax

0 comments on commit d24621f

Please sign in to comment.