diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 9f775405..2bbf49ef 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -15,6 +15,7 @@ has_adc, inducing_path, single_source_shortest_mixed_path, + valid_mag ) from pywhy_graphs.typing import Node, TsNode @@ -32,6 +33,7 @@ "is_definite_noncollider", "pag_to_mag", "legal_pag", + "valid_pag" ] @@ -1344,9 +1346,9 @@ def valid_pag(G: PAG): converted_mag = pag_to_mag(G) - valid_mag = valid_mag(converted_mag) + is_valid = valid_mag(converted_mag) - if valid_mag: + if is_valid: interim_bool = True # convert the mag back to a pag diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 1ff45ad7..de3cea6f 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -3,7 +3,7 @@ import pytest import pywhy_graphs -from pywhy_graphs import PAG +from pywhy_graphs import PAG,ADMG from pywhy_graphs.algorithms import ( discriminating_path, is_definite_noncollider, @@ -725,7 +725,7 @@ def test_pag_to_mag(): ) -def test_pag_to_mag(): +def test_legal_pag(): # D o-o A o-> B <-o C pag = PAG() @@ -769,3 +769,23 @@ def test_pag_to_mag(): pag_bool = pywhy_graphs.legal_pag(pag) assert pag_bool is False + + +def test_valid_pag(): + + pag = PAG() + pag.add_edge("A", "D", pag.directed_edge_name) + pag.add_edge("A", "C", pag.circle_edge_name) + pag.add_edge("D", "A", pag.circle_edge_name) + pag.add_edge("B", "D", pag.directed_edge_name) + pag.add_edge("C", "D", pag.directed_edge_name) + pag.add_edge("D", "B", pag.circle_edge_name) + pag.add_edge("D", "C", pag.circle_edge_name) + pag.add_edge("C", "A", pag.circle_edge_name) + pag.add_edge("B", "A", pag.circle_edge_name) + pag.add_edge("A", "B", pag.circle_edge_name) + + # C o- A o-> D <-o B + # B o-o A o-o C o-> D + + assert pywhy_graphs.valid_pag(pag) is True