Skip to content

Commit

Permalink
Added a test
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Roy <[email protected]>
  • Loading branch information
aryan26roy committed Jan 20, 2024
1 parent 8b7ac3a commit 4f83210
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
6 changes: 4 additions & 2 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
has_adc,
inducing_path,
single_source_shortest_mixed_path,
valid_mag
)
from pywhy_graphs.typing import Node, TsNode

Expand All @@ -32,6 +33,7 @@
"is_definite_noncollider",
"pag_to_mag",
"legal_pag",
"valid_pag"
]


Expand Down Expand Up @@ -1344,9 +1346,9 @@ def valid_pag(G: PAG):

converted_mag = pag_to_mag(G)

valid_mag = valid_mag(converted_mag)
is_valid = valid_mag(converted_mag)

if valid_mag:
if is_valid:
interim_bool = True

# convert the mag back to a pag
Expand Down
24 changes: 22 additions & 2 deletions pywhy_graphs/algorithms/tests/test_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import pywhy_graphs
from pywhy_graphs import PAG
from pywhy_graphs import PAG,ADMG
from pywhy_graphs.algorithms import (
discriminating_path,
is_definite_noncollider,
Expand Down Expand Up @@ -725,7 +725,7 @@ def test_pag_to_mag():
)


def test_pag_to_mag():
def test_legal_pag():
# D o-o A o-> B <-o C

pag = PAG()
Expand Down Expand Up @@ -769,3 +769,23 @@ def test_pag_to_mag():
pag_bool = pywhy_graphs.legal_pag(pag)

assert pag_bool is False


def test_valid_pag():

pag = PAG()
pag.add_edge("A", "D", pag.directed_edge_name)
pag.add_edge("A", "C", pag.circle_edge_name)
pag.add_edge("D", "A", pag.circle_edge_name)
pag.add_edge("B", "D", pag.directed_edge_name)
pag.add_edge("C", "D", pag.directed_edge_name)
pag.add_edge("D", "B", pag.circle_edge_name)
pag.add_edge("D", "C", pag.circle_edge_name)
pag.add_edge("C", "A", pag.circle_edge_name)
pag.add_edge("B", "A", pag.circle_edge_name)
pag.add_edge("A", "B", pag.circle_edge_name)

# C o- A o-> D <-o B
# B o-o A o-o C o-> D

assert pywhy_graphs.valid_pag(pag) is True

0 comments on commit 4f83210

Please sign in to comment.