Skip to content

Commit

Permalink
Update ml and carnival
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed Dec 21, 2024
1 parent 3959914 commit d25e3cd
Show file tree
Hide file tree
Showing 4 changed files with 915 additions and 510 deletions.
77 changes: 76 additions & 1 deletion corneto/_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from corneto._graph import BaseGraph
from collections import deque

import numpy as np
try:
import pandas as pd
_PANDAS_AVAILABLE = True
except ImportError:
_PANDAS_AVAILABLE = False

def _load_keras():
# Check if keras_core is installed,
# then find the best backend, prioritizing JAX then TF then Pytorch:
Expand All @@ -18,11 +25,79 @@ def _load_keras():
except ImportError as e:
raise e

def kfold_nonzero_splits(
data,
n_splits: int = 5,
shuffle: bool = True,
random_state: int = 42
):
from sklearn.model_selection import KFold
"""
Perform K-fold splitting on all nonzero cells in the input data,
returning two structures per fold (train and val), each the same shape
as the original data, but with NaNs (or a numpy.nan equivalent) marking
the 'left-out' entries.
Parameters
----------
data : array-like or pd.DataFrame
Input data with shape (features x samples) and values in {-1, 0, 1}.
Can be a numpy array or a pandas DataFrame.
n_splits : int, optional
Number of folds (default=5).
shuffle : bool, optional
Shuffle the labeled cells before splitting (default=True).
random_state : int, optional
Random seed for reproducibility (default=42).
Returns
------
train, val : tuple of the same type as `data`
- Both have the same shape as `data`.
- In `train`, all cells that belong to the validation fold are set to NaN.
- In `val`, all cells that are not in the validation fold are set to NaN.
"""
if _PANDAS_AVAILABLE and isinstance(data, pd.DataFrame):
is_pandas = True
arr = data.to_numpy()
else:
is_pandas = False
arr = np.asarray(data)

# Identify positions of nonzero (±1) cells
row_indices, col_indices = np.where(arr != 0)
labeled_positions = np.array(list(zip(row_indices, col_indices)))

# Use KFold to split these labeled positions
kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)

for train_idx, val_idx in kf.split(labeled_positions):
train_copy = arr.copy()
val_copy = arr.copy()

# Set NaNs for validation fold in train_copy
val_positions = labeled_positions[val_idx]
for r, c in val_positions:
train_copy[r, c] = np.nan

# Set NaNs for training fold in val_copy
train_positions = labeled_positions[train_idx]
for r, c in train_positions:
val_copy[r, c] = np.nan

if is_pandas:
train_copy = pd.DataFrame(train_copy, index=data.index, columns=data.columns)
val_copy = pd.DataFrame(val_copy, index=data.index, columns=data.columns)

yield train_copy, val_copy


def toposort(G):
# Topological sort using Kahn's algorithm
# See: https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.dag.topological_sort.html
in_degree = {v: len(set(G.predecessors(v))) for v in G._get_vertices()}

# Initialize queue with nodes having zero in-degree
# Initialize queue with nodes having zero in-degrees
queue = deque([v for v in in_degree.keys() if in_degree[v] == 0])

result = []
Expand Down
36 changes: 32 additions & 4 deletions corneto/methods/carnival.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_interactions,
)
from corneto.methods.signaling import create_flow_graph, signflow

import warnings

def _info(s, show=True):
if show:
Expand Down Expand Up @@ -491,6 +491,8 @@ def create_flow_carnival_v4(
upper_bound_flow=1000,
penalty_on="signal", # or "flow"
slack_reg=False,
set_perturbation_values=True,
fix_input_values=True,
backend=cn.DEFAULT_BACKEND,
):
At, Ah = get_incidence_matrices_of_edges(G)
Expand Down Expand Up @@ -572,13 +574,39 @@ def create_flow_carnival_v4(
[G.V.index(v) for exp in exp_list for v in exp_list[exp]["input"]]
)
perturbation_values = np.array(
[val for exp in exp_list for val in exp_list[exp]["input"].values()]
[int(np.sign(val)) for exp in exp_list for val in exp_list[exp]["input"].values()]
)

# Set the perturbations to the given values
P += V[vertex_indexes, :] == perturbation_values[:, None]
if set_perturbation_values:
warnings.warn("Using set_perturbation_values, please disable since behavior differs from original carnival")
nonzero_mask = perturbation_values != 0
nonzero_vertex_indexes = vertex_indexes[nonzero_mask]
nonzero_perturbation_values = perturbation_values[nonzero_mask]
# Assign the perturbations only to the nonzero ones
P += V[nonzero_vertex_indexes, :] == nonzero_perturbation_values[:, None]

all_vertices = G.V
all_inputs = [k for k in all_vertices if len(list(G.predecessors(k)))==0]

for i, exp in enumerate(exp_list):
# measuremenents:
# Any input not indicated in the condition must be blocked
if not set_perturbation_values:
# Block flow from any input not in the set of valid inputs
# for the given condition
m_inputs = list(exp_list[exp]["input"].keys())
for v_input in all_inputs:
if v_input not in m_inputs:
P += V[all_vertices.index(v_input), i] == 0
else:
input_value = int(exp_list[exp]["input"][v_input])
if input_value != 0:
if fix_input_values:
if input_value == -1 or input_value == 1:
P += V[all_vertices.index(v_input), i] == input_value
else:
raise ValueError(f"Invalid value for input vertex {v_input}: {input_value} (only -1, 0 or 1)")

m_nodes = list(exp_list[exp]["output"].keys())
m_values = np.array(list(exp_list[exp]["output"].values()))
m_nodes_positions = [G.V.index(key) for key in m_nodes]
Expand Down
Loading

0 comments on commit d25e3cd

Please sign in to comment.