Skip to content

Commit

Permalink
Added a _proper_pag function
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Roy <[email protected]>
  • Loading branch information
aryan26roy committed Jan 6, 2024
1 parent 0d09c33 commit f8b07e1
Showing 1 changed file with 56 additions and 7 deletions.
63 changes: 56 additions & 7 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG
from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path, valid_mag
from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path, has_adc, inducing_path
from pywhy_graphs.typing import Node, TsNode

from dodiscover.constraint.fcialg import FCI
Expand Down Expand Up @@ -1180,12 +1180,61 @@ def pag_to_mag(graph):

return mag

def _proper_pag(G: PAG):

#check for acyclicity
#check for ancestrality
#check for maximality

def _proper_pag(G: PAG, L: Optional[set] = None, S: Optional[set] = None):
"""Checks if the provided graph is a valid Partial ancestral graph (MAG).
Parameters
----------
G : Graph
The graph.
Returns
-------
is_valid : bool
A boolean indicating whether the provided graph is a valid PAG or not.
"""

if L is None:
L = set()

if S is None:
S = set()

directed_sub_graph = G.sub_directed_graph()

all_nodes = set(G.nodes)

# check if there are any undirected edges or more than one edges b/w two nodes
for node in all_nodes:
nb = set(G.neighbors(node))
for elem in nb:
edge_data = G.get_edge_data(node, elem)
if (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None):
return False

# check if there are any directed cyclces
try:
nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error
return False
except nx.NetworkXNoCycle:
pass

# check if there are any almost directed cycles
if has_adc(G): # if there is an ADC, it's not a valid MAG
return False

# check if there are any inducing paths between non-adjacent nodes

for source in all_nodes:
nb = set(G.neighbors(source))
cur_set = all_nodes - nb
cur_set.remove(source)
for dest in cur_set:
out = inducing_path(G, source, dest, L, S)
if out[0] is True:
return False

return True


Expand Down

0 comments on commit f8b07e1

Please sign in to comment.