Skip to content

Commit

Permalink
[ENH] Add the ability to convert a DAG to an MAG (#96)
Browse files Browse the repository at this point in the history
* Add is_maximal function
* add DAG to MAG function

---------

Signed-off-by: Aryan Roy <[email protected]>
Co-authored-by: Adam Li <[email protected]>
  • Loading branch information
aryan26roy and adam2392 authored Sep 26, 2023
1 parent 37ede7f commit 9f3e202
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ causal graph operations.
.. autosummary::
:toctree: generated/

dag_to_mag
valid_mag
has_adc
inducing_path
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Version 0.2
Changelog
---------
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)

Code and Documentation Contributors
-----------------------------------
Expand Down
120 changes: 120 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"inducing_path",
"has_adc",
"valid_mag",
"dag_to_mag",
"is_maximal",
]


Expand Down Expand Up @@ -567,6 +569,9 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
if node_x == node_y:
raise ValueError("The source and destination nodes are the same.")

if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
return (False, [])

edges = G.edges()

# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
Expand Down Expand Up @@ -703,3 +708,118 @@ def valid_mag(G: ADMG, L: set = None, S: set = None):
return False

return True


def dag_to_mag(G, L: Set = None, S: Set = None):
"""Converts a DAG to a valid MAG.
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
Parameters:
-----------
G : Graph
The graph.
L : Set
Nodes that are ignored on the path. Defaults to an empty set.
S : Set
Nodes that are always conditioned on. Defaults to an empty set.
Returns
-------
mag : Graph
The MAG.
"""

if L is None:
L = set()

if S is None:
S = set()

# for each pair of nodes find if they have an inducing path between them.
# only then will they be adjacent in the MAG.

all_nodes = set(G.nodes)
adj_nodes = []

for source in all_nodes:
copy_all = all_nodes.copy()
copy_all.remove(source)
for dest in copy_all:
out = inducing_path(G, source, dest, L, S)
if out[0] is True and {source, dest} not in adj_nodes:
adj_nodes.append({source, dest})

# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes

mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

ansA: Set = set()
ansB: Set = set()

for node in AuS:
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))

for node in BuS:
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))

if A in ansB and B not in ansA:
# if A is in ansB and B is not in ansA, A -> B
mag.add_edge(A, B, mag.directed_edge_name)

elif A not in ansB and B in ansA:
# if B is in ansA and A is not in ansB, A <- B
mag.add_edge(B, A, mag.directed_edge_name)

elif A not in ansB and B not in ansA:
# if A is not in ansB and B is not in ansA, A <-> B
mag.add_edge(B, A, mag.bidirected_edge_name)

elif A in ansB and B in ansA:
# if A is in ansB and B is in ansA, A - B
mag.add_edge(B, A, mag.undirected_edge_name)

return mag


def is_maximal(G, L: Set = None, S: Set = None):
"""Checks to see if the graph is maximal.
Parameters:
-----------
G : Graph
The graph.
Returns
-------
is_maximal : bool
A boolean indicating whether the provided graph is maximal or not.
"""

if L is None:
L = set()

if S is None:
S = set()

all_nodes = set(G.nodes)
checked = set()
for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
current_pair = frozenset({source, dest})
if current_pair not in checked:
checked.add(current_pair)
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False
else:
continue
return True
124 changes: 124 additions & 0 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,30 @@ def test_inducing_path_corner_cases():

assert pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]

# X -> Z <- Y, A <- B <- Z
admg = ADMG()
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Y", "Z", admg.directed_edge_name)
admg.add_edge("Z", "B", admg.directed_edge_name)
admg.add_edge("B", "A", admg.directed_edge_name)

L = {"X"}
S = {"A"}

assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]

# X -> Z <- Y, A <- B <- Z
admg = ADMG()
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Y", "Z", admg.directed_edge_name)
admg.add_edge("Z", "B", admg.directed_edge_name)
admg.add_edge("B", "A", admg.directed_edge_name)

L = {}
S = {"A", "Y"}

assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]


def test_is_collider():
# Z -> X -> A <- B -> Y; H -> A
Expand Down Expand Up @@ -348,3 +372,103 @@ def test_valid_mag():
admg.add_edge("H", "J", admg.undirected_edge_name)

assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J


def test_dag_to_mag():

# A -> E -> S
# H -> E , H -> R
admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("E", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()
dir_edges = list(out_edges["directed"])
assert (
("A", "R") in out_edges["directed"]
and ("E", "R") in out_edges["directed"]
and len(out_edges["directed"]) == 2
)
assert ("A", "E") in out_edges["undirected"]

out_mag = pywhy_graphs.dag_to_mag(admg)
dir_edges = list(out_mag.edges()["directed"])

assert (
("A", "E") in dir_edges
and ("E", "S") in dir_edges
and ("H", "E") in dir_edges
and ("H", "R") in dir_edges
)

# A -> E -> S <- H
# H -> E , H -> R,

admg = ADMG()
admg.add_edge("A", "E", admg.directed_edge_name)
admg.add_edge("H", "S", admg.directed_edge_name)
admg.add_edge("H", "E", admg.directed_edge_name)
admg.add_edge("H", "R", admg.directed_edge_name)

S = {"S"}
L = {"H"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()

dir_edges = list(out_edges["directed"])
assert ("A", "E") in out_edges["directed"] and len(out_edges["directed"]) == 1
assert ("E", "R") in out_edges["bidirected"]

# P -> S -> L <- G
# G -> S -> I <- J
# J -> S

admg = ADMG()
admg.add_edge("P", "S", admg.directed_edge_name)
admg.add_edge("S", "L", admg.directed_edge_name)
admg.add_edge("G", "S", admg.directed_edge_name)
admg.add_edge("G", "L", admg.directed_edge_name)
admg.add_edge("I", "S", admg.directed_edge_name)
admg.add_edge("J", "I", admg.directed_edge_name)
admg.add_edge("J", "S", admg.directed_edge_name)

S = set()
L = {"J"}

out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
assert pywhy_graphs.is_maximal(out_mag)
assert not pywhy_graphs.has_adc(out_mag)
out_edges = out_mag.edges()
dir_edges = list(out_edges["directed"])
assert (
("G", "S") in dir_edges
and ("G", "L") in dir_edges
and ("S", "L") in dir_edges
and ("I", "S") in dir_edges
and ("P", "S") in dir_edges
and len(dir_edges) == 5
)


def test_is_maximal():
# X <- Y <-> Z <-> H; Z -> X
admg = ADMG()
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("Z", "X", admg.directed_edge_name)
admg.add_edge("Z", "Y", admg.bidirected_edge_name)
admg.add_edge("Z", "H", admg.bidirected_edge_name)

S = {}
L = {"Y"}
assert not pywhy_graphs.is_maximal(admg, L, S)

0 comments on commit 9f3e202

Please sign in to comment.