Skip to content

Commit

Permalink
add profiling to test performance of self-rolled graph vs networkx
Browse files Browse the repository at this point in the history
  • Loading branch information
csquires committed Aug 4, 2018
1 parent bd7140a commit 1b84c2a
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ build/
.idea/
release.sh

.mypy_cache/


3 changes: 2 additions & 1 deletion causaldag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@

from .classes import *
from .loaders import *

from . import rand
from . import inference
2 changes: 1 addition & 1 deletion causaldag/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dag import DAG
from .pdag import PDAG
from .gaussdag import GaussDAG
from .gaussdag import GaussDAG, GaussIntervention
21 changes: 15 additions & 6 deletions causaldag/classes/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, nodes=None, arcs=None):
self._parents = self._get_parents_dict()
self._children = self._get_children_dict()

def copy(self):
return DAG(nodes=self.nodes, arcs=self.arcs)

@property
def nodes(self):
return set(self._nodes)
Expand Down Expand Up @@ -323,14 +326,20 @@ def cpdag(self):

def interventional_cpdag(self, intervened_nodes, cpdag=None):
from .pdag import PDAG

if cpdag is None:
cpdag = self.cpdag()
dag_cut = self.copy()
for node in intervened_nodes:
for i, j in dag_cut.incoming_arcs(node):
dag_cut.remove_arc(i, j)
pdag = PDAG(dag_cut)
else:
cut_edges = set()
for node in intervened_nodes:
cut_edges.update(self.incident_arcs(node))
known_arcs = cut_edges | cpdag._known_arcs
pdag = PDAG(self, known_arcs=known_arcs)

cut_edges = set()
for node in intervened_nodes:
cut_edges.update(self.incident_arcs(node))
known_edges = cut_edges | cpdag._known_arcs
pdag = PDAG(self, known_arcs=known_edges)
pdag.remove_unprotected_orientations()
return pdag

Expand Down
80 changes: 58 additions & 22 deletions causaldag/classes/gaussdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,42 @@
from causaldag.classes.dag import DAG
import numpy as np
from causaldag.utils import core_utils
from dataclasses import dataclass
from typing import Any, Dict, Union, Set, Tuple, List


@dataclass
class GaussIntervention:
mean: float
variance: float


class GaussDAG(DAG):
def __init__(self, nodes=None, arcs=None, weight_mat=None, variances=None):
if weight_mat is None:
super().__init__(nodes, arcs)
self._node_list = list(nodes)
self._weight_mat = np.zeros((len(nodes), len(nodes)))
else:
self._weight_mat = weight_mat.copy()
nnodes = weight_mat.shape[0]
if nodes is None:
self._node_list = list(range(nnodes))
if variances is None:
self._variances = np.ones(nnodes)

arcs = set()
for edge, val in np.ndenumerate(weight_mat):
if val != 0:
arcs.add(edge)
super().__init__(self._node_list, arcs)
def __init__(self, nodes: List, arcs: Union[Set[Tuple[Any, Any]], Dict[Tuple[Any, Any], float]], means=None, variances=None):
arcs_set = arcs if isinstance(arcs, set) else set(arcs.keys())
super().__init__(set(nodes), arcs_set)

self._node_list = nodes
self._node2ix = core_utils.ix_map_from_list(self._node_list)

self._weight_mat = np.zeros((len(nodes), len(nodes)))
for node1, node2 in arcs:
w = arcs[(node1, node2)] if isinstance(arcs, dict) else 1
self._weight_mat[self._node2ix[node1], self._node2ix[node2]] = w

self._variances = np.ones(len(nodes)) if variances is None else np.array(variances, dtype=float)
self._means = np.zeros((len(nodes))) if means is None else np.array(means)

self._precision = None
self._covariance = None

@classmethod
def from_weight_matrix(cls, weight_mat, nodes=None, means=None, variances=None):
nodes = nodes if nodes is not None else list(range(weight_mat.shape[0]))
arcs = {(i, j): w for (i, j), w in np.ndenumerate(weight_mat) if w != 0}
print(nodes, arcs)
return cls(nodes=nodes, arcs=arcs, means=means, variances=variances)

def set_arc_weight(self, i, j, val):
self._weight_mat[self._node2ix[i], self._node2ix[j]] = val
if val == 0 and (i, j) in self._arcs:
Expand Down Expand Up @@ -125,7 +135,7 @@ def save_gml(self, filename):
raise NotImplementedError

def to_amat(self):
raise NotImplementedError
return self.weight_mat

def cpdag(self):
raise NotImplementedError
Expand Down Expand Up @@ -161,11 +171,11 @@ def _ensure_covariance(self):
else:
self._covariance = id_min_a_inv.T @ np.diag(self._variances ** -1) @ id_min_a_inv

def sample(self, nsamples):
def sample(self, nsamples: int = 1) -> np.array:
samples = np.zeros((nsamples, len(self._nodes)))
noise = np.zeros((nsamples, len(self._nodes)))
for ix, var in enumerate(self._variances):
noise[:,ix] = np.random.normal(scale=var, size=nsamples)
noise[:, ix] = np.random.normal(scale=var, size=nsamples)
t = self.topological_sort()
for node in t:
ix = self._node2ix[node]
Expand All @@ -178,17 +188,43 @@ def sample(self, nsamples):
samples[:, ix] = noise[:, ix]
return samples

def sample_interventional(self, interventions: Dict[Any, GaussIntervention], nsamples: int = 1) -> np.array:
samples = np.zeros((nsamples, len(self._nodes)))
noise = np.zeros((nsamples, len(self._nodes)))

for ix, (node, mean, var) in enumerate(zip(self._node_list, self._means, self._variances)):
iv = interventions.get(node)
if iv is not None:
mean = iv.mean
var = iv.variance
noise[:, ix] = np.random.normal(loc=mean, scale=var, size=nsamples)

t = self.topological_sort()
for node in t:
ix = self._node2ix[node]
parents = self._parents[node]
if node not in interventions and len(parents) != 0:
parent_ixs = [self._node2ix[p] for p in self._parents[node]]
parent_vals = samples[:, parent_ixs]
samples[:, ix] = np.sum(parent_vals * self._weight_mat[parent_ixs, node], axis=1) + noise[:, ix]
else:
samples[:, ix] = noise[:, ix]

return samples


if __name__ == '__main__':
B = np.zeros((3, 3))
B[0, 1] = 1
B[0, 2] = -1
B[1, 2] = 4
gdag = GaussDAG(weight_mat=B)
gdag = GaussDAG.from_weight_matrix(B, means=[0, 0, 0], variances=[1, 1, 1])
s = gdag.sample(1000)
# print(gdag.arcs)
print(s.T @ s / 1000)
print(gdag.covariance)
s2 = gdag.sample_interventional({2: GaussIntervention(mean=0, variance=5)}, 1000)
print(s2.T @ s2 / 1000)



Expand Down
6 changes: 1 addition & 5 deletions causaldag/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
from .pcalg import pcalg
from .ges import ges



from . import structural
2 changes: 2 additions & 0 deletions causaldag/inference/structural/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .pcalg import pcalg
from .ges import ges
3 changes: 3 additions & 0 deletions causaldag/rand/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .graphs import *


6 changes: 3 additions & 3 deletions causaldag/random/_random.py → causaldag/rand/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import itertools as itr


def coin(p, size=1):
def _coin(p, size=1):
return np.random.binomial(1, p, size=size)


def directed_erdos(n, s, size=1):
if size == 1:
arcs = {(i, j) for i, j in itr.combinations(range(n), 2) if coin(s)}
bools = _coin(s, size=int(n*(n-1)/2))
arcs = {(i, j) for (i, j), b in zip(itr.combinations(range(n), 2), bools) if b}
return DAG(nodes=set(range(n)), arcs=arcs)
else:
return [directed_erdos(n, s) for _ in range(size)]
Expand All @@ -19,4 +20,3 @@ def directed_erdos(n, s, size=1):




1 change: 0 additions & 1 deletion causaldag/random/__init__.py

This file was deleted.

Empty file added profiling/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions profiling/time_create_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import causaldag as cd
import networkx as nx
from profiling.time_dec import timed


@timed
def test_create_nx_small():
for _ in range(10000):
g = nx.DiGraph()
g.add_edge(1, 2)
g.add_edge(2, 3)
g.add_edge(1, 3)


@timed
def test_create_dag_small():
for _ in range(10000):
g = cd.DAG()
g.add_arc(1, 2)
g.add_arc(2, 3)
g.add_arc(1, 3)


test_create_nx_small()
test_create_dag_small()

import numpy as np
np.random.seed(1729)
nnodes_large = 1000
arcs = cd.rand.directed_erdos(nnodes_large, .5).arcs
print(len(arcs))


@timed
def test_create_nx_large():
for i in range(10):
print(i)
g = nx.DiGraph()
g.add_nodes_from(range(nnodes_large))
g.add_edges_from(arcs)


@timed
def test_create_dag_large():
for i in range(10):
print(i)
g = cd.DAG(nodes=range(nnodes_large), arcs=arcs)


test_create_nx_large()
test_create_dag_large()

14 changes: 14 additions & 0 deletions profiling/time_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from functools import wraps
import time


def timed(func):
"""This decorator prints the execution time for the decorated function."""
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print("{} ran in {}s".format(func.__name__, round(end - start, 2)))
return result
return wrapper
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name='causaldag',
version='0.1a.6',
version='0.1a.11',
description='Causal DAG manipulation and inference',
long_description='CausalDAG is a Python package for the creation, manipulation, and learning of Causal DAGs.',
url='http://github.com/storborg/funniest',
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pdag.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ def test_to_dag(self):
}
self.assertIn(dag2.arcs, true_possible_arcs)

def test_interventional_cpdag_no_obs(self):
dag = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})

icpdag1 = dag.interventional_cpdag([2])
self.assertEqual(icpdag1.arcs, {(1, 3), (2, 3)})
self.assertEqual(icpdag1.edges, set())

icpdag2 = dag.interventional_cpdag([1])
self.assertEqual(icpdag2.arcs, set())
self.assertEqual(icpdag2.edges, {(1, 2), (1, 3), (2, 3)})


if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 1b84c2a

Please sign in to comment.