-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add src files for slicing, clustering, reward and visualizing
- Loading branch information
1 parent
6f89886
commit d24621f
Showing
5 changed files
with
357 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from numpy import dtype, ndarray | ||
from scipy.ndimage import label | ||
from sklearn.cluster import DBSCAN | ||
|
||
|
||
def find_clusters(tissue_value: int, slice: np.ndarray) -> list[dict]: | ||
"""Find clusters of a given tissue in a slice | ||
:param tissue_value: value of the tissue to cluster | ||
:param slice: image slice to cluster | ||
:return: list of clusters and their centers. | ||
""" | ||
# Create a binary mask based on the threshold | ||
binary_mask = slice == tissue_value | ||
|
||
# Check if there are tissues with given label | ||
if np.all(binary_mask is False): | ||
print("No tissues to cluster. Please set values using set_values method.") | ||
return [] | ||
|
||
# Label connected components in the binary mask | ||
labeled_array, num_clusters = label(binary_mask) | ||
|
||
# Extract clusters and their centers | ||
cluster_data = [] | ||
|
||
for cluster_label in range(num_clusters): | ||
cluster_indices = np.where(labeled_array == cluster_label + 1) | ||
# Calculate the center of the cluster | ||
center_x = np.mean(cluster_indices[0]) | ||
center_y = np.mean(cluster_indices[1]) | ||
center = (center_x, center_y) | ||
|
||
# Save both the cluster and center under the same key | ||
cluster_data.append( | ||
{ | ||
"cluster": np.array( | ||
list(zip(cluster_indices[0], cluster_indices[1], strict=False)), | ||
), | ||
"center": center, | ||
}, | ||
) | ||
|
||
return cluster_data | ||
|
||
|
||
def cluster_iter(tissues: dict, slice: np.ndarray) -> dict: | ||
"""Find clusters of all tissues in a slice | ||
:param tissues: dictionary of tissues and their values | ||
:param slice: image slice to cluster | ||
:return: dictionary of tissues and their clusters. | ||
""" | ||
# store clusters of tissues in a dict | ||
tissues_clusters = {} | ||
|
||
for tissue in tissues: | ||
print(f"Finding {tissue} clusters, with value {tissues[tissue]}:") | ||
tissues_clusters[tissue] = find_clusters(tissues[tissue], slice) | ||
|
||
print(f"Found {len(tissues_clusters[tissue])} clusters\n") | ||
print("---------------------------------------\n") | ||
return tissues_clusters | ||
|
||
|
||
def find_DBSCAN_clusters( | ||
tissue_value: int, | ||
slice: np.ndarray, | ||
eps: float, | ||
min_samples: int, | ||
) -> list[Any] | list[dict[str, ndarray[Any, dtype[Any]] | Any]]: | ||
"""Find clusters of a given tissue in a slice using DBSCAN | ||
:param tissue_value: value of the tissue to cluster | ||
:param slice: image slice to cluster | ||
:param eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other. | ||
:param min_samples: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. | ||
:return: list of clusters and their centers. | ||
""" | ||
# binary filter for the tissue_value | ||
binary_mask = slice == tissue_value | ||
|
||
# Check if there are tissues with given tissue_value | ||
if np.all(binary_mask == 0): | ||
print("No tissues to cluster with given value.") | ||
return [] | ||
|
||
# find label positions, upon which clustering wil be defined | ||
label_positions = np.array(list(zip(*np.where(binary_mask)))) | ||
# define clusterer | ||
clusterer = DBSCAN(eps=eps, min_samples=min_samples) | ||
|
||
# find cluster prediction | ||
clusters = clusterer.fit_predict(label_positions) | ||
n_clusters = ( | ||
len(np.unique(clusters)) - 1 | ||
) # noise cluster has label -1, we don't take it into account | ||
print(f"Found {n_clusters} clusters") | ||
|
||
# Extract clusters and their centers | ||
cluster_data = [] | ||
|
||
for cluster in range(n_clusters): | ||
label_to_pos_array = label_positions[clusters == cluster] # get positions of each cluster | ||
cluster_centers = np.mean(label_to_pos_array, axis=0) # mean of each column | ||
# Save both the cluster and center under the same key | ||
cluster_data.append({"cluster": label_to_pos_array, "center": cluster_centers}) | ||
|
||
return cluster_data | ||
|
||
|
||
# TODO: set different parameters for each tissue | ||
def DBSCAN_cluster_iter(tissues: dict, slice: np.ndarray, eps: float, min_samples: int) -> dict: | ||
"""Find clusters of all tissues in a slice using DBSCAN | ||
:param tissues: dictionary of tissues and their values | ||
:param slice: image slice to cluster | ||
:param eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other. | ||
:param min_samples: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. | ||
:return: dictionary of tissues and their clusters. | ||
""" | ||
# store clusters of tissues in a dict | ||
tissues_clusters = {} | ||
|
||
for tissue in tissues: | ||
print(f"Finding {tissue} clusters, with value {tissues[tissue]}:") | ||
# find clusters for each tissue | ||
tissues_clusters[tissue] = find_DBSCAN_clusters(tissues[tissue], slice, eps, min_samples) | ||
|
||
# print the identified clusters and their centers | ||
for index, data in enumerate(tissues_clusters[tissue]): | ||
print(f"Center of {tissue} cluster {index}: {data['center']}") | ||
print("---------------------------------------\n") | ||
return tissues_clusters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# my custom loss function for image navigation | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
|
||
def anatomy_based_rwd(tissue_clusters: dict, n_landmarks: list = [5, 2, 1]) -> np.ndarray[Any, np.dtype[ | ||
np.floating[Any]]]: | ||
"""Calculate the reward based on the presence and location of anatomical landmarks | ||
:param tissue_clusters: dictionary of tissues and their clusters | ||
:param n_landmarks: number of landmarks for each tissue | ||
:return: reward value. | ||
""" | ||
print("####################################################") | ||
print("Calculating loss function:") | ||
|
||
# Presence of landmark tissues: | ||
|
||
bones_loss = abs(len(tissue_clusters["bones"]) - n_landmarks[0]) | ||
ligament_loss = abs(len(tissue_clusters["tendons"]) - n_landmarks[1]) | ||
ulnar_loss = abs(len(tissue_clusters["ulnar"]) - n_landmarks[2]) | ||
|
||
landmark_loss = bones_loss + ligament_loss + ulnar_loss | ||
|
||
# Absence of landmarks: | ||
missing_landmark_loss = 0 | ||
|
||
# Location of landmarks: | ||
location_loss = 1 | ||
|
||
# There must be bones: | ||
if len(tissue_clusters["bones"]) != 0: | ||
# Get centers of tissue clusters: | ||
bones_centers = [cluster["center"] for _, cluster in enumerate(tissue_clusters["bones"])] | ||
bones_centers_mean = np.mean(bones_centers, axis=0) | ||
|
||
# There must be tendons: | ||
if len(tissue_clusters["tendons"]) != 0: | ||
# Get centers of tissue clusters: | ||
ligament_centers = [ | ||
cluster["center"] for _, cluster in enumerate(tissue_clusters["tendons"]) | ||
] | ||
ligament_centers_mean = np.mean(ligament_centers, axis=0) | ||
|
||
# Check the orientation of the arm: | ||
# The bones center might be over or under the tendons center depending on the origin | ||
if bones_centers_mean[0] > ligament_centers_mean[0]: | ||
print("Orientation: bones over tendons") | ||
orientation = -1 | ||
else: | ||
print("Orientation: bones under tendons") | ||
orientation = 1 | ||
|
||
# There must be one ulnar artery: | ||
if len(tissue_clusters["ulnar"]) == 1: | ||
# There must be only one ulnar tissue so there is no need to take the mean | ||
ulnar_center = tissue_clusters["ulnar"][0]["center"] | ||
|
||
# Ulnar artery must be over tendons in the positive orientation: | ||
if orientation * ulnar_center[0] > orientation * ligament_centers_mean[0]: | ||
location_loss = 0 | ||
else: | ||
print("Ulnar center not where expected") | ||
|
||
# if no ulnar artery | ||
else: | ||
missing_landmark_loss = 1 | ||
print("No ulnar artery found") | ||
# if no tendons | ||
else: | ||
missing_landmark_loss = 2 | ||
print("No tendons found") | ||
# if no bones: | ||
else: | ||
missing_landmark_loss = 3 | ||
print("No bones found") | ||
|
||
# Loss is bounded between 0 and 1 | ||
loss = (1 / 3) * (0.1 * landmark_loss + (1 / 3) * missing_landmark_loss + location_loss) | ||
|
||
print(f"Landmark loss: {landmark_loss}") | ||
print(f"Missing landmark loss: {missing_landmark_loss}") | ||
print(f"Location loss: {location_loss}") | ||
print(f"Total loss: {loss}") | ||
|
||
print("#################################################### \n") | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
from matplotlib.image import AxesImage | ||
|
||
|
||
def _show( | ||
slices: list, | ||
start: int, | ||
lap: int, | ||
col: int = 5, | ||
cmap: str | None = None, | ||
aspect: int = 6, | ||
) -> AxesImage: | ||
"""Function to display row of image slices | ||
:param slices: list of image slices | ||
:param start: starting slice number | ||
:param lap: number of slices to skip | ||
:param col: number of columns to display | ||
:param cmap: color map to use | ||
:param aspect: aspect ratio of each image | ||
:return: None. | ||
""" | ||
rows = -(-len(slices) // col) | ||
fig, ax = plt.subplots(rows, col, figsize=(15, 2 * rows)) | ||
# Flatten the ax array to simplify indexing | ||
ax = ax.flatten() | ||
for i, slice in enumerate(slices): | ||
ax[i].imshow(slice, cmap=cmap, origin="lower", aspect=aspect) | ||
ax[i].set_title(f"Slice {start - i * lap}") # Set titles if desired | ||
# Adjust layout to prevent overlap of titles | ||
plt.tight_layout() | ||
return ax | ||
|
||
|
||
def show_slices( | ||
data: np.ndarray, | ||
start: int, | ||
end: int, | ||
lap: int, | ||
col: int = 5, | ||
cmap: str | None = None, | ||
aspect: int = 6, | ||
) -> AxesImage: | ||
"""Function to display row of image slices | ||
:param data: 3D image data | ||
:param start: starting slice number | ||
:param end: ending slice number | ||
:param lap: number of slices to skip | ||
:param col: number of columns to display | ||
:param cmap: color map to use | ||
:param aspect: aspect ratio of each image | ||
:return: None. | ||
""" | ||
it = 0 | ||
slices = [] | ||
for slice in range(start, 0, -lap): | ||
it += 1 | ||
slices.append(data[:, slice, :]) | ||
if it == end: | ||
break | ||
return _show(slices, start, lap, col, cmap, aspect) | ||
|
||
|
||
def show_cluster_centers(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage: | ||
"""Plot the centers of the clusters of all tissues in a slice | ||
:param tissue_clusters: dictionary of tissues and their clusters | ||
:param slice: image slice to cluster | ||
:param ax: axis to plot on | ||
:return: None. | ||
""" | ||
ax = ax or plt.gca() | ||
|
||
for tissue in tissue_clusters: | ||
for _label, data in enumerate(tissue_clusters[tissue]): | ||
# plot clusters with different colors | ||
ax.scatter( | ||
data["center"][1], | ||
data["center"][0], | ||
color="red", | ||
marker="*", | ||
s=20, | ||
) # plot centers | ||
|
||
ax.imshow(slice, aspect=6, origin="lower") | ||
return ax | ||
|
||
|
||
def show_clusters(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage: | ||
"""Plot the clusters of all tissues in a slice | ||
:param tissue_clusters: dictionary of tissues and their clusters | ||
:param slice: image slice to cluster | ||
:param ax: axis to plot on | ||
:return: None. | ||
""" | ||
ax = ax or plt.gca() | ||
|
||
# create an empty array for cluster labels | ||
cluster_labels = slice.copy() | ||
|
||
for tissue in tissue_clusters: | ||
for label, data in enumerate(tissue_clusters[tissue]): | ||
# plot clusters with different colors | ||
cluster_labels[tuple(data["cluster"].T)] = (label + 1) * 10 | ||
ax.scatter(data["center"][1], data["center"][0], color="red", marker="*", s=20) | ||
|
||
ax.imshow(cluster_labels, aspect=6, origin="lower") | ||
return ax | ||
|
||
|
||
def show_only_clusters(tissue_clusters: dict, slice: np.ndarray, ax: AxesImage = None) -> AxesImage: | ||
"""Plot only the clusters of all tissues in a slice | ||
:param tissue_clusters: dictionary of tissues and their clusters | ||
:param slice: image slice to cluster | ||
:param ax: axis to plot on | ||
:return: None. | ||
""" | ||
ax = ax or plt.gca() | ||
|
||
# create an empty array for cluster labels | ||
cluster_labels = np.ones_like(slice) * 0 | ||
|
||
for tissue in tissue_clusters: | ||
for label, data in enumerate(tissue_clusters[tissue]): | ||
# plot clusters with different colors | ||
cluster_labels[tuple(data["cluster"].T)] = (label + 1) * 10 | ||
ax.scatter( | ||
data["center"][1], | ||
data["center"][0], | ||
color="red", | ||
marker="*", | ||
s=20, | ||
) # plot centers | ||
ax.imshow(cluster_labels, aspect=6, origin="lower") | ||
return ax |