Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 10, 2024
1 parent 31c6069 commit fd5736c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
24 changes: 16 additions & 8 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,16 +857,20 @@ def all_vstructures(G: nx.DiGraph, as_edges: bool = False):
vstructs.add((p1, node, p2)) # type: ignore
return vstructs


def check_back_arrow(G: ADMG, X, Y: set):

out = set()

for elem in Y:
if not (G.has_edge(X,elem,G.bidirected_edge_name) or G.has_edge(elem,X,G.directed_edge_name)):
if not (
G.has_edge(X, elem, G.bidirected_edge_name) or G.has_edge(elem, X, G.directed_edge_name)
):
out.update(elem)

return out


def get_X_neighbors(G, X: set):

out = []
Expand All @@ -877,7 +881,7 @@ def get_X_neighbors(G, X: set):
to_remove = X.intersection(elem_possible_neighbors)
elem_neighbors = elem_possible_neighbors - to_remove

if len(elem_neighbors) != 0:
if len(elem_neighbors) != 0:
temp = dict()
count = 0
temp[count] = elem
Expand All @@ -888,6 +892,7 @@ def get_X_neighbors(G, X: set):

return out


def recursively_find_pd_paths(G, X, paths, Y):

counter = 0
Expand All @@ -900,7 +905,7 @@ def recursively_find_pd_paths(G, X, paths, Y):

if len(nbr_possible) == 0:
new_paths.append(paths[i].copy())

possible_end = nbr_possible.intersection(Y)

if len(possible_end) != 0:
Expand All @@ -910,19 +915,22 @@ def recursively_find_pd_paths(G, X, paths, Y):
new_paths.append(temp_path)

remaining_nodes = nbr_possible - possible_end
remaining_nodes = remaining_nodes - remaining_nodes.intersection(paths[i].values()) - remaining_nodes.intersection(X)
remaining_nodes = (
remaining_nodes
- remaining_nodes.intersection(paths[i].values())
- remaining_nodes.intersection(X)
)

temp_arr = []
for elem in remaining_nodes:
temp_paths = paths[i].copy()
temp_paths[len(temp_paths)] = elem
temp_paths = paths[i].copy()
temp_paths[len(temp_paths)] = elem
temp_arr.append(temp_paths)

new_paths.extend(recursively_find_pd_paths(G, X, temp_arr, Y))

return new_paths



def possibly_directed_path(G, X: Optional[Set] = None, Y: Optional[Set] = None):

Expand Down
56 changes: 30 additions & 26 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,9 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y"}


correct = [{0: 'Y', 1: 'X', 2: 'Z', 3: 'H'}]
correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[0] == out[0]

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
Expand All @@ -523,10 +522,10 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y", "A"}

correct = [{0: 'A', 1: 'X', 2: 'Z', 3: 'H'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'H'}]
correct = [{0: "A", 1: "X", 2: "Z", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "H"}]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[0] == out[0]
assert correct[1] == out[1]

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
Expand All @@ -537,10 +536,9 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y", "A"}

correct = [{0: 'Y', 1: 'X', 2: 'Z', 3: 'H'}]
correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]

assert correct[0] == out[0]

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
Expand All @@ -552,10 +550,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: 'Y', 1: 'X', 2: 'Z', 3: 'H'}]
correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]

assert correct[0] == out[0]

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
Expand All @@ -567,13 +564,17 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: 'A', 1: 'X', 2: 'Z', 3: 'H'}, {0: 'A', 1: 'X', 2: 'Z', 3: 'K'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'H'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'K'}]
correct = [
{0: "A", 1: "X", 2: "Z", 3: "H"},
{0: "A", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[3] == out[3]

assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[3] == out[3]

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -586,11 +587,10 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: 'A', 1: 'G', 2: 'C', 3: 'H'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'K'}]
correct = [{0: "A", 1: "G", 2: "C", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "K"}]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]

assert correct[0] == out[0]
assert correct[1] == out[1]

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -604,8 +604,12 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: 'A', 1: 'G', 2: 'C', 3: 'H'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'K'}, {0: 'Y', 1: 'X', 2: 'Z', 3: 'C', 4: 'H'}]
correct = [
{0: "A", 1: "G", 2: "C", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"},
]
pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]

0 comments on commit fd5736c

Please sign in to comment.