diff --git a/corneto/_graph.py b/corneto/_graph.py index d87fb69..e5dbd8d 100644 --- a/corneto/_graph.py +++ b/corneto/_graph.py @@ -1,6 +1,6 @@ import abc import pickle -from collections import OrderedDict +from collections import OrderedDict, deque from copy import deepcopy from enum import Enum from itertools import chain @@ -632,6 +632,32 @@ def opener(file, mode="r"): with opener(filename, "rb") as f: return pickle.load(f) + def toposort(self): + # Topological sort using Kahn's algorithm + in_degree = {v: len(set(self.predecessors(v))) for v in self._get_vertices()} + + # Initialize queue with nodes having zero in-degree + queue = deque([v for v in in_degree.keys() if in_degree[v] == 0]) + + result = [] + + while queue: + v = queue.popleft() + result.append(v) + + # Decrease the in-degree of successor nodes by 1 + for successor in self.successors(v): + in_degree[successor] -= 1 + if in_degree[successor] == 0: + queue.append(successor) + + # Check if topological sort is possible (i.e., graph has no cycles) + if len(result) == self.num_vertices: + return result + else: + raise ValueError("Graph contains a cycle, so topological sort is not possible.") + + def reachability_analysis( self, input_nodes, diff --git a/corneto/_ml.py b/corneto/_ml.py index 0ff2193..8717a89 100755 --- a/corneto/_ml.py +++ b/corneto/_ml.py @@ -18,18 +18,6 @@ def _load_keras(): except ImportError as e: raise e -def _concat_indexes(layer, indexes, keras): - if len(indexes) > 1: - if len(set(indexes)) == layer.shape[1]: - subset = layer - else: - slices = [layer[:, j : (j + 1)] for j in indexes] - subset = keras.layers.Concatenate()(slices) - else: - j = list(indexes)[0] - subset = layer[:, j : (j + 1)] - return subset - def toposort(G): # Topological sort using Kahn's algorithm in_degree = {v: len(set(G.predecessors(v))) for v in G._get_vertices()} @@ -56,7 +44,6 @@ def toposort(G): raise ValueError("Graph contains a cycle, so topological sort is not possible.") - def index_selector(): keras = _load_keras() @@ -101,8 +88,6 @@ def get_config(self): return IndexSelector - - def build_dagnn( G, input_nodes, @@ -130,7 +115,7 @@ def build_dagnn( )(input_layer) if unit_norm_input: input_layer = keras.layers.UnitNormalization()(input_layer) - vertices = toposort(G) + vertices = G.toposort() input_index = {v: i for i, v in enumerate(input_nodes)} kernel_reg, bias_reg = None, None if kernel_reg_l1 > 0 or kernel_reg_l2 > 0: @@ -212,133 +197,5 @@ def build_dagnn( model = keras.Model(inputs=input_layer, outputs=output_layer) return model - -def create_dagnn( - G: BaseGraph, - input_nodes, - output_nodes, - bias_reg_l1=0, - bias_reg_l2=0, - kernel_reg_l1=0, - kernel_reg_l2=0, - batch_norm_input=True, - batch_norm_center=False, - batch_norm_scale=False, - unit_norm_input=False, - dropout=0.20, - min_inputs_for_dropout=2, - activation_attribute="activation", - default_hidden_activation="sigmoid", - default_output_activation="sigmoid", - verbose=False, -): - keras = _load_keras() - input_layer = keras.Input(shape=(len(input_nodes),), name="inputs") - if batch_norm_input: - input_layer = keras.layers.BatchNormalization( - center=batch_norm_center, scale=batch_norm_scale - )(input_layer) - if unit_norm_input: - input_layer = keras.layers.UnitNormalization()(input_layer) - #if nonneg_unit_norm_input: - # input_layer = keras.layers.Lambda(lambda x: (1 + x) / 2)(input_layer) - input_index = {v: i for i, v in enumerate(input_nodes)} - queue = list(input_nodes) - neurons = {} - concat_cache = {} - while len(queue) > 0: - v = queue.pop(0) - for s in G.successors(v): - if s not in neurons: - queue.append(s) - else: - continue - if s not in input_index: - n_inputs = [] - s_idx_inputs = set() - s_neu_inputs = set() - for p in G.predecessors(s): - if p in input_index: - idx = input_index[p] - s_idx_inputs.add(idx) - else: - s_neu_inputs.add(p) - # Check if all neuron inputs are created - if len(s_neu_inputs) > 0: - if not all([p in neurons for p in s_neu_inputs]): - continue - # Now check if there is a cached concatenation - # for the inputs of this neuron - if len(s_idx_inputs) > 0: - s_idx_inputs = frozenset(s_idx_inputs) - if s_idx_inputs in concat_cache: - n_inputs.append(concat_cache[s_idx_inputs]) - else: - subset_inputs = _concat_indexes( - input_layer, s_idx_inputs, keras - ) - concat_cache[s_idx_inputs] = subset_inputs - n_inputs.append(subset_inputs) - if len(s_neu_inputs) > 0: - s_neu_inputs = frozenset(s_neu_inputs) - if s_neu_inputs in concat_cache: - n_inputs.append(concat_cache[s_neu_inputs]) - else: - if len(s_neu_inputs) > 1: - subset_inputs = keras.layers.Concatenate()( - [neurons[p] for p in s_neu_inputs] - ) - concat_cache[s_neu_inputs] = subset_inputs - else: - subset_inputs = neurons[list(s_neu_inputs)[0]] - n_inputs.append(subset_inputs) - if len(n_inputs) > 1: - neuron_inputs = keras.layers.Concatenate(name=f"{s}_c")(n_inputs) - else: - neuron_inputs = n_inputs[0] - if dropout > 0 and len(n_inputs) >= min_inputs_for_dropout: - neuron_inputs = keras.layers.Dropout(dropout)(neuron_inputs) - - # Create the neuron. - default_act = ( - default_hidden_activation - if s not in output_nodes - else default_output_activation - ) - act = G.get_attr_vertex(s).get(activation_attribute, default_act) - # ElasticNet regularization - kernel_reg, bias_reg = None, None - if kernel_reg_l1 > 0 or kernel_reg_l2 > 0: - kernel_reg = keras.regularizers.l1_l2( - l1=bias_reg_l1, l2=bias_reg_l2 - ) - if bias_reg_l1 > 0 or bias_reg_l2 > 0: - bias_reg = keras.regularizers.l1_l2( - l1=kernel_reg_l1, l2=kernel_reg_l2 - ) - neuron = keras.layers.Dense( - 1, - activation=act, - kernel_regularizer=kernel_reg, - bias_regularizer=bias_reg, - name=s, - ) - x = neuron(neuron_inputs) - neurons[s] = x - if verbose: - print( - f"{s} ({act}) > {len(s_idx_inputs)} data input(s), {len(s_neu_inputs)} neuron input(s)" - ) - # Create the model - if len(output_nodes) == 1: - output_layer = neurons[output_nodes[0]] - else: - output_layer = keras.layers.Concatenate(name="output_layer")( - [neurons[v] for v in output_nodes] - ) - model = keras.Model(inputs=input_layer, outputs=output_layer) - return model - - def plot_model(model): return _load_keras().utils.plot_model(model) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5581df7..85004ae 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -274,6 +274,21 @@ def test_graph_bfs_rev(): assert dist[4] == 2 assert 5 not in dist +def test_graph_toposort(): + g = Graph() + g.add_edge("a", "b") + g.add_edge("a", "c") + g.add_edge("c", "b") + g.add_edge("c", "d") + g.add_edge("c", "e") + g.add_edge("b", "d") + g.add_edge("d", "e") + order = g.toposort() + assert order.index("a") < order.index("b") + assert order.index("a") < order.index("c") + assert order.index("c") < order.index("d") + assert order.index("d") < order.index("e") + assert order.index("b") < order.index("d") def test_incidence_single_edge_single_source_vertex(): g = Graph()