diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1b78278d1..8c64e2b1d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,12 +27,12 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Setup Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.9" architecture: "x64" - name: Install Poetry ${{ matrix.poetry-version }} - uses: abatilo/actions-poetry@v2.3.0 + uses: abatilo/actions-poetry@v2.4.0 with: poetry-version: ${{ matrix.poetry-version }} - name: Install Poetry Dynamic Versioning Plugin @@ -69,12 +69,12 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: "x64" - name: Install Poetry ${{ matrix.poetry-version }} - uses: abatilo/actions-poetry@v2.3.0 + uses: abatilo/actions-poetry@v2.4.0 with: poetry-version: ${{ matrix.poetry-version }} - name: Install Poetry Dynamic Versioning Plugin @@ -92,7 +92,7 @@ jobs: run: poetry build - name: Upload package distribution files if: ${{ matrix.os == 'ubuntu' && matrix.python-version == '3.11' }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: package path: dist @@ -130,12 +130,12 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: "x64" - name: Install Poetry ${{ matrix.poetry-version }} - uses: abatilo/actions-poetry@v2.3.0 + uses: abatilo/actions-poetry@v2.4.0 with: poetry-version: ${{ matrix.poetry-version }} - name: Install Poetry Dynamic Versioning Plugin @@ -158,6 +158,7 @@ jobs: if: "matrix.os == 'ubuntu'" shell: bash run: | + sudo apt-get update sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc - name: Run pytest # headless via Xvfb on linux @@ -180,7 +181,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.9 architecture: "x64" @@ -193,7 +194,7 @@ jobs: echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV - name: Download package distribution files - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: package path: dist diff --git a/README.md b/README.md index c2cdd7a3f..1cce990db 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ To install the package from github, clone the repository and then `cd` into the Pywhy-Graphs is always looking for new contributors to help make the package better, whether it is algorithms, documentation, examples of graph usage, and more! Contributing to Pywhy-Graphs will be rewarding because you will contribute to a much needed package for causal inference. -See our [contributing guide](https://github.com/py-why/pywhy-graphs/CONTRIBUTING.md) for more details. +See our [contributing guide](https://github.com/py-why/pywhy-graphs/blob/main/CONTRIBUTING.md) for more details. # Citing diff --git a/doc/api.rst b/doc/api.rst index d5974823a..9297ebb33 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -61,6 +61,28 @@ causal graph operations. find_connected_pairs add_all_snode_combinations compute_invariant_domains_per_node + is_semi_directed_path + all_semi_directed_paths + +:mod:`pywhy_graphs.algorithms`: Algorithms for dealing with CPDAGs +================================================================== +With Markov equivalence classes of DAGs in a Markovian SCM setting, we obtain +a potentially directed acyclic graph (PDAG), which may be completed (CPDAG). +We may want to generate a consistent DAG extension (i.e. Markov equivalent) of a CPDAG +then we may use some of the algorithms described here. Or perhaps one may want to +convert a DAG to its corresponding CPDAG. + +.. currentmodule:: pywhy_graphs.algorithms + +.. autosummary:: + :toctree: generated/ + + pdag_to_dag + dag_to_cpdag + pdag_to_cpdag + order_edges + label_edges + Conversions between other package's causal graphs ================================================= diff --git a/doc/installation.md b/doc/installation.md index 5e208c6ab..10b312bdb 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -1,9 +1,10 @@ Installation ============ -**pywhy-graphs** supports Python >= 3.8. +**pywhy-graphs** closely follows the NetworkX dependencies and thus supports Python >= 3.9. -## Installing with ``pip`` +Installing with ``pip`` +----------------------- **pywhy-graphs** is available [on PyPI](https://pypi.org/project/pywhy-graphs/). Just run @@ -12,7 +13,8 @@ Installation # or if you use poetry which is recommended poetry add pywhy-graphs -## Installing from source +Installing from source +---------------------- To install **pywhy-graphs** from source, first clone [the repository](https://github.com/py-why/pywhy-graphs): diff --git a/doc/reference/algorithms/index.rst b/doc/reference/algorithms/index.rst index 270638c59..5c65472a6 100644 --- a/doc/reference/algorithms/index.rst +++ b/doc/reference/algorithms/index.rst @@ -63,3 +63,15 @@ Algorithms for handling acyclicity :toctree: ../../generated/ acyclification + + +*************************************** +Semi-directed (possibly-directed) Paths +*************************************** + +.. automodule:: pywhy_graphs.algorithms.semi_directed_paths +.. autosummary:: + :toctree: ../../generated/ + + all_semi_directed_paths + is_semi_directed_path diff --git a/doc/references.bib b/doc/references.bib index 2ba181a6f..e9a8833c1 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -17,6 +17,16 @@ @article{bareinboim_causal_2016 pages = {7345--7352} } +@article{chickering2002learning, + title = {Learning equivalence classes of Bayesian-network structures}, + author = {Chickering, David Maxwell}, + journal = {The Journal of Machine Learning Research}, + volume = {2}, + pages = {445--498}, + year = {2002}, + publisher = {JMLR} +} + @article{Colombo2012, author = {Diego Colombo and Marloes H. Maathuis and Markus Kalisch and Thomas S. Richardson}, title = {{Learning high-dimensional directed acyclic graphs with latent and selection variables}}, diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index dfb4309fb..5a746cbb8 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -29,6 +29,8 @@ Changelog - |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`) - |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`) - |API| Remove support for Python 3.8 by `Adam Li`_ (:pr:`99`) +- |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:`101`) +- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`) Code and Documentation Contributors ----------------------------------- diff --git a/examples/mixededge/plot_mixed_edge_graph.py b/examples/mixededge/plot_mixed_edge_graph.py index e811475fc..3d10dbe43 100644 --- a/examples/mixededge/plot_mixed_edge_graph.py +++ b/examples/mixededge/plot_mixed_edge_graph.py @@ -34,12 +34,8 @@ # %% # Construct a MixedEdgeGraph # -------------------------- -# Using the ``MixedEdgeGraph``, we can represent a causal graph -# with two different kinds of edges. To create the graph, we -# use networkx ``nx.DiGraph`` class to represent directed edges, -# and ``nx.Graph`` class to represent edges without directions (i.e. -# bidirected edges). The edge types are then specified, so the mixed edge -# graph object knows which graphs are associated with which types of edges. +# Here we demonstrate how to construct a mixed edge graph +# by composing networkx graphs. directed_G = nx.DiGraph( [ @@ -60,7 +56,6 @@ name="IV Graph", ) -# Compute the multipartite_layout using the "layer" node attribute pos = nx.spring_layout(G) # we can then visualize the mixed-edge graph diff --git a/examples/visualization/draw_and_compare_graphs_with_same_layout.py b/examples/visualization/draw_and_compare_graphs_with_same_layout.py index f33ebf00c..88aef90b2 100644 --- a/examples/visualization/draw_and_compare_graphs_with_same_layout.py +++ b/examples/visualization/draw_and_compare_graphs_with_same_layout.py @@ -34,6 +34,8 @@ cpdag.orient_uncertain_edge("x", "y") cpdag.orient_uncertain_edge("xy", "y") cpdag.orient_uncertain_edge("z", "y") + +# create a PAG from the CPDAG, with all undirected edges pag = PAG() pag.add_edges_from(G.edges, cpdag.undirected_edge_name) diff --git a/pyproject.toml b/pyproject.toml index 5acf30e6d..66524b946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ exclude_dirs = ["tests"] [tool.black] line-length = 100 -target-version = ['py38'] +target-version = ['py39'] include = '\.pyi?$' extend-exclude = ''' ( @@ -102,10 +102,10 @@ readme = "README.md" classifiers = [ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11' + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12' ] keywords = ['causality', 'graphs', 'causal-inference', 'graphical-model'] diff --git a/pywhy_graphs/algorithms/__init__.py b/pywhy_graphs/algorithms/__init__.py index 87b9ed01d..ddc2edd53 100644 --- a/pywhy_graphs/algorithms/__init__.py +++ b/pywhy_graphs/algorithms/__init__.py @@ -1,4 +1,6 @@ +from .cpdag import * # noqa: F403 from .cyclic import * # noqa: F403 from .generic import * # noqa: F403 from .multidomain import * # noqa: F403 from .pag import * # noqa: F403 +from .semi_directed_paths import * # noqa: F403 diff --git a/pywhy_graphs/algorithms/cpdag.py b/pywhy_graphs/algorithms/cpdag.py new file mode 100644 index 000000000..558dd710a --- /dev/null +++ b/pywhy_graphs/algorithms/cpdag.py @@ -0,0 +1,281 @@ +from enum import Enum + +import networkx as nx + +import pywhy_graphs as pg + +__all__ = ["pdag_to_dag", "dag_to_cpdag", "pdag_to_cpdag", "order_edges", "label_edges"] + + +class EDGELABELS(Enum): + """Edge labels for a CPDAG.""" + + COMPELLED = "compelled" + REVERSIBLE = "reversible" + UNKNOWN = "unknown" + + +def is_clique(G, nodelist): + H = G.subgraph(nodelist) + n = len(nodelist) + return H.size() == n * (n - 1) / 2 + + +def order_edges(G: nx.DiGraph): + """Find total ordering of the edges of DAG G. + + A total ordering is a topological sorting of the nodes, and then + ordering all possible edges according to Algorithm 4 in + :footcite:`chickering2002learning`. The edges are sorted such that + the edges obey the topological sorting of the nodes, but also + is sorted such that the source node of the edge is ordered based + on the topological sort as well. + + Parameters + ---------- + G : DAG + A directed acyclic graph. + + Returns + ------- + list + A list of edges in the DAG. + + References + ---------- + .. footbibliography:: + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + nx.set_edge_attributes(G, None, "order") + ordered_nodes = list(nx.topological_sort(G)) + + idx = 0 + + while any([G[u][v]["order"] is None for u, v in G.edges]): + # get all edges that are still not ordered + unordered_edges = [(u, v) for u, v in G.edges if G[u][v]["order"] is None] + + # get the lowest order unlabeled edge's destination node + y = sorted(unordered_edges, key=lambda x: ordered_nodes.index(x[1]))[-1][-1] + + # find the highest order node such that x -> y is not ordered + unlabeled_y_parent_edges = [u for u in G.predecessors(y) if G[u][y]["order"] is None] + x = sorted(unlabeled_y_parent_edges, key=lambda x: ordered_nodes.index(x))[0] + + # label the edge order + G[x][y]["order"] = idx + idx += 1 + + return G + + +def label_edges(G: nx.DiGraph): + """Label compelled and reversible edges of a DAG G. + + Label the edges of a DAG G as either compelled or reversible. Compelled + edges are edges that are compelled to be directed in a consistent + extension of G. Reversible edges are edges that are not required + to be directed in a consistent extension of G. For full details, + see Algorithm 5 in :footcite:`chickering2002learning`. + + Parameters + ---------- + G : DAG + The directed acyclic graph to label. + + Returns + ------- + DAG + The labelled DAG with edge attribute ``"label"`` as either + ``"compelled"`` or ``"reversible"``. + + References + ---------- + .. footbibliography:: + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + if not all([G[u][v].get("order") is not None for u, v in G.edges]): + raise ValueError("G must have all edges ordered via the `order` attribute") + + nx.set_edge_attributes(G, EDGELABELS.UNKNOWN, "label") + + while any([edge[-1] == EDGELABELS.UNKNOWN for edge in G.edges.data("label")]): + # find the lowest order edge with an unknown label + unknown_edges = [ + (src, target) + for src, target, label in G.edges.data("label") + if label == EDGELABELS.UNKNOWN + ] + unknown_edges.sort(key=lambda edge: G.edges[edge]["order"]) + x, y = unknown_edges[-1] + + # now find every edge w -> x that is labeled as compelled + w_nodes = [w for w in G.predecessors(x) if G[w][x]["label"] == EDGELABELS.COMPELLED] + continue_while_loop = False + for node in w_nodes: + # For all compelled edges w -> x, if there is no edge w -> y, + # we can label the edge x -> y as compelled + if not G.has_edge(node, y): + for src, target in G.in_edges(y): + G[src][target]["label"] = EDGELABELS.COMPELLED + + # now, we start over at the beginning of the while loop + continue_while_loop = True + break + else: + # w -> y is compelled, since there is an edge w -> x that is compelled + # so w is a confounder + G[node][y]["label"] = EDGELABELS.COMPELLED + + if continue_while_loop: + continue + + # now, we check if there an edge z -> y such that: + # 1. z != x + # 2. z is not a parent of x + # If so, then label all unknown edges into y (including x -> y) + # as compelled + # otherwise, label all unknown edges with reversible label + z_exists = len([z for z in G.predecessors(y) if z != x and not G.has_edge(z, x)]) + for src, target in G.in_edges(y): + if G[src][target]["label"] == EDGELABELS.UNKNOWN: + if z_exists: + G[src][target]["label"] = EDGELABELS.COMPELLED + else: + G[src][target]["label"] = EDGELABELS.REVERSIBLE + return G + + +def pdag_to_dag(G): + """Compute consistent extension of given PDAG resulting in a DAG. + + Implements the algorithm described in Figure 11 of :footcite:`chickering2002learning`. + + Parameters + ---------- + G : CPDAG + A partially directed acyclic graph. + + Returns + ------- + DAG + A directed acyclic graph. + + References + ---------- + .. footbibliography:: + """ + if set(["directed", "undirected"]) != set(G.edge_types): + raise ValueError("Only directed and undirected edges are allowed in a CPDAG") + + G = G.copy() + dir_G: nx.DiGraph = G.get_graphs(edge_type="directed") + undir_G: nx.Graph = G.get_graphs(edge_type="undirected") + full_undir_G: nx.Graph = G.to_undirected() + + # initialize a DAG for the consistent extension + dag = nx.DiGraph(dir_G) + + nodes_memo = {node: None for node in G.nodes} + found = False + + while len(nodes_memo) > 0: + found = False + idx = 0 + + nodes = list(nodes_memo.keys()) + + # select a node, x, which: + # 1. has no outgoing edges + # 2. all undirected neighbors are adjacent to all its adjacent nodes + while not found and idx < len(nodes): + # check that there are no outgoing edges for said node + node_is_sink = dir_G.out_degree(nodes[idx]) == 0 + + if not node_is_sink: + idx += 1 + continue + + # since there are no outgoing edges, all directed adjacencies are parent nodes + # now check that all undirected neighbors are adjacent to all its adjacent nodes + undir_nbrs = list(undir_G.neighbors(nodes[idx])) + nearby_is_clique = False + if len(undir_nbrs) != 0: + parents = dir_G.predecessors(nodes[idx]) + # adj = full_undir_G.neighbors(nodes[idx]) + undir_nbrs_and_parents = set(undir_nbrs).union(set(parents)) + nearby_is_clique = is_clique(full_undir_G, undir_nbrs_and_parents) + + if len(undir_nbrs) == 0 or nearby_is_clique: + found = True + + # now, we orient all undirected edges between x and its neighbors + # such that ``nbr -> x`` + for nbr in undir_nbrs: + dag.add_edge(nbr, nodes[idx], edge_type="directed") + + # remove x from the "graph" and memoization + del nodes_memo[nodes[idx]] + dir_G.remove_node(nodes[idx]) + undir_G.remove_node(nodes[idx]) + full_undir_G.remove_node(nodes[idx]) + else: + idx += 1 + + # if no node satisfies condition 1 and 2, then the PDAG does not + # admit a consistent extension + if not found: + print(nodes_memo) + raise ValueError(f"No consistent extension found for PDAG: {G}, {G.edges()}") + return dag + + +def dag_to_cpdag(G): + """Convert a DAG to a CPDAG. + + Creates a CPDAG from a DAG. + + Parameters + ---------- + G : DAG + Directed acyclic graph. + """ + G = order_edges(G) + G = label_edges(G) + + # now construct CPDAG + cpdag = pg.CPDAG() + + # for all compelled edges, add a directed edge + compelled_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.COMPELLED + ] + cpdag.add_edges_from(compelled_edges, edge_type="directed") + + # for all reversible edges, add an undirected edge + reversible_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.REVERSIBLE + ] + cpdag.add_edges_from(reversible_edges, edge_type="undirected") + + return cpdag + + +def pdag_to_cpdag(G): + """Convert a PDAG to a CPDAG. + + Parameters + ---------- + G : PDAG + A partially directed acyclic graph that is not completed. + + Returns + ------- + CPDAG + A completed partially directed acyclic graph. + """ + dag = pdag_to_dag(G) + + return dag_to_cpdag(dag) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 2c336df06..f0011c655 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -1,3 +1,4 @@ +from itertools import combinations from typing import List, Optional, Set, Union import networkx as nx @@ -17,6 +18,7 @@ "valid_mag", "dag_to_mag", "is_maximal", + "all_vstructures", ] @@ -823,3 +825,33 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): else: continue return True + + +def all_vstructures(G: nx.DiGraph, as_edges: bool = False): + """Generate all v-structures in the graph. + + Parameters + ---------- + G : DiGraph + A directed graph. + as_edges : bool + Whether to return the v-structures as edges or as a set of tuples. + + Returns + ------- + vstructs : set + If ``as_edges`` is True, a set of v-structures in the graph encoded as the + (parent_1, child, parent_2) tuple with child being an unshielded collider. + Otherwise, a set of tuples of the form (parent, child), which are part of + v-structures in the graph. + """ + vstructs = set() + for node in G.nodes: + for p1, p2 in combinations(G.predecessors(node), 2): + if p1 not in G.predecessors(p2) and p2 not in G.predecessors(p1): + if as_edges: + vstructs.add((p1, node)) + vstructs.add((p2, node)) + else: + vstructs.add((p1, node, p2)) # type: ignore + return vstructs diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 5526b1ade..bc0f3cd2a 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -30,9 +30,9 @@ def _possibly_directed(G: PAG, i: Node, j: Node, reverse: bool = False): - """Check that path is possibly directed. + """Check that edge is possibly directed. - A possibly directed path is one of the form: + A possibly directed edge is one of the form: - ``i -> j`` - ``i o-> j`` - ``i o-o j`` @@ -67,7 +67,7 @@ def _possibly_directed(G: PAG, i: Node, j: Node, reverse: bool = False): # the direct check checks for i *-> j or i <-* j # i <-> j is also checked - # everything else is valid + # everything else is valid; i.e. i -- j, or i o-o j if direct_check or G.has_edge(i, j, G.bidirected_edge_name): return False return True diff --git a/pywhy_graphs/algorithms/semi_directed_paths.py b/pywhy_graphs/algorithms/semi_directed_paths.py new file mode 100644 index 000000000..29d6fef4f --- /dev/null +++ b/pywhy_graphs/algorithms/semi_directed_paths.py @@ -0,0 +1,188 @@ +import networkx as nx + +from ..config import EdgeType +from ..typing import Node + +__all__ = [ + "is_semi_directed_path", + "all_semi_directed_paths", +] + + +def _empty_generator(): + yield from () + + +def is_semi_directed_path(G, nodes): + """Returns True if and only if `nodes` form a semi-directed path in `G`. + + A *semi-directed path* in a graph is a nonempty sequence of nodes in which + no node appears more than once in the sequence, each adjacent + pair of nodes in the sequence is adjacent in the graph and where each + pair of adjacent nodes does not contain a directed endpoint in the direction + towards the start of the sequence. + + That is ``(a -> b o-> c <-> d -> e)`` is not a semi-directed path from ``a`` to ``e`` + because ``d *-> c`` is a directed endpoint in the direction towards ``a``. + + Parameters + ---------- + G : graph + A mixed-edge graph. + nodes : list + A list of one or more nodes in the graph `G`. + + Returns + ------- + bool + Whether the given list of nodes represents a semi-directed path in `G`. + + Notes + ----- + This function is very similar to networkx's + :func:`networkx.algorithms.simple_paths.is_simple_path` function. + """ + # The empty list is not a valid path. Could also return + # NetworkXPointlessConcept here. + if len(nodes) == 0: + return False + + # If the list is a single node, just check that the node is actually + # in the graph. + if len(nodes) == 1: + return nodes[0] in G + + # check that all nodes in the list are in the graph, if at least one + # is not in the graph, then this is not a semi-directed path + if not all(n in G for n in nodes): + return False + + # If the list contains repeated nodes, then it's not a semi-directed path + if len(set(nodes)) != len(nodes): + return False + + # Test that each adjacent pair of nodes is adjacent and that there + # is no directed endpoint towards the beginning of the sequence. + for idx in range(len(nodes) - 1): + u, v = nodes[idx], nodes[idx + 1] + if G.has_edge(v, u, EdgeType.DIRECTED.value) or G.has_edge(v, u, EdgeType.BIDIRECTED.value): + return False + elif not G.has_edge(u, v): + return False + return True + + +def all_semi_directed_paths(G, source: Node, target: Node, cutoff: int = None): + """Generate all semi-directed paths from source to target in G. + + A semi-directed path is a path from ``source`` to ``target`` in that + no end-point is directed from ``target`` to ``source``. I.e. + ``target *-> source`` does not exist. + + Parameters + ---------- + G : Graph + The graph. + source : Node + The source node. + target : Node + The target node. + cutoff : integer, optional + Depth to stop the search. Only paths of length <= cutoff are returned. + + Notes + ----- + This algorithm is very similar to networkx's + :func:`networkx.algorithms.simple_paths.all_simple_paths` function. + + This algorithm uses a modified depth-first search to generate the + paths [1]_. A single path can be found in $O(V+E)$ time but the + number of semi-directed paths in a graph can be very large, e.g. $O(n!)$ in + the complete graph of order $n$. + + This function does not check that a path exists between `source` and + `target`. For large graphs, this may result in very long runtimes. + Consider using `has_path` to check that a path exists between `source` and + `target` before calling this function on large graphs. + + References + ---------- + .. [1] R. Sedgewick, "Algorithms in C, Part 5: Graph Algorithms", + Addison Wesley Professional, 3rd ed., 2001. + """ + if source not in G: + raise nx.NodeNotFound("source node %s not in graph" % source) + if target in G: + targets = {target} + else: + try: + targets = set(target) # type: ignore + except TypeError: + raise nx.NodeNotFound("target node %s not in graph" % target) + if source in targets: + return _empty_generator() + if cutoff is None: + cutoff = len(G) - 1 + if cutoff < 1: + return _empty_generator() + if cutoff is None: + cutoff = len(G) - 1 + + return _all_semi_directed_paths_graph(G, source, targets, cutoff) + + +def _all_semi_directed_paths_graph( + G, source, targets, cutoff, directed_edge_name="directed", bidirected_edge_name="bidirected" +): + """See networkx's all_simple_paths function. + + This performs a depth-first search for all semi-directed paths from source to target. + """ + # memoize each node that was already visited + visited = {source: True} + + # iterate over neighbors of source + stack = [iter(G.neighbors(source))] + + # if source has no neighbors, then prev_nodes should be None + prev_nodes = [source] + + while stack: + # get the iterator through nbrs for the current node + nbrs = stack[-1] + prev_node = prev_nodes[-1] + nbr = next(nbrs, None) + + # The first condition guarantees that there is not a directed endpoint + # along the path from source to target that points towards source. + if ( + G.has_edge(nbr, prev_node, directed_edge_name) + or G.has_edge(nbr, prev_node, bidirected_edge_name) + ) and nbr not in visited: + # If we've found a directed edge from child to prev_node, + # that we haven't visited, then we don't need to continue down this path + continue + elif nbr is None: + # once all children are visited, pop the stack + # and remove the child from the visited set + stack.pop() + visited.popitem() + prev_nodes.pop() + elif len(visited) < cutoff: + if nbr in visited: + continue + if nbr in targets: + # we've found a path to a target + yield list(visited) + [nbr] + visited[nbr] = True + if targets - set(visited.keys()): # expand stack until find all targets + stack.append(iter(G.neighbors(nbr))) + prev_nodes.append(nbr) + else: + visited.popitem() # maybe other ways to child + else: # len(visited) == cutoff: + for target in (targets & (set(nbrs) | {nbr})) - set(visited.keys()): + yield list(visited) + [target] + stack.pop() + visited.popitem() + prev_nodes.pop() diff --git a/pywhy_graphs/algorithms/tests/test_cpdag.py b/pywhy_graphs/algorithms/tests/test_cpdag.py new file mode 100644 index 000000000..472628cab --- /dev/null +++ b/pywhy_graphs/algorithms/tests/test_cpdag.py @@ -0,0 +1,241 @@ +import networkx as nx +import numpy as np +import pytest + +import pywhy_graphs.networkx as pywhy_nx +from pywhy_graphs.algorithms import all_vstructures +from pywhy_graphs.algorithms.cpdag import ( + EDGELABELS, + dag_to_cpdag, + label_edges, + order_edges, + pdag_to_cpdag, + pdag_to_dag, +) +from pywhy_graphs.testing import assert_mixed_edge_graphs_isomorphic + +seed = 12345 +rng = np.random.default_rng(seed) + + +class TestOrderEdges: + def test_order_edges_errors(self): + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + # so topological sort is: (1, 2, 3, 4, 5) + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + # now test when there is a cycle + G.add_edge(5, 1) + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + order_edges(G) + + def test_order_edges(self): + # Example usage: + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G = order_edges(G) + + expected_order = [ + (1, 2, {"order": 4}), + (1, 3, {"order": 3}), + (2, 4, {"order": 1}), + (3, 4, {"order": 2}), + (4, 5, {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + # Add a string as a node + # 5 -> 3 -> 1 -> 2 -> 'a'; 1 -> 'b' + G = nx.DiGraph() + G.add_edges_from([(5, 3), (3, 1), (1, 2), (2, "a"), (1, "b")]) + G = order_edges(G) + + expected_order = [ + (5, 3, {"order": 4}), + (3, 1, {"order": 3}), + (1, 2, {"order": 2}), + (1, "b", {"order": 1}), + (2, "a", {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + def test_order_edges_ex1(self): + G = nx.DiGraph() + + # 1 -> 3; 1 -> 4; 1 -> 5; + # 2 -> 3; 2 -> 4; 2 -> 5; + # 3 -> 4; 3 -> 5; + # 4 -> 5; + G.add_edges_from([(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]) + G = order_edges(G) + expected_order = [ + (1, 3, {"order": 7}), + (1, 4, {"order": 4}), + (1, 5, {"order": 0}), + (2, 3, {"order": 8}), + (2, 4, {"order": 5}), + (2, 5, {"order": 1}), + (3, 4, {"order": 6}), + (3, 5, {"order": 2}), + (4, 5, {"order": 3}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + +class TestLabelEdges: + def test_label_edges_raises_error_for_non_dag(self): + # Test that label_edges raises a ValueError for a non-DAG + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) # A cyclic graph + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + label_edges(G) + + def test_label_edges_raises_error_for_unordered_edges(self): + # Test that label_edges raises a ValueError for unordered edges + G = nx.DiGraph([(1, 2), (2, 3)]) + with pytest.raises( + ValueError, match="G must have all edges ordered via the `order` attribute" + ): + label_edges(G) + + def test_label_edges_all_compelled(self): + # Create an example DAG for testing + G = nx.DiGraph() + + # 1 -> 3; 3 -> 4; 3 -> 5 + # 2 -> 3; + # 4 -> 5 + G.add_edges_from([(1, 3), (2, 3), (3, 4), (3, 5), (4, 5)]) + nx.set_edge_attributes(G, None, "order") + G = order_edges(G) + labeled_graph = label_edges(G) + + expected_labels = { + (1, 3): EDGELABELS.COMPELLED, + (2, 3): EDGELABELS.COMPELLED, + (3, 4): EDGELABELS.COMPELLED, + (3, 5): EDGELABELS.COMPELLED, + (4, 5): EDGELABELS.REVERSIBLE, + } + for edge, expected_label in expected_labels.items(): + assert labeled_graph[edge[0]][edge[1]]["label"] == expected_label, ( + f"Edge {edge} has label {labeled_graph[edge[0]][edge[1]]['label']}, " + f"but expected {expected_label}" + ) + + +class TestPDAGtoDAG: + def test_pdag_to_dag_errors(self): + G = nx.DiGraph() + G.add_edge("A", "Z") + G.add_edges_from([("A", "B"), ("B", "A"), ("B", "Z"), ("X", "Y"), ("Z", "X")]) + + # add non-CPDAG supported edges + G = pywhy_nx.MixedEdgeGraph(graphs=[G], edge_types=["directed"]) + G.add_edge_type(nx.DiGraph(), "circle") + G.add_edge("Z", "A", edge_type="circle") + G.add_edge("A", "B", edge_type="circle") + G.add_edge("B", "A", edge_type="circle") + G.add_edge("B", "Z", edge_type="circle") + with pytest.raises( + ValueError, match="Only directed and undirected edges are allowed in a CPDAG" + ): + pdag_to_dag(G) + + def test_pdag_to_dag_inconsistent(self): + # 1 -- 3; 1 -> 4; + # 2 -> 3; + # 4 -> 3 + # Note: this PDAG is inconsistent because it would create a v-structure, or a cycle + # by orienting the undirected edge 1 -- 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(1, 4), (2, 3), (4, 3)], edge_type="directed") + with pytest.raises(ValueError, match="No consistent extension found"): + pdag_to_dag(pdag) + + def test_pdag_to_dag_already_dag(self): + # 1 -> 2; 1 -> 3 + # 2 -> 3 + # 4 -> 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 3)], edge_type="directed") + G = pdag_to_dag(pdag) + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_dag_0(self): + # 1 -- 3; + # 2 -> 3; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 3), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + + # add a directed edge from 3 to 1 + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(3, 1, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_dag_1(self): + # 1 -- 3; + # 2 -> 1; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 1), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(1, 3, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_cpdag(self): + # construct a random DAG + n = 10 + p = 0.4 + random_graph = nx.fast_gnp_random_graph(n, p, directed=True, seed=seed) + dag = nx.DiGraph([(u, v) for (u, v) in random_graph.edges() if u < v]) + + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[dag.copy(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + # now we construct the set of undirected edges that to not belong + # to any unshielded collider (i.e. v-structure) + vstructs = all_vstructures(dag, as_edges=True) + + # we apply a random orientation for a subset of the undirected edges + for edge in dag.edges: + if edge not in vstructs: + if rng.binomial(1, 0.3): + pdag.remove_edge(*edge) + pdag.add_edge(*edge, edge_type="undirected") + + # now, we can convert the DAG to CPDAG and also convert the PDAG to a CPDAG + # they should be equivalent + cpdag = dag_to_cpdag(dag) + cpdag_from_pdag = pdag_to_cpdag(pdag) + + assert_mixed_edge_graphs_isomorphic(cpdag, cpdag_from_pdag) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 8b7e6c375..09218a334 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -3,6 +3,7 @@ import pywhy_graphs from pywhy_graphs import ADMG +from pywhy_graphs.algorithms import all_vstructures def test_convert_to_latent_confounder_errors(): @@ -468,3 +469,30 @@ def test_is_maximal(): S = {} L = {"Y"} assert not pywhy_graphs.is_maximal(admg, L, S) + + +def test_all_vstructures(): + # Create a directed graph + G = nx.DiGraph() + G.add_edges_from([(1, 2), (3, 2), (4, 2)]) + + # Generate the v-structures + v_structs_edges = all_vstructures(G, as_edges=True) + v_structs_tuples = all_vstructures(G, as_edges=False) + + # Assert that the returned values are as expected + assert len(v_structs_edges) == 3 + assert len(v_structs_tuples) == 3 + assert (1, 2) in v_structs_edges or (2, 1) in v_structs_edges + assert (3, 2) in v_structs_edges or (2, 3) in v_structs_edges + assert (1, 2, 3) in v_structs_tuples or (3, 2, 1) in v_structs_tuples + assert (4, 2, 3) in v_structs_tuples or (3, 2, 4) in v_structs_tuples + + G.remove_node(2) + # Generate the v-structures + v_structs_edges = all_vstructures(G, as_edges=True) + v_structs_tuples = all_vstructures(G, as_edges=False) + + # Assert that the returned values are as expected + assert len(v_structs_edges) == 0 + assert len(v_structs_tuples) == 0 diff --git a/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py b/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py new file mode 100644 index 000000000..8f4f10c97 --- /dev/null +++ b/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py @@ -0,0 +1,157 @@ +import networkx as nx +import pytest + +import pywhy_graphs.networkx as pywhy_nx +from pywhy_graphs.algorithms import all_semi_directed_paths, is_semi_directed_path + + +# Fixture to create a sample mixed-edge graph for testing +@pytest.fixture +def sample_mixed_edge_graph(): + directed_G = nx.DiGraph([("X", "Y"), ("Z", "X")]) + bidirected_G = nx.Graph([("X", "Y")]) + directed_G.add_nodes_from(bidirected_G.nodes) + bidirected_G.add_nodes_from(directed_G.nodes) + G = pywhy_nx.MixedEdgeGraph( + graphs=[directed_G, bidirected_G], edge_types=["directed", "bidirected"], name="IV Graph" + ) + + G.add_edge_type(nx.DiGraph(), "circle") + G.add_edge("A", "Z", edge_type="directed") + G.add_edge("Z", "A", edge_type="circle") + G.add_edge("A", "B", edge_type="circle") + G.add_edge("B", "A", edge_type="circle") + G.add_edge("B", "Z", edge_type="circle") + return G + + +class TestIsSemiDirectedPath: + def test_empty_path_not_semi_directed(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert not is_semi_directed_path(G, []) + + def test_single_node_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert is_semi_directed_path(G, ["X"]) + + def test_nonexistent_node_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert not is_semi_directed_path(G, ["1", "2"]) + + def test_repeated_nodes_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert not is_semi_directed_path(G, ["X", "Y", "X"]) + + def test_non_connected_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert not is_semi_directed_path(G, ["A", "X"]) + + def test_valid_semi_directed_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert is_semi_directed_path(G, ["Z", "X"]) + assert is_semi_directed_path(G, ["A", "Z", "X"]) + + def test_invalid_semi_directed_path(self, sample_mixed_edge_graph): + G = sample_mixed_edge_graph + assert not is_semi_directed_path(G, ["Y", "X"]) + + # there is a bidirected edge between X and Y + assert not is_semi_directed_path(G, ["X", "Y"]) + assert not is_semi_directed_path(G, ["Z", "X", "Y"]) + + +def test_node_not_in_graph(): + G = nx.Graph() + G.add_edge("X", "Y") + with pytest.raises(nx.NodeNotFound): + all_semi_directed_paths(G, "A", "X") + + with pytest.raises(nx.NodeNotFound): + all_semi_directed_paths(G, "X", 1) + + +def test_target_is_single_node_in_graph(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "X" + paths = all_semi_directed_paths(G, source, "Y") + assert list(paths) == [] + + +def test_source_same_as_target(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "X" + paths = all_semi_directed_paths(G, source, source) + assert list(paths) == [] + + +def test_cutoff_none(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "Z" + paths = all_semi_directed_paths(G, source, "X", cutoff=None) + assert list(paths) == [["Z", "X"]] + + +def test_cutoff_less_than_one(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "X" + paths = all_semi_directed_paths(G, source, "Y", cutoff=0) + assert list(paths) == [] + + +def test_empty_paths(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "1" + target = "B" + with pytest.raises(nx.NodeNotFound, match=f"source node {source} not in graph"): + all_semi_directed_paths(G, source, target) + + G.add_node(source) + G.add_node(target) + paths = all_semi_directed_paths(G, source, target) + assert list(paths) == [] + + +def test_no_paths(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "Y" + target = "X" + cutoff = 3 + paths = all_semi_directed_paths(G, source, target, cutoff) + assert list(paths) == [] + + +def test_multiple_paths(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + + source = "A" + target = "X" + cutoff = 3 + paths = all_semi_directed_paths(G, source, target, cutoff) + paths = list(paths) + + dig = nx.path_graph(5, create_using=nx.DiGraph()) + G.add_edges_from(dig.edges(), edge_type="directed") + G.add_edge("A", 0, edge_type="circle") + + assert len(paths) == 2 + assert all(path in paths for path in [["A", "Z", "X"], ["A", "B", "Z", "X"]]) + + # for a short cutoff, there is only one path + cutoff = 2 + paths = all_semi_directed_paths(G, source, target, cutoff) + assert all(path in paths for path in [["A", "Z", "X"]]) + + # for an even shorter cutoff, there are no paths now + cutoff = 1 + paths = all_semi_directed_paths(G, source, target, cutoff) + assert list(paths) == [] + + +def test_long_cutoff(sample_mixed_edge_graph): + G = sample_mixed_edge_graph + source = "Z" + target = "X" + cutoff = 10 # Cutoff longer than the actual path length + print(G.edges()) + paths = all_semi_directed_paths(G, source, target, cutoff) + assert list(paths) == [[source, target]] diff --git a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py index 022348046..e091a03d9 100644 --- a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py +++ b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py @@ -11,7 +11,15 @@ def mixed_edge_moral_graph( undirected_edge_name="undirected", bidirected_edge_name="bidirected", ): - """Return the moral graph from an ancestral graph in :math:`O(|V|^2)`. + """Return the moral graph from an ancestral graph. + + A moral graph is a graph where all edges are undirected and an edge + between two nodes, ``u`` and ``v``, exists if there is a v-structure + ``u -> w <- v``, where ``u`` and ``v`` are not adjacent. An ancestral + graph is a mixed edge graph with directed, bidirected, and undirected + edges. + + The algorithm runs in :math:`O(|V|^2)`. Parameters ---------- diff --git a/pywhy_graphs/networkx/classes/mixededge.py b/pywhy_graphs/networkx/classes/mixededge.py index f90723973..a9fc41cc4 100644 --- a/pywhy_graphs/networkx/classes/mixededge.py +++ b/pywhy_graphs/networkx/classes/mixededge.py @@ -112,6 +112,12 @@ def __init__(self, graphs=None, edge_types=None, **attr): # load graph attributes (must be after convert) self.graph.update(attr) + # XXX: experimental. Fix this in doc string once finalized. + # make dynamic property names for the edges, (i.e. circle_edges, + # directed_edges, undirected_edges) + for edge_type_name in self.edge_types: + setattr(self, f"{edge_type_name}_edges", self.get_graphs(edge_type_name).edges) + def __str__(self): """Returns a short summary of the graph. diff --git a/pywhy_graphs/viz/draw.py b/pywhy_graphs/viz/draw.py index 934fa39f4..3e9d4dd8d 100644 --- a/pywhy_graphs/viz/draw.py +++ b/pywhy_graphs/viz/draw.py @@ -135,7 +135,7 @@ def draw( # an edge case of drawing graphs is the undirected Markov network if hasattr(G, "undirected_edges"): undirected_edges = G.undirected_edges - elif isinstance(G, nx.Graph): + elif isinstance(G, nx.Graph) and not isinstance(G, nx.DiGraph): undirected_edges = G.edges() if hasattr(G, "bidirected_edges"): bidirected_edges = G.bidirected_edges diff --git a/pywhy_graphs/viz/tests/test_draw.py b/pywhy_graphs/viz/tests/test_draw.py index 1dfa19e1a..82b940ef1 100644 --- a/pywhy_graphs/viz/tests/test_draw.py +++ b/pywhy_graphs/viz/tests/test_draw.py @@ -6,6 +6,26 @@ from pywhy_graphs.viz import draw, timeseries_layout +def test_draw_digraph(): + """ + Ensure the generated graph is a directed graph. + + The number of edges between any two nodes should always be one. + """ + # create a dummy graph x --> y <-- z and z --> x + graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")]) + + # draw the graphs + dot = draw(graph) + + # assert that the produced graph is a directed graph + assert "digraph" in dot.source + + regex_pattern = "(?=(x -> y))" + matches = re.findall(regex_pattern, dot.source) + assert len(matches) == 1 + + def test_draw_pos_is_fully_given(): """ Ensure the Graphviz pos="x,y!" attribute is generated by the draw function