-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* minor upds * refactor: simplified --------- Co-authored-by: anna-grim <[email protected]>
- Loading branch information
Showing
4 changed files
with
174 additions
and
405 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
@email: [email protected] | ||
""" | ||
from collections import defaultdict | ||
from concurrent.futures import ProcessPoolExecutor, as_completed | ||
from random import sample | ||
|
||
|
@@ -15,8 +16,6 @@ | |
|
||
from segmentation_skeleton_metrics import utils | ||
|
||
MIN_CNT = 30 | ||
|
||
|
||
# --- Update graph structure --- | ||
def delete_nodes(graph, delete_label, return_cnt=False): | ||
|
@@ -75,7 +74,7 @@ def upd_labels(graph, nodes, label): | |
Returns | ||
------- | ||
graph : networkx.Graph | ||
networkx.Graph | ||
Updated graph. | ||
""" | ||
|
@@ -84,30 +83,41 @@ def upd_labels(graph, nodes, label): | |
return graph | ||
|
||
|
||
def store_labels(graph): | ||
def init_label_to_nodes(graph, filter_bool=False, key=None): | ||
""" | ||
Iterates over all nodes in "graph" and stores the label and node id in | ||
a dictionary called "label_to_node". | ||
Initializes a dictionary that maps a label to nodes with that label. | ||
Parameters | ||
---------- | ||
graph : networkx.Graph | ||
Graph to be updated | ||
filter_bool : bool, optional | ||
Indication of whether to filter labels that occur less frequently than | ||
a predefined minimum count (MIN_CNT). The default is False. | ||
key : str | ||
Graph ID of "graph". | ||
Returns | ||
------- | ||
label_to_node : dict | ||
Dictionary that stores the label and node id. | ||
dict | ||
Dictionary that maps a label to nodes with that label. | ||
""" | ||
label_to_node = dict() | ||
for i in graph.nodes: | ||
label = graph.nodes[i]["label"] | ||
if label in label_to_node.keys(): | ||
label_to_node[label].add(i) | ||
else: | ||
label_to_node[label] = set([i]) | ||
return label_to_node | ||
# Initialize dictionary | ||
label_to_nodes = defaultdict(set) | ||
node_to_label = nx.get_node_attributes(graph, "label") | ||
for i, label in node_to_label.items(): | ||
label_to_nodes[label].add(i) | ||
|
||
# Filter labels (if applicable) | ||
if filter_bool: | ||
label_to_nodes = utils.filter_dict(label_to_nodes) | ||
|
||
# Finish | ||
if key: | ||
return key, label_to_nodes | ||
else: | ||
return label_to_nodes | ||
|
||
|
||
def get_node_labels(graphs): | ||
|
@@ -128,51 +138,21 @@ def get_node_labels(graphs): | |
""" | ||
with ProcessPoolExecutor() as executor: | ||
# Assign processes | ||
processes = list() | ||
for key, graph in graphs.items(): | ||
processes.append(executor.submit(parse_node_labels, graph, key)) | ||
processes.append( | ||
executor.submit(init_label_to_nodes, graph, True, key) | ||
) | ||
|
||
# Store results | ||
graph_to_labels = dict() | ||
for cnt, process in enumerate(as_completed(processes)): | ||
graph_to_labels.update(process.result()) | ||
key, label_to_nodes = process.result() | ||
graph_to_labels[key] = set(label_to_nodes.keys()) | ||
return graph_to_labels | ||
|
||
|
||
def parse_node_labels(graph, key): | ||
""" | ||
Parses and filters node labels from the given graph based on whether they | ||
occur less frequently than a predefined minimum count (MIN_CNT). | ||
frequencies. | ||
Parameters | ||
---------- | ||
graph : networkx.Graph | ||
Graph containing nodes with labels. | ||
key : hashable | ||
Key under which the resulting set of labels will be stored in the | ||
returned dictionary. | ||
Returns | ||
------- | ||
dict | ||
A dictionary that maps "key" to the set of labels that occur more | ||
frequently than the global variable "MIN_CNT" in the graph. | ||
""" | ||
# Main | ||
label_to_cnt = dict() | ||
for i in graph.nodes: | ||
label = graph.nodes[i]["label"] | ||
if label in label_to_cnt and label != 0: | ||
label_to_cnt[label] += 1 | ||
else: | ||
label_to_cnt[label] = 0 | ||
|
||
# Filter | ||
keep = [l for l, cnt in label_to_cnt.items() if cnt > MIN_CNT] | ||
return {key: set([l for l in label_to_cnt.keys() if l in keep])} | ||
|
||
|
||
# -- eval tools -- | ||
def count_splits(graph): | ||
""" | ||
|
@@ -263,26 +243,6 @@ def to_xyz_array(graph): | |
return np.array([xyz_coords[i] for i in graph.nodes]) | ||
|
||
|
||
def get_coord(graph, i): | ||
""" | ||
Gets xyz image coordinates of node "i". | ||
Parameters | ||
---------- | ||
graph : networkx.Graph | ||
Graph to be queried. | ||
i : int | ||
Node of "graph". | ||
Returns | ||
------- | ||
tuple | ||
The xyz image coordinates of node "i". | ||
""" | ||
return tuple(graph.nodes[i]["xyz"]) | ||
|
||
|
||
def sample_leaf(graph): | ||
""" | ||
Samples leaf node from "graph". | ||
|
@@ -300,21 +260,3 @@ def sample_leaf(graph): | |
""" | ||
leafs = [i for i in graph.nodes if graph.degree[i] == 1] | ||
return sample(leafs, 1)[0] | ||
|
||
|
||
def sample_node(graph): | ||
""" | ||
Samples a node from "graph". | ||
Parameters | ||
---------- | ||
graph : networkx.Graph | ||
Graph to be sampled from. | ||
Returns | ||
------- | ||
int | ||
Node. | ||
""" | ||
return sample(list(graph.nodes), 1)[0] |
Oops, something went wrong.