From f6f21ca2480762630f07744f37986b015ccc9898 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:04:19 -0700 Subject: [PATCH 1/2] feat: support merge cnt after correction (#78) Co-authored-by: anna-grim --- .../graph_utils.py | 6 +- .../skeleton_metric.py | 101 +++----- .../swc_utils.py | 232 ++++++++---------- 3 files changed, 150 insertions(+), 189 deletions(-) diff --git a/src/segmentation_skeleton_metrics/graph_utils.py b/src/segmentation_skeleton_metrics/graph_utils.py index 885f4ec..32eab2a 100644 --- a/src/segmentation_skeleton_metrics/graph_utils.py +++ b/src/segmentation_skeleton_metrics/graph_utils.py @@ -15,7 +15,7 @@ from segmentation_skeleton_metrics import utils -MIN_CNT = 30 +MIN_CNT = 20 # --- Update graph structure --- @@ -217,7 +217,7 @@ def compute_run_lengths(graph): return np.array(run_lengths) -def compute_run_length(graph, apply=True): +def compute_run_length(graph, img_coords_bool=True): """ Computes path length of graph. @@ -236,7 +236,7 @@ def compute_run_length(graph, apply=True): for i, j in nx.dfs_edges(graph): xyz_1 = graph.nodes[i]["xyz"] xyz_2 = graph.nodes[j]["xyz"] - if apply: + if img_coords_bool: xyz_1 = utils.to_world(xyz_1) xyz_2 = utils.to_world(xyz_2) path_length += get_dist(xyz_1, xyz_2) diff --git a/src/segmentation_skeleton_metrics/skeleton_metric.py b/src/segmentation_skeleton_metrics/skeleton_metric.py index baee423..7fa503a 100644 --- a/src/segmentation_skeleton_metrics/skeleton_metric.py +++ b/src/segmentation_skeleton_metrics/skeleton_metric.py @@ -16,7 +16,6 @@ from time import time from zipfile import ZipFile -import networkx as nx import numpy as np import tensorstore as ts from scipy.spatial import KDTree @@ -30,6 +29,7 @@ ) from segmentation_skeleton_metrics.merge_detection import find_sites +ANISOTROPY = [0.748, 0.748, 1.0] MERGE_DIST_THRESHOLD = 20 @@ -54,6 +54,7 @@ def __init__( target_swc_paths, anisotropy=[1.0, 1.0, 1.0], connections_path=None, + merged_ids_path=None, ignore_boundary_mistakes=False, output_dir=None, pred_swc_paths=None, @@ -75,9 +76,12 @@ def __init__( anisotropy : list[float], optional Image to real-world coordinates scaling factors applied to swc files at "target_swc_paths". The default is [1.0, 1.0, 1.0]. - connections_path : list[tuple] + connections_path : str, optional Path to a txt file containing pairs of segment ids of segments - that were merged into a single segment. + that were merged into a single segment. The default is None. + merged_ids_path : str, optional + Path to txt file that contains segment ids that correspond to + merge mistakes. The default is None. ignore_boundary_mistakes : bool, optional Indication of whether to ignore mistakes near boundary of bounding box. The default is False. @@ -94,7 +98,10 @@ def __init__( The purpose of this argument is to account for segments that were removed due to thresholding by path length. The default is None. save_projections: bool, optional - ... + Indication of whether to save fragments that 'project' onto the + ground truth neurons (i.e. there exists a node in a graph from + "self.graphs" that is labeled with a given fragment id. The + default is None. save_sites, : bool, optional Indication of whether to write merge sites to an swc file. The default is False. @@ -236,42 +243,42 @@ def set_labels(self, graph): label_to_nodes[label] = set([i]) return label_to_nodes - def read_label(self, coord): + def read_label(self, voxel): """ - Gets label at image coordinates "coord". + Gets label at image coordinate "voxel". Parameters ---------- - coord : tuple[int] - Coordinates that indexes into "self.label_mask". + voxel : tuple[int] + Image coordinate that indexes into "self.label_mask". Returns ------- int - Label at image coordinates "coord". + Label of voxel. """ if type(self.label_mask) == ts.TensorStore: - return int(self.label_mask[coord].read().result()) + return int(self.label_mask[voxel].read().result()) else: - return self.label_mask[coord] + return self.label_mask[voxel] - def get_label(self, coord, return_node=False): + def get_label(self, voxel, return_node=False): """ - Gets label of voxel at "coord". + Gets label of voxel at "voxel". Parameters ---------- - coord : numpy.ndarray + voxel : numpy.ndarray Image coordinate of voxel to be read. Returns ------- int - Label of voxel at "coord". + Label of voxel. """ - label = self.read_label(coord) + label = self.read_label(voxel) if return_node: return return_node, self.validate(label) else: @@ -381,7 +388,10 @@ def load_fragments(self): # Read fragments t0 = time() print("Loading Fragments") - reader = swc_utils.Reader(return_graphs=True) + anisotropy = [1.0 / a_i for a_i in ANISOTROPY] # hard coded + reader = swc_utils.Reader( + anisotropy=anisotropy, img_coords_bool=False, return_graphs=True + ) fragment_graphs = reader.load_from_local_zip(self.pred_swc_paths) # Filter fragments @@ -414,6 +424,8 @@ def filter_fragments(self, fragment_graphs): """ labels = set.union(*list(self.graph_to_labels.values())) + if self.inv_label_map: + labels = set.union(*[self.inv_label_map[l] for l in labels]) return {l: fragment_graphs[l] for l in labels if l in fragment_graphs} def init_fragment_arrays(self): @@ -524,6 +536,7 @@ def compute_projected_run_lengths(self): # Compute run lengths t0 = time() + print("Computing Run Lengths") for key, labels in self.graph_to_labels.items(): target_rl = self.get_run_length(key) projected_rl = self.compute_projected_run_length( @@ -642,6 +655,8 @@ def detect_merges(self): self.merged_cnts = self.init_counter() self.merged_percent = self.init_counter() self.merged_labels = set() + + # Check whether to delete prexisting merges # add conditional that checks whether detected merges file exists # to adjust for scenario when evaluating a corrected segmentation @@ -681,58 +696,20 @@ def is_valid_merge(self, key_1, key_2, label): return True if is_valid else False def count_merges(self, key, kdtree): - for label in self.graph_to_labels[key]: + # Get labels + labels = self.graph_to_labels[key] + if self.inv_label_map: + labels = set.union(*[self.inv_label_map[l] for l in labels]) + + for label in labels: if label in self.fragment_arrays: for xyz in self.fragment_arrays[label][::4]: d, _ = kdtree.query(xyz, k=1) - if d > 40: + if d > 100: self.merge_cnt[key] += 1 self.merged_labels.add(label) break - def localize_merges_old(self, detected_merges): - """ - Searches for exact site where a merge occurs. - - Parameters - ---------- - detected_merges : list[tuple[str, str, int]] - Merge sites indicated by a pair of keys and label. - - Returns - ------- - None - - """ - with ProcessPoolExecutor() as executor: - # Assign processes - processes = [] - for (key_1, key_2), label in detected_merges: - processes.append( - executor.submit( - merge_detection.localize, - self.graphs[key_1], - self.graphs[key_2], - self.key_to_label_to_nodes[key_1][label], - self.key_to_label_to_nodes[key_2][label], - MERGE_DIST_THRESHOLD, - ((key_1, key_2), label), - ) - ) - - # Compile results - cnt = 1 - for i, process in enumerate(as_completed(processes)): - # Check site - site, d = process.result() - if d < MERGE_DIST_THRESHOLD: - self.save_merge_site(site[0], site[1]) - - # Report process - if i >= cnt * len(processes) * 0.02: - utils.progress_bar(i + 1, len(processes)) - cnt += 1 - def save_merge_site(self, xyz_1, xyz_2): """ Saves the site where a merge is located by writing the xyz coordinates diff --git a/src/segmentation_skeleton_metrics/swc_utils.py b/src/segmentation_skeleton_metrics/swc_utils.py index fa8c6c2..189a520 100644 --- a/src/segmentation_skeleton_metrics/swc_utils.py +++ b/src/segmentation_skeleton_metrics/swc_utils.py @@ -34,7 +34,11 @@ class Reader: """ def __init__( - self, anisotropy=[1.0, 1.0, 1.0], min_size=0, return_graphs=False + self, + anisotropy=[1.0, 1.0, 1.0], + img_coords_bool=True, + min_size=0, + return_graphs=False, ): """ Initializes a Reader object that loads swc files. @@ -45,6 +49,9 @@ def __init__( Image to world scaling factors applied to xyz coordinates to account for anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. + img_coords_bool : bool, optional + Indication of whether node xyz coordinates coorespond to voxels or + world. The default is True. min_size : int, optional Threshold on the number of nodes in swc file. Only swc files with more than "min_size" nodes are stored in "xyz_coords". The default @@ -59,6 +66,7 @@ def __init__( """ self.anisotropy = anisotropy + self.img_coords_bool = img_coords_bool self.min_size = min_size self.return_graphs = return_graphs @@ -69,7 +77,7 @@ def load_from_local_paths(self, swc_paths): Paramters --------- - swc_paths : list or dict + swc_paths : list List of paths to swc files stored on the local machine. Returns @@ -85,10 +93,10 @@ def load_from_local_paths(self, swc_paths): if len(content) > self.min_size: key = utils.get_id(path) if self.return_graphs: - swc_dict[key] = get_graph(content, self.anisotropy) + swc_dict[key] = self.get_graph(content) swc_dict[key].graph["filename"] = os.path.basename(path) else: - swc_dict[key] = get_coords(content, self.anisotropy) + swc_dict[key] = self.get_coords(content) return swc_dict def load_from_local_zip(self, zip_path): @@ -111,7 +119,6 @@ def load_from_local_zip(self, zip_path): that swc file. """ - anisotropy = [1.0 / val for val in ANISOTROPY] # hard coded cnt = 1 swc_dict = dict() with ZipFile(zip_path, "r") as zip: @@ -123,10 +130,10 @@ def load_from_local_zip(self, zip_path): if len(content) > self.min_size: key = utils.get_id(f) if self.return_graphs: - swc_dict[key] = get_graph(content, anisotropy) + swc_dict[key] = self.get_graph(content) swc_dict[key].graph["filename"] = f else: - swc_dict[key] = get_coords(content, anisotropy) + swc_dict[key] = self.get_coords(content) # Report progress if i >= cnt * chunk_size: @@ -233,14 +240,105 @@ def load_from_cloud_zipped_file(self, zip_file, path): if len(content) > self.min_size: key = utils.get_id(path) if self.return_graphs: - graph = get_graph(content, self.anisotropy) + graph = self.get_graph(content) graph.graph["filename"] = os.path.basename(path) return {key: graph} else: - return {key: get_coords(content, self.anisotropy)} + return {key: self.get_coords(content)} else: return dict() + def get_coords(self, content): + """ + Gets the xyz coords from the an swc file that has been read and stored + as "content". + + Parameters + ---------- + content : list[str] + Entries from swc where each item is the text string from an swc. + anisotropy : list[float] + Image to world scaling factors applied to xyz coordinates to + account for anisotropy of the microscope. + + Returns + ------- + numpy.ndarray + xyz coords from an swc file. + + """ + coords_list = [] + offset = [0, 0, 0] + for line in content: + if line.startswith("# OFFSET"): + parts = line.split() + offset = self.read_xyz(parts[2:5], offset) + if not line.startswith("#"): + parts = line.split() + coords_list.append(self.read_xyz(parts[2:5], offset)) + return np.array(coords_list) + + def read_xyz(self, xyz_str, offset): + """ + Reads the xyz coordinates from an swc file, then transforms the + coordinates with respect to "anisotropy" and "offset". + + Parameters + ---------- + xyz_str : str + xyz coordinate stored in a str. + offset : list[int] + Offset of xyz coordinates in swc file. + + Returns + ------- + numpy.ndarray + xyz coordinates of an entry from an swc file. + + """ + xyz = np.zeros((3)) + for i in range(3): + xyz[i] = self.anisotropy[i] * (float(xyz_str[i]) + offset[i]) + return xyz.astype(int) + + def get_graph(self, content): + """ + Reads an swc file and builds an undirected graph from it. + + Parameters + ---------- + path : str + Path to swc file to be read. + + Returns + ------- + networkx.Graph + Graph built from an swc file. + + """ + # Build Gaph + graph = nx.Graph() + offset = [0, 0, 0] + for line in content: + if line.startswith("# OFFSET"): + parts = line.split() + offset = self.read_xyz(parts[2:5]) + if not line.startswith("#"): + parts = line.split() + child = int(parts[0]) + parent = int(parts[-1]) + xyz = self.read_xyz(parts[2:5], offset=offset) + graph.add_node(child, xyz=xyz) + if parent != -1: + graph.add_edge(parent, child) + + # Set graph-level attributes + graph.graph["number_of_edges"] = graph.number_of_edges() + graph.graph["run_length"] = gutils.compute_run_length( + graph, self.img_coords_bool + ) + return graph + # -- write -- def save(path, xyz_1, xyz_2, color=None): @@ -309,7 +407,7 @@ def to_zipped_swc(zip_writer, graph, color=None): Parameters ---------- - zip_writer : ... + zip_writer : zipfile.ZipFile ... graph : networkx.Graph Graph to be written to an swc file. @@ -349,117 +447,3 @@ def to_zipped_swc(zip_writer, graph, color=None): # Finish zip_writer.writestr(graph.graph["filename"], text_buffer.getvalue()) - - -# -- utils -- -def get_coords(content, anisotropy=[1.0, 1.0, 1.0]): - """ - Gets the xyz coords from the an swc file that has been read and stored as - "content". - - Parameters - ---------- - content : list[str] - Entries in swc file where each entry is the text string from an swc. - anisotropy : list[float] - Image to world scaling factors applied to xyz coordinates to account - for anisotropy of the microscope. - - Returns - ------- - numpy.ndarray - xyz coords from an swc file. - - """ - coords_list = [] - offset = [0, 0, 0] - for line in content: - if line.startswith("# OFFSET"): - parts = line.split() - offset = read_xyz(parts[2:5], anisotropy, offset) - if not line.startswith("#"): - parts = line.split() - coords_list.append(read_xyz(parts[2:5], anisotropy, offset)) - return np.array(coords_list) - - -def read_xyz(xyz, anisotropy, offset): - """ - Reads the xyz coordinates from an swc file, then transforms the - coordinates with respect to "anisotropy" and "offset". - - Parameters - ---------- - xyz : str - xyz coordinate stored in a str. - anisotropy : list[float] - Image to real-world coordinates scaling factors applied to "xyz". - offset : list[int] - Offset of xyz coordinates in swc file. - - Returns - ------- - numpy.ndarray - xyz coordinates of an entry from an swc file. - - """ - xyz = [float(xyz[i]) + offset[i] for i in range(3)] - return np.array([xyz[i] * anisotropy[i] for i in range(3)], dtype=int) - - -def get_graph(content, anisotropy=[1.0, 1.0, 1.0]): - """ - Reads an swc file and builds an undirected graph from it. - - Parameters - ---------- - path : str - Path to swc file to be read. - anisotropy : list[float], optional - Image to real-world coordinates scaling factors for (x, y, z) that is - applied to swc files. The default is [1.0, 1.0, 1.0]. - - Returns - ------- - networkx.Graph - Graph built from an swc file. - - """ - # Build Gaph - graph = nx.Graph() - offset = [0, 0, 0] - for line in content: - if line.startswith("# OFFSET"): - parts = line.split() - offset = read_xyz(parts[2:5], anisotropy) - if not line.startswith("#"): - parts = line.split() - child = int(parts[0]) - parent = int(parts[-1]) - xyz = read_xyz(parts[2:5], anisotropy, offset=offset) - graph.add_node(child, xyz=xyz) - if parent != -1: - graph.add_edge(parent, child) - - # Set graph-level attributes - graph.graph["number_of_edges"] = graph.number_of_edges() - graph.graph["run_length"] = gutils.compute_run_length(graph) - return graph - - -def to_voxels(xyz): - """ - Converts coordinates from world to voxels. - - Parameters - ---------- - xyz : numpy.ndarray - Coordinate to be converted. - - Returns - ------- - tuple - Converted coordinates. - - """ - return tuple([xyz[i] * 1.0 / ANISOTROPY[i] for i in range(3)]) From cddb6c761196dd34493134a61f8927d9dcc2eade Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:04:34 +0000 Subject: [PATCH 2/2] ci: version bump [skip actions] --- src/segmentation_skeleton_metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/segmentation_skeleton_metrics/__init__.py b/src/segmentation_skeleton_metrics/__init__.py index a8ed27a..d24735e 100644 --- a/src/segmentation_skeleton_metrics/__init__.py +++ b/src/segmentation_skeleton_metrics/__init__.py @@ -2,4 +2,4 @@ Package to evaluate a predicted segmentation. """ -__version__ = "4.5.5" +__version__ = "4.6.0"