Skip to content

Commit

Permalink
Switched to set<tuple>
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Roy <[email protected]>
  • Loading branch information
aryan26roy committed Aug 15, 2024
1 parent 0acd084 commit f0ef279
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 84 deletions.
42 changes: 21 additions & 21 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def check_back_arrow(G: ADMG, X, Y: set):

def get_X_neighbors(G, X: set):

out = []
out = set()

for elem in X:
elem_neighbors = set(G.neighbors(elem))
Expand All @@ -883,53 +883,52 @@ def get_X_neighbors(G, X: set):

if len(elem_neighbors) != 0:
for nbh in elem_neighbors:
temp = dict()
temp[0] = elem
temp[1] = nbh
out.append(temp)
temp = (elem,)
temp = temp + (nbh,)
out.add(temp)
return out


def recursively_find_pd_paths(G, X, paths, Y):

counter = 0
new_paths = []
new_paths = set()

for i in range(len(paths)):
cur_elem = paths[i][list(paths[i].keys())[-1]]
for elem in paths:
cur_elem = elem[-1]

if cur_elem in Y:
new_paths.append(paths[i])
new_paths.add(elem)
continue

nbr_temp = G.neighbors(cur_elem)
nbr_possible = check_back_arrow(G, cur_elem, nbr_temp)

if len(nbr_possible) == 0:
new_paths.append(paths[i].copy())
new_paths = new_paths + (elem,)

possible_end = nbr_possible.intersection(Y)

if len(possible_end) != 0:
for elem in possible_end:
temp_path = paths[i].copy()
temp_path[len(temp_path)] = elem
new_paths.append(temp_path)
for nbr in possible_end:
temp_path = elem
temp_path = temp_path + (nbr,)
new_paths.add(temp_path)

remaining_nodes = nbr_possible - possible_end
remaining_nodes = (
remaining_nodes
- remaining_nodes.intersection(paths[i].values())
- remaining_nodes.intersection(set(elem))
- remaining_nodes.intersection(X)
)

temp_arr = []
for elem in remaining_nodes:
temp_paths = paths[i].copy()
temp_paths[len(temp_paths)] = elem
temp_arr.append(temp_paths)
temp_set = set()
for nbr in remaining_nodes:
temp_paths = elem
temp_paths = temp_paths + (nbr,)
temp_set.add(temp_paths)

new_paths.extend(recursively_find_pd_paths(G, X, temp_arr, Y))
new_paths.update(recursively_find_pd_paths(G, X, temp_set, Y))

return new_paths

Expand All @@ -950,5 +949,6 @@ def possibly_directed_path(G, X: Optional[Set] = None, Y: Optional[Set] = None):
x_neighbors.append(temp)

path_list = recursively_find_pd_paths(G, X, x_neighbors, Y)
print(path_list)

return path_list
88 changes: 25 additions & 63 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y"}

correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
correct = {('Y', 'X', 'Z', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct == out

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
Expand All @@ -522,10 +522,9 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y", "A"}

correct = [{0: "A", 1: "X", 2: "Z", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "H"}]
correct = {('Y', 'X', 'Z', 'H'), ('A', 'X', 'Z', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct == out

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
Expand All @@ -536,9 +535,9 @@ def test_possibly_directed():
Y = {"H"}
X = {"Y", "A"}

correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
correct = {('Y', 'X', 'Z', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct == out

admg = ADMG()
admg.add_edge("X", "A", admg.directed_edge_name)
Expand All @@ -550,9 +549,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}]
correct = {('Y', 'X', 'Z', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct == out

admg = ADMG()
admg.add_edge("A", "X", admg.directed_edge_name)
Expand All @@ -564,17 +563,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [
{0: "A", 1: "X", 2: "Z", 3: "H"},
{0: "A", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
]
correct = {('Y', 'X', 'Z', 'K'), ('A', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'H'), ('A', 'X', 'Z', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[3] == out[3]
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -587,10 +578,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [{0: "A", 1: "G", 2: "C", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "K"}]
correct = {('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -604,15 +594,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [
{0: "A", 1: "G", 2: "C", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"},
]
correct = {('Y', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'C', 'H'), ('A', 'G', 'C', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -623,12 +607,9 @@ def test_possibly_directed():
Y = {"G", "H"}
X = {"A", "K"}

correct = [{0: "K", 1: "G"}, {0: "K", 1: "H"}, {0: "A", 1: "G"}, {0: "A", 1: "H"}]
correct = {('K', 'H'), ('K', 'G'), ('A', 'G'), ('A', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[3] == out[3]
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -642,13 +623,12 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [
{0: "A", 1: "G", 2: "C", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
]
correct = {
("A","G","C","H"),
("Y","X","Z","K"),
}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct == out

admg = ADMG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -662,16 +642,9 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [
{0: "A", 1: "G", 2: "C", 3: "H"},
{0: "Y", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"},
{0: "A", 1: "G", 2: "C", 3: "Z", 4: "K"},
]
correct = {('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'H')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct == out

admg = PAG()
admg.add_edge("A", "G", admg.directed_edge_name)
Expand All @@ -686,17 +659,6 @@ def test_possibly_directed():
Y = {"H", "K"}
X = {"Y", "A"}

correct = [
[
{0: "Y", 1: "X", 2: "Z", 3: "K"},
{0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"},
{0: "A", 1: "G", 2: "C", 3: "H"},
{0: "A", 1: "G", 2: "C", 3: "Z", 4: "K"},
]
]

correct = {('Y', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'C', 'H'), ('A', 'G', 'C', 'H'), ('A', 'G', 'C', 'Z', 'K')}
out = pywhy_graphs.possibly_directed_path(admg, X, Y)
assert correct[0] == out[0]
assert correct[1] == out[1]
assert correct[2] == out[2]
assert correct[3] == out[3]
assert correct == out

0 comments on commit f0ef279

Please sign in to comment.