Skip to content

Commit

Permalink
Feat remove merges (#81)
Browse files Browse the repository at this point in the history
* minor upds

* refactor: simplified

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 1, 2024
1 parent 6e5572b commit 408064b
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 405 deletions.
124 changes: 33 additions & 91 deletions src/segmentation_skeleton_metrics/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@email: [email protected]
"""
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from random import sample

Expand All @@ -15,8 +16,6 @@

from segmentation_skeleton_metrics import utils

MIN_CNT = 30


# --- Update graph structure ---
def delete_nodes(graph, delete_label, return_cnt=False):
Expand Down Expand Up @@ -75,7 +74,7 @@ def upd_labels(graph, nodes, label):
Returns
-------
graph : networkx.Graph
networkx.Graph
Updated graph.
"""
Expand All @@ -84,30 +83,41 @@ def upd_labels(graph, nodes, label):
return graph


def store_labels(graph):
def init_label_to_nodes(graph, filter_bool=False, key=None):
"""
Iterates over all nodes in "graph" and stores the label and node id in
a dictionary called "label_to_node".
Initializes a dictionary that maps a label to nodes with that label.
Parameters
----------
graph : networkx.Graph
Graph to be updated
filter_bool : bool, optional
Indication of whether to filter labels that occur less frequently than
a predefined minimum count (MIN_CNT). The default is False.
key : str
Graph ID of "graph".
Returns
-------
label_to_node : dict
Dictionary that stores the label and node id.
dict
Dictionary that maps a label to nodes with that label.
"""
label_to_node = dict()
for i in graph.nodes:
label = graph.nodes[i]["label"]
if label in label_to_node.keys():
label_to_node[label].add(i)
else:
label_to_node[label] = set([i])
return label_to_node
# Initialize dictionary
label_to_nodes = defaultdict(set)
node_to_label = nx.get_node_attributes(graph, "label")
for i, label in node_to_label.items():
label_to_nodes[label].add(i)

# Filter labels (if applicable)
if filter_bool:
label_to_nodes = utils.filter_dict(label_to_nodes)

# Finish
if key:
return key, label_to_nodes
else:
return label_to_nodes


def get_node_labels(graphs):
Expand All @@ -128,51 +138,21 @@ def get_node_labels(graphs):
"""
with ProcessPoolExecutor() as executor:
# Assign processes
processes = list()
for key, graph in graphs.items():
processes.append(executor.submit(parse_node_labels, graph, key))
processes.append(
executor.submit(init_label_to_nodes, graph, True, key)
)

# Store results
graph_to_labels = dict()
for cnt, process in enumerate(as_completed(processes)):
graph_to_labels.update(process.result())
key, label_to_nodes = process.result()
graph_to_labels[key] = set(label_to_nodes.keys())
return graph_to_labels


def parse_node_labels(graph, key):
"""
Parses and filters node labels from the given graph based on whether they
occur less frequently than a predefined minimum count (MIN_CNT).
frequencies.
Parameters
----------
graph : networkx.Graph
Graph containing nodes with labels.
key : hashable
Key under which the resulting set of labels will be stored in the
returned dictionary.
Returns
-------
dict
A dictionary that maps "key" to the set of labels that occur more
frequently than the global variable "MIN_CNT" in the graph.
"""
# Main
label_to_cnt = dict()
for i in graph.nodes:
label = graph.nodes[i]["label"]
if label in label_to_cnt and label != 0:
label_to_cnt[label] += 1
else:
label_to_cnt[label] = 0

# Filter
keep = [l for l, cnt in label_to_cnt.items() if cnt > MIN_CNT]
return {key: set([l for l in label_to_cnt.keys() if l in keep])}


# -- eval tools --
def count_splits(graph):
"""
Expand Down Expand Up @@ -263,26 +243,6 @@ def to_xyz_array(graph):
return np.array([xyz_coords[i] for i in graph.nodes])


def get_coord(graph, i):
"""
Gets xyz image coordinates of node "i".
Parameters
----------
graph : networkx.Graph
Graph to be queried.
i : int
Node of "graph".
Returns
-------
tuple
The xyz image coordinates of node "i".
"""
return tuple(graph.nodes[i]["xyz"])


def sample_leaf(graph):
"""
Samples leaf node from "graph".
Expand All @@ -300,21 +260,3 @@ def sample_leaf(graph):
"""
leafs = [i for i in graph.nodes if graph.degree[i] == 1]
return sample(leafs, 1)[0]


def sample_node(graph):
"""
Samples a node from "graph".
Parameters
----------
graph : networkx.Graph
Graph to be sampled from.
Returns
-------
int
Node.
"""
return sample(list(graph.nodes), 1)[0]
Loading

0 comments on commit 408064b

Please sign in to comment.