Skip to content

Commit

Permalink
finished pdag2alldags with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
csquires committed Jul 31, 2018
1 parent cb59c1c commit bd7140a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 38 deletions.
45 changes: 29 additions & 16 deletions causaldag/classes/pdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, dag_or_pdag, known_arcs=set()):
self._known_arcs = dag_or_pdag._known_arcs | known_arcs
self._undirected_neighbors = defaultdict(set, dag_or_pdag.undirected_neighbors)

def __eq__(self, other):
same_nodes = self._nodes == other._nodes
same_arcs = self._arcs == other._arcs
same_edges = self._edges == other._edges

return same_nodes and same_arcs and same_edges

def __str__(self):
substrings = []
for node in self._nodes:
Expand Down Expand Up @@ -251,40 +258,46 @@ def to_dag(self):

def all_dags(self, verbose=False):
amat = self.to_amat()
all_dags = []
_all_dags_helper(amat, amat, all_dags, verbose=verbose)
node_list = list(amat.index)
all_dags = set()
_all_dags_helper(amat, amat, node_list, all_dags, verbose=verbose)
return all_dags


def _all_dags_helper(full_amat, curr_submatrix, all_dags, verbose=False):
def _all_dags_helper(full_amat, curr_submatrix, node_list, all_dags, verbose=False):
if curr_submatrix.sum().sum() == 0:
all_dags.append(full_amat)
print(full_amat)
print(curr_submatrix)
arcs = frozenset((node_list[i], node_list[j]) for (i, j), val in np.ndenumerate(full_amat) if val==1)
all_dags.add(arcs)
if verbose: print('=== APPENDING ===')
if verbose: print(arcs)
if verbose: print('=================')
return

if verbose: print(full_amat)
nchildren = ((curr_submatrix - curr_submatrix.T) > 0).sum(axis=1)
if verbose: print('nchildren\n', nchildren)
sink_ixs = (nchildren == 0).nonzero()[0]
sinks = curr_submatrix.index[sink_ixs]
print('sinks', sinks)

if verbose: print(set(sinks))
for sink in sinks:
children_ixs = curr_submatrix.loc[sink].nonzero()[0]
children = set(full_amat.index[children_ixs])
children = set(curr_submatrix.index[children_ixs])
parent_ixs = curr_submatrix[sink].nonzero()[0]
parents = set(full_amat.index[parent_ixs])
print(sink, children, parents)
parents = set(curr_submatrix.index[parent_ixs])

undirected_nbrs = list(children & parents)
sink_nbrs = children | parents
get_neighbors = lambda n: set(curr_submatrix.index[curr_submatrix[n].nonzero()[0]]) | set(curr_submatrix.index[curr_submatrix[n].nonzero()[0]])

nbrs_of_undirected_nbrs = (get_neighbors(nbr) for nbr in undirected_nbrs)
nbrs_of_undirected_nbrs = list(nbrs_of_undirected_nbrs)
print(nbrs_of_undirected_nbrs)
if len(undirected_nbrs) > 0 and all((sink_nbrs - {nbr}).issubset(nbrs_of_nbr) for nbr, nbrs_of_nbr in zip(undirected_nbrs, nbrs_of_undirected_nbrs)):
print('Removing sink node', sink)
if len(sink_nbrs) > 0 and all((sink_nbrs - {nbr}).issubset(nbrs_of_nbr) for nbr, nbrs_of_nbr in zip(undirected_nbrs, nbrs_of_undirected_nbrs)):
if verbose: print('Removing sink node', sink, 'and edges to', sink_nbrs, 'from:\n', full_amat)
new_full_amat = full_amat.copy()
new_full_amat.loc[sink] = 0
curr_submatrix = curr_submatrix.drop(sink, axis=0).drop(sink, axis=1)
_all_dags_helper(new_full_amat, curr_submatrix, all_dags, verbose=verbose)
new_full_amat.loc[sink][sink_nbrs] = 0
new_submatrix = curr_submatrix.drop(sink, axis=0).drop(sink, axis=1)
_all_dags_helper(new_full_amat, new_submatrix, node_list, all_dags, verbose=verbose)



59 changes: 37 additions & 22 deletions tests/test_pdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import causaldag as cd
import numpy as np
import itertools as itr


class TestDAG(TestCase):
Expand Down Expand Up @@ -112,34 +113,48 @@ def test_interventional_cpdag(self):
# self.assertEqual(cpdag.arcs, {(1, 2), (2, 3)})
# self.assertEqual(cpdag.edges, set())

# def test_pdag2alldags_complete3(self):
# dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
# cpdag = dag.cpdag()
# dags = cpdag.all_dags()
# self.assertEqual(len(dags), 6)
# for dag in dags:
# self.assertEqual(len(dag.arcs), 3)
# true_possible_arcs = {
# frozenset({(1, 2), (1, 3), (2, 3)}),
# frozenset({(1, 2), (1, 3), (3, 2)}), # flip 2->3
# frozenset({(1, 2), (3, 1), (3, 2)}), # flip 1->3
# frozenset({(2, 1), (3, 1), (3, 2)}), # flip 1->2
# frozenset({(2, 1), (3, 1), (2, 3)}), # flip 3->2
# frozenset({(2, 1), (1, 3), (2, 3)}), # flip 3->1
# }
# self.assertEqual(true_possible_arcs, {frozenset(d.arcs) for d in dags})

def test_pdag2alldags_chain3(self):
def test_pdag2alldags_3nodes_complete(self):
dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
cpdag = dag.cpdag()
dags = cpdag.all_dags(verbose=False)
self.assertEqual(len(dags), 6)
for dag in dags:
self.assertEqual(len(dag), 3)
true_possible_arcs = {
frozenset({(1, 2), (1, 3), (2, 3)}),
frozenset({(1, 2), (1, 3), (3, 2)}), # flip 2->3
frozenset({(1, 2), (3, 1), (3, 2)}), # flip 1->3
frozenset({(2, 1), (3, 1), (3, 2)}), # flip 1->2
frozenset({(2, 1), (3, 1), (2, 3)}), # flip 3->2
frozenset({(2, 1), (1, 3), (2, 3)}), # flip 3->1
}
self.assertEqual(true_possible_arcs, dags)

def test_pdag2alldags_3nodes_chain(self):
dag = cd.DAG(arcs={(1, 2), (2, 3)})
cpdag = dag.cpdag()
dags = cpdag.all_dags(verbose=True)
print('dags', dags)
dags = cpdag.all_dags(verbose=False)
true_possible_arcs = {
frozenset({(1, 2), (2, 3)}),
frozenset({(2, 1), (2, 3)}),
frozenset({(1, 2), (3, 2)}),
frozenset({(2, 1), (3, 2)}),
}
self.assertEqual(true_possible_arcs, {frozenset(d.arcs) for d in dags})
self.assertEqual(true_possible_arcs, dags)

def test_pdag2alldags_5nodes(self):
dag = cd.DAG(arcs={(1, 2), (2, 3), (1, 3), (2, 4), (2, 5), (3, 5), (4, 5)})
cpdag = dag.cpdag()
dags = cpdag.all_dags()
for arcs in dags:
dag2 = cd.DAG(arcs=set(arcs))
cpdag2 = dag2.cpdag()
self.assertEqual(cpdag, cpdag2)

def test_pdag2alldags_8nodes_complete(self):
dag = cd.DAG(arcs={(i, j) for i, j in itr.combinations(range(6), 2)})
cpdag = dag.cpdag()
dags = cpdag.all_dags()
self.assertEqual(len(dags), np.prod(range(1, 7)))

def test_optimal_intervention(self):
dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
Expand Down

0 comments on commit bd7140a

Please sign in to comment.