Skip to content

Commit

Permalink
[ENH] Add the ability to find Proper Possibly Directed Paths (#112)
Browse files Browse the repository at this point in the history
* Added function to return a list of possibly directed paths between two nodes

---------

Signed-off-by: Aryan Roy <[email protected]>
Co-authored-by: Adam Li <[email protected]>
  • Loading branch information
aryan26roy and adam2392 authored Aug 26, 2024
1 parent 68de868 commit 053a58b
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Changelog
- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`)
- |API| Remove poetry based setup, by `Adam Li`_ (:pr:`110`)
- |Feature| Implement and test function to validate PAG, by `Aryan Roy`_ (:pr:`100`)
- |Feature| Implement and test function to find all the proper possibly directed paths, by `Aryan Roy`_ (:pr:`112`)

Code and Documentation Contributors
-----------------------------------
Expand Down
186 changes: 186 additions & 0 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"dag_to_mag",
"is_maximal",
"all_vstructures",
"proper_possibly_directed_path",
]


Expand Down Expand Up @@ -855,3 +856,188 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False):
else:
vstructs.add((p1, node, p2)) # type: ignore
return vstructs


def _check_back_arrow(G: ADMG, X, Y: set):
"""Retrieve all the neigbors of X that do not have
an arrow pointing back to it.
Parameters
----------
G : DiGraph
A directed graph.
X : Node
Y : Set
A set of neigbors of X.
Returns
-------
out : set
A set of all the neighbors of X that do not have an arrow pointing
back to it.
"""
out = set()

for elem in Y:
if not (
G.has_edge(X, elem, G.bidirected_edge_name) or G.has_edge(elem, X, G.directed_edge_name)
):
out.update(elem)

return out


def _get_neighbors_of_set(G, X: set):
"""Retrieve all the neigbors of X when X has more than one element.
Note that if X is not a set, graph.neighbors(X) is sufficient.
Parameters
----------
G : DiGraph
A directed graph.
X : Set
Returns
-------
out : set
A set of all the neighbors of X.
"""

out = set()

for elem in X:
elem_neighbors = set(G.neighbors(elem))
elem_possible_neighbors = _check_back_arrow(G, elem, elem_neighbors)
to_remove = X.intersection(elem_possible_neighbors)
elem_neighbors = elem_possible_neighbors - to_remove

if len(elem_neighbors) != 0:
for nbh in elem_neighbors:
temp = (elem,)
temp = temp + (nbh,)
out.add(temp)
return out


def _recursively_find_pd_paths(G, X, paths, Y):
"""Recursively finds all the possibly directed paths for a given
graph.
Parameters
----------
G : DiGraph
A directed graph.
X : Set
Source.
paths : Set
Set of initial paths from X.
Y : Set
Destination
Returns
-------
out : set
A set of all the possibly directed paths.
"""

counter = 0
new_paths = set()

for elem in paths:
cur_elem = elem[-1]

if cur_elem in Y:
new_paths.add(elem)
continue

nbr_temp = G.neighbors(cur_elem)
nbr_possible = _check_back_arrow(G, cur_elem, nbr_temp)

if len(nbr_possible) == 0:
new_paths = new_paths + (elem,)

possible_end = nbr_possible.intersection(Y)

if len(possible_end) != 0:
for nbr in possible_end:
temp_path = elem
temp_path = temp_path + (nbr,)
new_paths.add(temp_path)

remaining_nodes = nbr_possible - possible_end
remaining_nodes = (
remaining_nodes
- remaining_nodes.intersection(set(elem))
- remaining_nodes.intersection(X)
)

temp_set = set()
for nbr in remaining_nodes:
temp_paths = elem
temp_paths = temp_paths + (nbr,)
temp_set.add(temp_paths)

new_paths.update(_recursively_find_pd_paths(G, X, temp_set, Y))

return new_paths


def proper_possibly_directed_path(G, X: Optional[Set], Y: Optional[Set]):
"""Find all the proper possibly directed paths in a graph. A proper possibly directed
path from X to Y is a set of edges with just the first node in X and none of the edges
with an arrow pointing back to X.
Parameters
----------
G : DiGraph
A directed graph.
X : Set
Source.
Y : Set
Destination
Returns
-------
out : set
A set of all the proper possibly directed paths.
Examples
--------
The function generates a set of tuples containing all the valid
proper possibly directed paths from X to Y.
>>> import pywhy_graphs
>>> from pywhy_graphs import PAG
>>> pag = PAG()
>>> pag.add_edge("A", "G", pag.directed_edge_name)
>>> pag.add_edge("G", "C", pag.directed_edge_name)
>>> pag.add_edge("C", "H", pag.directed_edge_name)
>>> pag.add_edge("Z", "C", pag.circle_edge_name)
>>> pag.add_edge("C", "Z", pag.circle_edge_name)
>>> pag.add_edge("Y", "X", pag.directed_edge_name)
>>> pag.add_edge("X", "Z", pag.directed_edge_name)
>>> pag.add_edge("Z", "K", pag.directed_edge_name)
>>> Y = {"H", "K"}
>>> X = {"Y", "A"}
>>> pywhy_graphs.proper_possibly_directed_path(pag, X, Y)
{('A', 'G', 'C', 'H'), ('Y', 'X', 'Z', 'C', 'H'), ('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'Z', 'K')}
"""

if isinstance(X, set):
x_neighbors = _get_neighbors_of_set(G, X)
else:
nbr_temp = G.neighbors(X)
nbr_possible = _check_back_arrow(nbr_temp)
x_neighbors = []

for elem in nbr_possible:
temp = dict()
temp[0] = X
temp[1] = elem
x_neighbors.append(temp)

path_list = _recursively_find_pd_paths(G, X, x_neighbors, Y)

return path_list
181 changes: 180 additions & 1 deletion pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

import pywhy_graphs
from pywhy_graphs import ADMG
from pywhy_graphs import ADMG, PAG
from pywhy_graphs.algorithms import all_vstructures


Expand Down Expand Up @@ -496,3 +496,182 @@ def test_all_vstructures():
# Assert that the returned values are as expected
assert len(v_structs_edges) == 0
assert len(v_structs_tuples) == 0


def test_proper_possibly_directed():
# X <- Y <-> Z <-> H; Z -> X

admg = ADMG()
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "H", admg.directed_edge_name)

Y = {"H"}
X = {"Y"}

correct = {("Y", "X", "Z", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "H", admg.directed_edge_name)

Y = {"H"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "H"), ("A", "X", "Z", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "H", admg.directed_edge_name)

Y = {"H"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "H", admg.directed_edge_name)
admg.add_edge("K", "Z", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "H", admg.directed_edge_name)
admg.add_edge("Z", "K", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {
("Y", "X", "Z", "K"),
("A", "X", "Z", "K"),
("Y", "X", "Z", "H"),
("A", "X", "Z", "H"),
}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
admg.add_edge("G", "C", admg.directed_edge_name)
admg.add_edge("C", "H", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "K", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "K"), ("A", "G", "C", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
admg.add_edge("G", "C", admg.directed_edge_name)
admg.add_edge("C", "H", admg.directed_edge_name)
admg.add_edge("Z", "C", admg.directed_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "K", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "K"), ("Y", "X", "Z", "C", "H"), ("A", "G", "C", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
admg.add_edge("A", "H", admg.directed_edge_name)
admg.add_edge("K", "G", admg.directed_edge_name)
admg.add_edge("K", "H", admg.directed_edge_name)

Y = {"G", "H"}
X = {"A", "K"}

correct = {("K", "H"), ("K", "G"), ("A", "G"), ("A", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
admg.add_edge("G", "C", admg.directed_edge_name)
admg.add_edge("C", "H", admg.directed_edge_name)
admg.add_edge("Z", "C", admg.bidirected_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "K", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {
("A", "G", "C", "H"),
("Y", "X", "Z", "K"),
}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
admg.add_edge("G", "C", admg.directed_edge_name)
admg.add_edge("C", "H", admg.directed_edge_name)
admg.add_edge("Z", "C", admg.bidirected_edge_name)
admg.add_edge("Y", "X", admg.directed_edge_name)
admg.add_edge("X", "Z", admg.directed_edge_name)
admg.add_edge("Z", "K", admg.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {("Y", "X", "Z", "K"), ("A", "G", "C", "H")}
out = pywhy_graphs.proper_possibly_directed_path(admg, X, Y)
assert correct == out


def test_ppdp_PAG():

pag = PAG()
pag.add_edge("A", "G", pag.directed_edge_name)
pag.add_edge("G", "C", pag.directed_edge_name)
pag.add_edge("C", "H", pag.directed_edge_name)
pag.add_edge("Z", "C", pag.circle_edge_name)
pag.add_edge("C", "Z", pag.circle_edge_name)
pag.add_edge("Y", "X", pag.directed_edge_name)
pag.add_edge("X", "Z", pag.directed_edge_name)
pag.add_edge("Z", "K", pag.directed_edge_name)

Y = {"H", "K"}
X = {"Y", "A"}

correct = {
("Y", "X", "Z", "K"),
("Y", "X", "Z", "C", "H"),
("A", "G", "C", "H"),
("A", "G", "C", "Z", "K"),
}
out = pywhy_graphs.proper_possibly_directed_path(pag, X, Y)
assert correct == out

0 comments on commit 053a58b

Please sign in to comment.