Skip to content

Commit

Permalink
use pandas while getting all pdags
Browse files Browse the repository at this point in the history
  • Loading branch information
csquires committed Jul 30, 2018
1 parent f11250e commit cb59c1c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 41 deletions.
103 changes: 74 additions & 29 deletions causaldag/classes/pdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ def __str__(self):
substrings.append('[{node}|{parents}:{nbrs}]'.format(node=node, parents=parents_str, nbrs=nbrs_str))
return ''.join(substrings)

def remove_node(self, node):
self._nodes.remove(node)
self._arcs = {(i, j) for i, j in self._arcs if i != node and j != node}
self._edges = {(i, j) for i, j in self._edges if i != node and j != node}
for child in self._children[node]:
self._parents[child].remove(node)
self._neighbors[child].remove(node)
for parent in self._parents[node]:
self._children[parent].remove(node)
self._neighbors[parent].remove(node)
for u_nbr in self._neighbors[node]:
self._undirected_neighbors[u_nbr].remove(node)
del self._parents[node]
del self._children[node]
del self._neighbors[node]
del self._undirected_neighbors[node]

@property
def nodes(self):
return set(self._nodes)
Expand Down Expand Up @@ -191,55 +208,83 @@ def add_known_arcs(self, arcs):
raise NotImplementedError

def to_amat(self, node_list=None):
import pandas as pd

if node_list is None:
node_list = sorted(self._nodes)

node2ix = {node: i for i, node in enumerate(node_list)}
amat = np.zeros([len(self._nodes), len(self._nodes)])
amat = np.zeros([len(self._nodes), len(self._nodes)], dtype=int)
for source, target in self._arcs:
amat[node2ix[source], node2ix[target]] = 1
for i, j in self._edges:
amat[node2ix[i], node2ix[j]] = 1
amat[node2ix[j], node2ix[i]] = 1
return amat, node_list

return pd.DataFrame(amat, index=node_list, columns=node_list)

def _possible_sinks(self):
return {node for node in self._nodes if len(self._children[node]) == 0}

def _neighbors_covered(self, node):
return {node2: self.neighbors[node2] - {node} == self.neighbors[node] for node2 in self._nodes}

def all_dags(self):
def to_dag(self):
from .dag import DAG

pdag2 = self.copy()
arcs = set()
while len(pdag2._edges) + len(pdag2._arcs) != 0:
is_sink = lambda n: len(pdag2._children[n]) == 0
adj_check = lambda n: all(
(pdag2._neighbors[n] - {u_nbr}).issubset(pdag2._neighbors[u_nbr])
for u_nbr in pdag2._undirected_neighbors[n]
)
sink = next(n for n in pdag2._nodes if is_sink(n) and adj_check(n))
print(sink)
if sink is None:
break
arcs.update((u_nbr, sink) for u_nbr in pdag2._undirected_neighbors[sink])
pdag2.remove_node(sink)

return DAG(arcs=arcs)

def all_dags(self, verbose=False):
amat = self.to_amat()
all_dags = []
_all_dags_helper(self, all_dags)
_all_dags_helper(amat, amat, all_dags, verbose=verbose)
return all_dags


def _all_dags_helper(curr_graph, curr_dags, sinked_nodes=None):
if sinked_nodes is None:
sinked_nodes = set()

print(curr_graph)
print(sinked_nodes)
if len(curr_graph._edges) == 0:
print('here')
curr_dags.append(curr_graph)
return

sinks = {node for node in curr_graph._nodes if len(curr_graph._children[node] - sinked_nodes) == 0}
def _all_dags_helper(full_amat, curr_submatrix, all_dags, verbose=False):
if curr_submatrix.sum().sum() == 0:
all_dags.append(full_amat)
print(full_amat)
print(curr_submatrix)
nchildren = ((curr_submatrix - curr_submatrix.T) > 0).sum(axis=1)
sink_ixs = (nchildren == 0).nonzero()[0]
sinks = curr_submatrix.index[sink_ixs]
print('sinks', sinks)
for sink in sinks:
undirected_nbrs = curr_graph._undirected_neighbors[sink]
all_undirected_protected = all(
(curr_graph._neighbors[sink] - {nbr}) == (curr_graph._neighbors[nbr] - {sink})
for nbr in undirected_nbrs
)
if all_undirected_protected and len(undirected_nbrs) != 0:
sinked_nodes = set(sinked_nodes)
sinked_nodes.add(sink)

new_graph = curr_graph.copy()
for nbr in undirected_nbrs:
new_graph._replace_edge_with_arc((nbr, sink))
_all_dags_helper(new_graph, curr_dags, sinked_nodes)
children_ixs = curr_submatrix.loc[sink].nonzero()[0]
children = set(full_amat.index[children_ixs])
parent_ixs = curr_submatrix[sink].nonzero()[0]
parents = set(full_amat.index[parent_ixs])
print(sink, children, parents)

undirected_nbrs = list(children & parents)
sink_nbrs = children | parents
get_neighbors = lambda n: set(curr_submatrix.index[curr_submatrix[n].nonzero()[0]]) | set(curr_submatrix.index[curr_submatrix[n].nonzero()[0]])

nbrs_of_undirected_nbrs = (get_neighbors(nbr) for nbr in undirected_nbrs)
nbrs_of_undirected_nbrs = list(nbrs_of_undirected_nbrs)
print(nbrs_of_undirected_nbrs)
if len(undirected_nbrs) > 0 and all((sink_nbrs - {nbr}).issubset(nbrs_of_nbr) for nbr, nbrs_of_nbr in zip(undirected_nbrs, nbrs_of_undirected_nbrs)):
print('Removing sink node', sink)
new_full_amat = full_amat.copy()
new_full_amat.loc[sink] = 0
curr_submatrix = curr_submatrix.drop(sink, axis=0).drop(sink, axis=1)
_all_dags_helper(new_full_amat, curr_submatrix, all_dags, verbose=verbose)



47 changes: 35 additions & 12 deletions tests/test_pdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,50 @@ def test_interventional_cpdag(self):
# self.assertEqual(cpdag.arcs, {(1, 2), (2, 3)})
# self.assertEqual(cpdag.edges, set())

def test_pdag2alldags_complete3(self):
dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
# def test_pdag2alldags_complete3(self):
# dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
# cpdag = dag.cpdag()
# dags = cpdag.all_dags()
# self.assertEqual(len(dags), 6)
# for dag in dags:
# self.assertEqual(len(dag.arcs), 3)
# true_possible_arcs = {
# frozenset({(1, 2), (1, 3), (2, 3)}),
# frozenset({(1, 2), (1, 3), (3, 2)}), # flip 2->3
# frozenset({(1, 2), (3, 1), (3, 2)}), # flip 1->3
# frozenset({(2, 1), (3, 1), (3, 2)}), # flip 1->2
# frozenset({(2, 1), (3, 1), (2, 3)}), # flip 3->2
# frozenset({(2, 1), (1, 3), (2, 3)}), # flip 3->1
# }
# self.assertEqual(true_possible_arcs, {frozenset(d.arcs) for d in dags})

def test_pdag2alldags_chain3(self):
dag = cd.DAG(arcs={(1, 2), (2, 3)})
cpdag = dag.cpdag()
dags = cpdag.all_dags()
self.assertEqual(len(dags), 6)
for dag in dags:
self.assertEqual(len(dag.arcs), 3)
dags = cpdag.all_dags(verbose=True)
print('dags', dags)
true_possible_arcs = {
frozenset({(1, 2), (1, 3), (2, 3)}),
frozenset({(1, 2), (1, 3), (3, 2)}), # flip 2->3
frozenset({(1, 2), (3, 1), (3, 2)}), # flip 1->3
frozenset({(2, 1), (3, 1), (3, 2)}), # flip 1->2
frozenset({(2, 1), (3, 1), (2, 3)}), # flip 3->2
frozenset({(2, 1), (1, 3), (2, 3)}), # flip 3->1
frozenset({(1, 2), (2, 3)}),
frozenset({(2, 1), (2, 3)}),
frozenset({(1, 2), (3, 2)}),
}
self.assertEqual(true_possible_arcs, {frozenset(d.arcs) for d in dags})

def test_optimal_intervention(self):
dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
self.assertEqual(dag.optimal_intervention(), 2)

def test_to_dag(self):
dag = cd.DAG(arcs={(1, 2), (2, 3)})
cpdag = dag.cpdag()
dag2 = cpdag.to_dag()
true_possible_arcs = {
frozenset({(1, 2), (2, 3)}),
frozenset({(2, 1), (2, 3)}),
frozenset({(1, 2), (3, 2)}),
}
self.assertIn(dag2.arcs, true_possible_arcs)


if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit cb59c1c

Please sign in to comment.