From f67c408f241588fcd62967ee9fe45e47ded76470 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:25:57 -0700 Subject: [PATCH] minor upds (#80) Co-authored-by: anna-grim --- .../graph_utils.py | 2 +- .../merge_detection.py | 89 ------------------- .../skeleton_metric.py | 55 ++++++++---- 3 files changed, 38 insertions(+), 108 deletions(-) delete mode 100644 src/segmentation_skeleton_metrics/merge_detection.py diff --git a/src/segmentation_skeleton_metrics/graph_utils.py b/src/segmentation_skeleton_metrics/graph_utils.py index 32eab2a..c6ef02d 100644 --- a/src/segmentation_skeleton_metrics/graph_utils.py +++ b/src/segmentation_skeleton_metrics/graph_utils.py @@ -15,7 +15,7 @@ from segmentation_skeleton_metrics import utils -MIN_CNT = 20 +MIN_CNT = 30 # --- Update graph structure --- diff --git a/src/segmentation_skeleton_metrics/merge_detection.py b/src/segmentation_skeleton_metrics/merge_detection.py deleted file mode 100644 index 45f7fd1..0000000 --- a/src/segmentation_skeleton_metrics/merge_detection.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Created on Wed April 8 20:30:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - - - -""" - -import numpy as np -from scipy.spatial.distance import euclidean as get_dist - - -def find_sites(graphs, get_labels): - """ - Detects merges between ground truth graphs which are considered to be - potential merge sites. - - Parameters - ---------- - graphs : dict - Dictionary where the keys are graph ids and values are graphs. - get_labels : func - Gets the label of a node in "graphs". - - Returns - ------- - merge_ids : set[tuple] - Set of tuples containing a tuple of graph ids and common label between - the graphs. - - """ - merge_ids = set() - visited = set() - for key_1 in graphs.keys(): - for key_2 in graphs.keys(): - keys = frozenset((key_1, key_2)) - if key_1 != key_2 and keys not in visited: - visited.add(keys) - intersection = get_labels(key_1).intersection( - get_labels(key_2) - ) - for label in intersection: - merge_ids.add((keys, label)) - return merge_ids - - -def localize(graph_1, graph_2, merged_1, merged_2, dist_threshold, merge_id): - """ - Finds the closest pair of xyz coordinates from "merged_1" and "merged_2". - - Parameters - ---------- - graph_1 : networkx.Graph - Graph with potential merge. - graph_2 : networkx.Graph - Graph with potential merge. - merged_1 : set - Nodes contained in "graph_1" with same labels as nodes in "merged_2". - merged_2 : set - Nodes contained in "graph_2" with same labels as nodes in "merged_1". - dist_threshold : float - Distance that determines whether two graphs contain a merge site. - merge_id : tuple - Tuple containing keys corresponding to "graph_1" and "graph_2" along - the common label between them. - - Returns - ------- - xyz_pair : list[numpy.ndarray] - Closest pair of xyz coordinates from "merged_1" and "merged_2". - min_dist : float - Distance between xyz coordinates in "xyz_pair". - - """ - min_dist = np.inf - xyz_pair = list() - for i in merged_1: - for j in merged_2: - xyz_i = graph_1.nodes[i]["xyz"] - xyz_j = graph_2.nodes[j]["xyz"] - if get_dist(xyz_i, xyz_j) < min_dist: - min_dist = get_dist(xyz_i, xyz_j) - xyz_pair = [xyz_i, xyz_j] - if min_dist < dist_threshold: - print("Merge Detected:", merge_id, xyz_pair, min_dist) - return merge_id, xyz_pair, min_dist - return xyz_pair, min_dist diff --git a/src/segmentation_skeleton_metrics/skeleton_metric.py b/src/segmentation_skeleton_metrics/skeleton_metric.py index 7fa503a..4683a1a 100644 --- a/src/segmentation_skeleton_metrics/skeleton_metric.py +++ b/src/segmentation_skeleton_metrics/skeleton_metric.py @@ -8,11 +8,7 @@ """ import os -from concurrent.futures import ( - ProcessPoolExecutor, - ThreadPoolExecutor, - as_completed, -) +from concurrent.futures import ThreadPoolExecutor, as_completed from time import time from zipfile import ZipFile @@ -21,13 +17,7 @@ from scipy.spatial import KDTree from segmentation_skeleton_metrics import graph_utils as gutils -from segmentation_skeleton_metrics import ( - merge_detection, - split_detection, - swc_utils, - utils, -) -from segmentation_skeleton_metrics.merge_detection import find_sites +from segmentation_skeleton_metrics import split_detection, swc_utils, utils ANISOTROPY = [0.748, 0.748, 1.0] MERGE_DIST_THRESHOLD = 20 @@ -60,7 +50,6 @@ def __init__( pred_swc_paths=None, valid_labels=None, save_projections=False, - save_sites=False, ): """ Constructs skeleton metric object that evaluates the quality of a @@ -102,9 +91,6 @@ def __init__( ground truth neurons (i.e. there exists a node in a graph from "self.graphs" that is labeled with a given fragment id. The default is None. - save_sites, : bool, optional - Indication of whether to write merge sites to an swc file. The - default is False. Returns ------- @@ -117,7 +103,6 @@ def __init__( self.ignore_boundary_mistakes = ignore_boundary_mistakes self.output_dir = output_dir self.pred_swc_paths = pred_swc_paths - self.save_sites = save_sites # Labels and Graphs assert type(valid_labels) is set if valid_labels else True @@ -700,7 +685,7 @@ def count_merges(self, key, kdtree): labels = self.graph_to_labels[key] if self.inv_label_map: labels = set.union(*[self.inv_label_map[l] for l in labels]) - + for label in labels: if label in self.fragment_arrays: for xyz in self.fragment_arrays[label][::4]: @@ -1054,6 +1039,40 @@ def init_tracker(self): # -- utils -- +def find_sites(graphs, get_labels): + """ + Detects merges between ground truth graphs which are considered to be + potential merge sites. + + Parameters + ---------- + graphs : dict + Dictionary where the keys are graph ids and values are graphs. + get_labels : func + Gets the label of a node in "graphs". + + Returns + ------- + merge_ids : set[tuple] + Set of tuples containing a tuple of graph ids and common label between + the graphs. + + """ + merge_ids = set() + visited = set() + for key_1 in graphs.keys(): + for key_2 in graphs.keys(): + keys = frozenset((key_1, key_2)) + if key_1 != key_2 and keys not in visited: + visited.add(keys) + intersection = get_labels(key_1).intersection( + get_labels(key_2) + ) + for label in intersection: + merge_ids.add((keys, label)) + return merge_ids + + def generate_result(keys, stats): """ Reorders items in "stats" with respect to the order defined by "keys".