Skip to content

Commit

Permalink
Fix ci?
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Oct 30, 2023
1 parent 506aab2 commit b151126
Show file tree
Hide file tree
Showing 24 changed files with 59 additions and 63 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ causal-learn = { version = "^0.1.2.8" }
ananke-causal = { version = "^0.3.3" }
pre-commit = "^3.0.4"
pandas = { version = "^1.1" } # needed for simulation
torch = { version="^2.0.0" }

[tool.poetry.group.style]
optional = true
Expand Down
16 changes: 8 additions & 8 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Set, Union
from typing import List, Optional, Set, Union

import networkx as nx

Expand All @@ -20,7 +20,9 @@
]


def is_node_common_cause(G: nx.DiGraph, node: Node, exclude_nodes: List[Node] = None) -> bool:
def is_node_common_cause(
G: nx.DiGraph, node: Node, exclude_nodes: Optional[List[Node]] = None
) -> bool:
"""Check if a node is a common cause within the graph.
Parameters
Expand Down Expand Up @@ -519,7 +521,7 @@ def _shortest_valid_path(
return (path_exists, path)


def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Optional[Set] = None):
"""Checks if an inducing path exists between two nodes.
An inducing path is defined in :footcite:`Zhang2008`.
Expand Down Expand Up @@ -599,7 +601,6 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):

path_exists = False
for elem in x_neighbors:

visited = {node_x}
if elem not in visited:
path_exists, temp_path = _shortest_valid_path(
Expand Down Expand Up @@ -646,7 +647,7 @@ def has_adc(G):
return adc_present


def valid_mag(G: ADMG, L: set = None, S: set = None):
def valid_mag(G: ADMG, L: Optional[set] = None, S: Optional[set] = None):
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).
A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
Expand Down Expand Up @@ -710,7 +711,7 @@ def valid_mag(G: ADMG, L: set = None, S: set = None):
return True


def dag_to_mag(G, L: Set = None, S: Set = None):
def dag_to_mag(G, L: Optional[Set] = None, S: Optional[Set] = None):
"""Converts a DAG to a valid MAG.
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
Expand Down Expand Up @@ -755,7 +756,6 @@ def dag_to_mag(G, L: Set = None, S: Set = None):
mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

Expand Down Expand Up @@ -787,7 +787,7 @@ def dag_to_mag(G, L: Set = None, S: Set = None):
return mag


def is_maximal(G, L: Set = None, S: Set = None):
def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None):
"""Checks to see if the graph is maximal.
Parameters:
Expand Down
12 changes: 6 additions & 6 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def uncovered_pd_path(


def pds(
graph: PAG, node_x: Node, node_y: Node = None, max_path_length: Optional[int] = None
graph: PAG, node_x: Node, node_y: Optional[Node] = None, max_path_length: Optional[int] = None
) -> Set[Node]:
"""Find all PDS sets between node_x and node_y.
Expand Down Expand Up @@ -712,7 +712,7 @@ def pds_path(
for comp in biconn_comp:
if (node_x, node_y) in comp or (node_y, node_x) in comp:
# add all unique nodes in the biconnected component
for (x, y) in comp:
for x, y in comp:
found_component.add(x)
found_component.add(y)
break
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool:
if graph.has_edge(i, j, graph.undirected_edge_name):
# For all the pairs of nodes adjacent to i,
# look for (k, l), such that j -> l and k -> l
for (k, l) in combinations(graph.neighbors(i), 2):
for k, l in combinations(graph.neighbors(i), 2):
# Skip if k and l are adjacent.
if l in graph.neighbors(k):
continue
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def pag_to_mag(graph):
while flag:
undedges = temp_cpdag.undirected_edges
if len(undedges) != 0:
for (u, v) in undedges:
for u, v in undedges:
temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name)
temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name)
_apply_meek_rules(temp_cpdag)
Expand All @@ -1169,10 +1169,10 @@ def pag_to_mag(graph):

# construct the final MAG

for (u, v) in copy_graph.directed_edges:
for u, v in copy_graph.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

for (u, v) in temp_cpdag.directed_edges:
for u, v in temp_cpdag.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

return mag
2 changes: 1 addition & 1 deletion pywhy_graphs/algorithms/tests/test_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_sigma_separated():
cyclic_G = pywhy_nx.MixedEdgeGraph(graphs=[cyclic_G], edge_types=["directed"])
cyclic_G.add_edge_type(nx.Graph(), edge_type="bidirected")

for (u, v) in combinations(cyclic_G.nodes, 2):
for u, v in combinations(cyclic_G.nodes, 2):
other_nodes = set(cyclic_G.nodes)
other_nodes.remove(u)
other_nodes.remove(v)
Expand Down
4 changes: 0 additions & 4 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def test_convert_to_latent_confounder(graph_func):


def test_inducing_path():

admg = ADMG()

admg.add_edge("X", "Y", admg.directed_edge_name)
Expand Down Expand Up @@ -93,7 +92,6 @@ def test_inducing_path():


def test_inducing_path_wihtout_LandS():

admg = ADMG()

admg.add_edge("X", "Y", admg.directed_edge_name)
Expand All @@ -113,7 +111,6 @@ def test_inducing_path_wihtout_LandS():


def test_inducing_path_one_direction():

admg = ADMG()

admg.add_edge("A", "B", admg.directed_edge_name)
Expand Down Expand Up @@ -375,7 +372,6 @@ def test_valid_mag():


def test_dag_to_mag():

# A -> E -> S
# H -> E , H -> R
admg = ADMG()
Expand Down
5 changes: 2 additions & 3 deletions pywhy_graphs/algorithms/tests/test_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_discriminating_path():
)

for u in pag.nodes:
for (a, c) in permutations(pag.neighbors(u), 2):
for a, c in permutations(pag.neighbors(u), 2):
found_discriminating_path, disc_path, _ = discriminating_path(
pag, u, a, c, max_path_length=100
)
Expand All @@ -193,7 +193,7 @@ def test_discriminating_path():
pag.remove_edge("x2", "x5", pag.directed_edge_name)
pag.add_edge("x5", "x2", pag.bidirected_edge_name)
for u in pag.nodes:
for (a, c) in permutations(pag.neighbors(u), 2):
for a, c in permutations(pag.neighbors(u), 2):
found_discriminating_path, disc_path, _ = discriminating_path(
pag, u, a, c, max_path_length=100
)
Expand Down Expand Up @@ -650,7 +650,6 @@ def test_pdst(pdst_graph):


def test_pag_to_mag():

# C o- A o-> D <-o B
# B o-o A o-o C o-> D

Expand Down
4 changes: 2 additions & 2 deletions pywhy_graphs/array/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Set
from typing import Dict, List, Optional, Set

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_summary_graph(arr: NDArray, arr_enum: str = "clearn"):


def array_to_lagged_links(
arr: NDArray, arr_idx: List[Node] = None, include_weights: bool = True
arr: NDArray, arr_idx: Optional[List[Node]] = None, include_weights: bool = True
) -> Dict[Node, List[Set]]:
"""Convert a time-series 3D array to a dictionary of lagged links.
Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/classes/augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def s_nodes(self) -> List[Node]:
"""Return set of S-nodes."""
return list(self.graph["S-nodes"].keys())

def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None):
def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None):
if isinstance(node_changes, str) or not isinstance(node_changes, Iterable):
raise RuntimeError("The intervention set nodes must be an iterable set of node(s).")

Expand Down
8 changes: 5 additions & 3 deletions pywhy_graphs/classes/timeseries/conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import numpy as np

Expand All @@ -7,7 +7,7 @@
from .graph import StationaryTimeSeriesGraph


def tsgraph_to_numpy(G, var_order: List[Node] = None):
def tsgraph_to_numpy(G, var_order: Optional[List[Node]] = None):
"""Convert stationary timeseries graph to numpy array.
Parameters
Expand Down Expand Up @@ -44,7 +44,9 @@ def tsgraph_to_numpy(G, var_order: List[Node] = None):
return ts_graph_arr


def numpy_to_tsgraph(arr, var_order: List[Node] = None, create_using=StationaryTimeSeriesGraph):
def numpy_to_tsgraph(
arr, var_order: Optional[List[Node]] = None, create_using=StationaryTimeSeriesGraph
):
"""Convert 3D numpy array into a stationary time-series graph.
Parameters
Expand Down
4 changes: 3 additions & 1 deletion pywhy_graphs/classes/timeseries/mixededge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np

import pywhy_graphs.networkx as pywhy_nx
Expand Down Expand Up @@ -178,7 +180,7 @@ class StationaryTimeSeriesMixedEdgeGraph(TimeSeriesMixedEdgeGraph):
# supported graph types
graph_types = (StationaryTimeSeriesGraph, StationaryTimeSeriesDiGraph)

def __init__(self, graphs=None, edge_types=None, max_lag: int = None, **attr):
def __init__(self, graphs=None, edge_types=None, max_lag: Optional[int] = None, **attr):
super().__init__(graphs, edge_types, max_lag=max_lag, **attr)

def set_stationarity(self, stationary: bool):
Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/export/pcalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def graph_to_pcalg(causal_graph):

# now map all values to their respective pcalg values
seen_idx = dict()
for (idx, jdx) in np.argwhere(clearn_arr != 0):
for idx, jdx in np.argwhere(clearn_arr != 0):
if (idx, jdx) in seen_idx or (jdx, idx) in seen_idx:
continue

Expand Down
2 changes: 0 additions & 2 deletions pywhy_graphs/export/tests/test_ananke.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


def dag():

vertices = ["A", "B", "C", "D"]
di_edges = [("A", "B"), ("B", "C"), ("C", "D")]
graph = DAG(vertices=vertices, di_edges=di_edges)
Expand All @@ -19,7 +18,6 @@ def dag():


def admg():

vertices = ["A", "B", "C", "D"]
di_edges = [("A", "B"), ("B", "C"), ("C", "D")]
bi_edges = [("A", "C"), ("B", "D")]
Expand Down
8 changes: 6 additions & 2 deletions pywhy_graphs/functional/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_parent_function(G: nx.DiGraph, node: Node, func: Callable) -> nx.DiGraph


def add_noise_function(
G: nx.DiGraph, node: Node, distr_func: Callable, func: Callable = None
G: nx.DiGraph, node: Node, distr_func: Callable, func: Optional[Callable] = None
) -> nx.DiGraph:
"""Add function and distribution for a node's exogenous variable into the graph.
Expand Down Expand Up @@ -120,7 +120,11 @@ def add_soft_intervention_function(


def add_domain_shift_function(
G: AugmentedGraph, node: Node, s_node: Node, func: Callable = None, distr_func: Callable = None
G: AugmentedGraph,
node: Node,
s_node: Node,
func: Optional[Callable] = None,
distr_func: Optional[Callable] = None,
):
"""Add domain shift function for a node into the graph assuming invariant graph structure.
Expand Down
6 changes: 3 additions & 3 deletions pywhy_graphs/functional/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def parent_func(*args):

def make_random_discrete_graph(
G: nx.DiGraph,
cardinality_lims: Dict[Any, List[int]] = None,
weight_lims: Dict[Any, List[int]] = None,
noise_ratio_lims: List[float] = None,
cardinality_lims: Optional[Dict[Any, List[int]]] = None,
weight_lims: Optional[Dict[Any, List[int]]] = None,
noise_ratio_lims: Optional[List[float]] = None,
overwrite: bool = False,
random_state=None,
) -> nx.DiGraph:
Expand Down
10 changes: 5 additions & 5 deletions pywhy_graphs/functional/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Set
from typing import Callable, List, Optional, Set

import networkx as nx
import numpy as np
Expand All @@ -11,10 +11,10 @@

def make_graph_linear_gaussian(
G: nx.DiGraph,
node_mean_lims: List[float] = None,
node_std_lims: List[float] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_weight_lims: List[float] = None,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
) -> nx.DiGraph:
r"""Convert an existing DAG to a linear Gaussian graphical model.
Expand Down
4 changes: 2 additions & 2 deletions pywhy_graphs/functional/multidomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def make_graph_multidomain(
n_invariances_to_try: int = 1,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
) -> nx.DiGraph:
Expand Down Expand Up @@ -263,7 +263,7 @@ def sample_multidomain_lin_functions(
G: AugmentedGraph,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
):
Expand Down
3 changes: 2 additions & 1 deletion pywhy_graphs/functional/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from typing import Optional

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -217,7 +218,7 @@ def _preprocess_parameter_inputs(
edge_functions,
edge_weight_lims,
multi_domain: bool = False,
n_domains: int = None,
n_domains: Optional[int] = None,
):
"""Helper function to preprocess common parameter inputs for sampling functional graphs.
Expand Down
2 changes: 0 additions & 2 deletions pywhy_graphs/networkx/algorithms/causal/m_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def m_separated(
G_bidirected = G.get_graphs(edge_type=bidirected_edge_name)

while forward_deque or backward_deque:

if backward_deque:
node = backward_deque.popleft()
backward_visited.add(node)
Expand Down Expand Up @@ -151,7 +150,6 @@ def m_separated(
# Consider if *-> node <-* is opened due to conditioning on collider,
# or descendant of collider
if node in an_z:

if has_directed:
# add <- edges to backward deque
for x, _ in G_directed.in_edges(nbunch=node):
Expand Down
Loading

0 comments on commit b151126

Please sign in to comment.