Skip to content

Commit

Permalink
Merge pull request #645 from Sichao25/pp
Browse files Browse the repository at this point in the history
Refactor `fate()` with Trajectory class
  • Loading branch information
Xiaojieqiu authored Feb 20, 2024
2 parents dcd96b5 + 0eba49c commit e401e9b
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 222 deletions.
128 changes: 71 additions & 57 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,63 +139,8 @@ def fate(
)

exprs = None
if basis == "pca" and inverse_transform:
Qkey = "PCs"
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
high_p_n = exprs[0].shape[1]
else:
exprs = vector_transformation(prediction.T, adata.uns[Qkey])
high_p_n = exprs.shape[1]

if adata.var.use_for_dynamics.sum() == high_p_n:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
else:
valid_genes = adata.var_names[adata.var.use_for_transition]

elif basis == "umap" and inverse_transform:
# this requires umap 0.4; reverse project to PCA space.
if hasattr(prediction, "ndim"):
if prediction.ndim == 1:
prediction = prediction[None, :]

params = adata.uns["umap_fit"]
umap_fit = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)

PCs = adata.uns["PCs"].T
exprs = []

for cur_pred in prediction:
expr = umap_fit.inverse_transform(cur_pred.T)

# further reverse project back to raw expression space
if PCs.shape[0] == expr.shape[1]:
expr = np.expm1(expr @ PCs + adata.uns["pca_mean"])

exprs.append(expr)

if adata.var.use_for_dynamics.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
elif adata.var.use_for_transition.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_transition]
else:
raise Exception(
"looks like a customized set of genes is used for pca analysis of the adata. "
"Try rerunning pca analysis with default settings for this function to work."
)
if inverse_transform:
exprs, valid_genes = _inverse_transform(adata=adata, prediction=prediction, basis=basis, Qkey=Qkey)

adata.uns[fate_key] = {
"init_states": init_states,
Expand Down Expand Up @@ -309,6 +254,75 @@ def _fate(
return t, prediction


def _inverse_transform(
adata: AnnData,
prediction: Union[np.ndarray, List[np.ndarray]],
basis: str = "umap",
Qkey: str = "PCs",
) -> Tuple[Union[np.ndarray, List[np.ndarray]], np.ndarray]:
"""Inverse transform the low dimensional vector field prediction back to high dimensional space."""
if basis == "pca":
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
high_p_n = exprs[0].shape[1]
else:
exprs = vector_transformation(prediction.T, adata.uns[Qkey])
high_p_n = exprs.shape[1]

if adata.var.use_for_dynamics.sum() == high_p_n:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
else:
valid_genes = adata.var_names[adata.var.use_for_transition]

elif basis == "umap":
# this requires umap 0.4; reverse project to PCA space.
if hasattr(prediction, "ndim"):
if prediction.ndim == 1:
prediction = prediction[None, :]

params = adata.uns["umap_fit"]
umap_fit = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)

PCs = adata.uns[Qkey].T
exprs = []

for cur_pred in prediction:
expr = umap_fit.inverse_transform(cur_pred.T)

# further reverse project back to raw expression space
if PCs.shape[0] == expr.shape[1]:
expr = np.expm1(expr @ PCs + adata.uns["pca_mean"])

exprs.append(expr)

if adata.var.use_for_dynamics.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
elif adata.var.use_for_transition.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_transition]
else:
raise Exception(
"looks like a customized set of genes is used for pca analysis of the adata. "
"Try rerunning pca analysis with default settings for this function to work."
)
else:
raise ValueError(f"Inverse transform with basis {basis} is not supported.")

return exprs, valid_genes


def fate_bias(
adata: AnnData,
group: str,
Expand Down
4 changes: 2 additions & 2 deletions dynamo/prediction/least_action_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
vector_field_function_transformation,
vector_transformation,
)
from .trajectory import GeneTrajectory, Trajectory
from .utils import arclength_sampling_n, find_elbow
from .trajectory import arclength_sampling_n, GeneTrajectory, Trajectory
from .utils import find_elbow


class LeastActionPath(Trajectory):
Expand Down
7 changes: 2 additions & 5 deletions dynamo/prediction/state_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
from ..tools.Markov import DiscreteTimeMarkovChain
from ..tools.utils import fetch_states
from ..vectorfield import vector_field_function
from .utils import (
arclength_sampling,
integrate_streamline,
remove_redundant_points_trajectory,
)
from .trajectory import arclength_sampling, remove_redundant_points_trajectory
from .utils import integrate_streamline

# from sklearn.preprocessing import OrdinalEncoder

Expand Down
178 changes: 177 additions & 1 deletion dynamo/prediction/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Callable, List, Tuple, Union

import numpy as np
import scipy
from scipy.interpolate import interp1d

from ..dynamo_logger import LoggerManager
from ..tools.utils import flatten
from ..utils import expr_to_pca, pca_to_expr
from ..vectorfield.scVectorField import DifferentiableVectorField
from ..vectorfield.topography import dup_osc_idx_iter
from ..vectorfield.utils import angle, normalize_vectors
from .utils import arclength_sampling_n


class Trajectory:
Expand Down Expand Up @@ -141,6 +142,89 @@ def resample(self, n_points: int, tol: float = 1e-4, inplace: bool = True) -> Tu

return X, t

def archlength_sampling(
self,
sol: scipy.integrate._ivp.common.OdeSolution,
interpolation_num: int,
integration_direction: str,
):
"""Sample the curve using archlength sampling.
Args:
sol: The ODE solution from scipy.integrate.solve_ivp.
interpolation_num: The number of points to interpolate the curve at.
integration_direction: The direction to integrate the curve in. Can be "forward", "backward", or "both".
"""
tau, x = self.t, self.X.T
idx = dup_osc_idx_iter(x, max_iter=100, tol=x.ptp(0).mean() / 1000)[0]

# idx = dup_osc_idx_iter(x)
x = x[:idx]
_, arclen, _ = remove_redundant_points_trajectory(x, tol=1e-4, output_discard=True)
cur_Y, alen, self.t = arclength_sampling_n(x, num=interpolation_num+1, t=tau[:idx])
self.t = self.t[1:]
cur_Y = cur_Y[:, 1:]

if integration_direction == "both":
neg_t_len = sum(np.array(self.t) < 0)

self.X = (
sol(self.t)
if integration_direction != "both"
else np.hstack(
(
sol[0](self.t[:neg_t_len]),
sol[1](self.t[neg_t_len:]),
)
)
)

def logspace_sampling(
self,
sol: scipy.integrate._ivp.common.OdeSolution,
interpolation_num: int,
integration_direction: str,
):
"""Sample the curve using logspace sampling.
Args:
sol: The ODE solution from scipy.integrate.solve_ivp.
interpolation_num: The number of points to interpolate the curve at.
integration_direction: The direction to integrate the curve in. Can be "forward", "backward", or "both".
"""
tau, x = self.t, self.X.T
neg_tau, pos_tau = tau[tau < 0], tau[tau >= 0]

if len(neg_tau) > 0:
t_0, t_1 = (
-(
np.logspace(
0,
np.log10(abs(min(neg_tau)) + 1),
interpolation_num,
)
)
- 1,
np.logspace(0, np.log10(max(pos_tau) + 1), interpolation_num) - 1,
)
self.t = np.hstack((t_0[::-1], t_1))
else:
self.t = np.logspace(0, np.log10(max(tau) + 1), interpolation_num) - 1

if integration_direction == "both":
neg_t_len = sum(np.array(self.t) < 0)

self.X = (
sol(self.t)
if integration_direction != "both"
else np.hstack(
(
sol[0](self.t[:neg_t_len]),
sol[1](self.t[neg_t_len:]),
)
)
)

def interpolate(self, t: np.ndarray, **interp_kwargs) -> np.ndarray:
"""Interpolate the curve at new time values.
Expand Down Expand Up @@ -427,3 +511,95 @@ def select_gene(self, genes, arr=None, axis=None):
raise Exception("Cannot select genes since `self.genes` is `None`.")

return np.array(y)


def arclength_sampling_n(X, num, t=None):
arclen = np.cumsum(np.linalg.norm(np.diff(X, axis=0), axis=1))
arclen = np.hstack((0, arclen))

z = np.linspace(arclen[0], arclen[-1], num)
X_ = interp1d(arclen, X, axis=0)(z)
if t is not None:
t_ = interp1d(arclen, t)(z)
return X_, arclen[-1], t_
else:
return X_, arclen[-1]


def remove_redundant_points_trajectory(X, tol=1e-4, output_discard=False):
"""remove consecutive data points that are too close to each other."""
X = np.atleast_2d(X)
discard = np.zeros(len(X), dtype=bool)
if X.shape[0] > 1:
for i in range(len(X) - 1):
dist = np.linalg.norm(X[i + 1] - X[i])
if dist < tol:
discard[i + 1] = True
X = X[~discard]

arclength = 0

x0 = X[0]
for i in range(1, len(X)):
tangent = X[i] - x0 if i == 1 else X[i] - X[i - 1]
d = np.linalg.norm(tangent)

arclength += d

if output_discard:
return (X, arclength, discard)
else:
return (X, arclength)


def arclength_sampling(X, step_length, n_steps: int, t=None):
"""uniformly sample data points on an arc curve that generated from vector field predictions."""
Y = []
x0 = X[0]
T = [] if t is not None else None
t0 = t[0] if t is not None else None
i = 1
terminate = False
arclength = 0

def _calculate_new_point():
x = x0 if j == i else X[j - 1]
cur_y = x + (step_length - L) * tangent / d

if t is not None:
cur_tau = t0 if j == i else t[j - 1]
cur_tau += (step_length - L) / d * (t[j] - cur_tau)
T.append(cur_tau)
else:
cur_tau = None

Y.append(cur_y)

return cur_y, cur_tau

while i < len(X) - 1 and not terminate:
L = 0
for j in range(i, len(X)):
tangent = X[j] - x0 if j == i else X[j] - X[j - 1]
d = np.linalg.norm(tangent)
if L + d >= step_length:
y, tau = _calculate_new_point()
t0 = tau if t is not None else None
x0 = y
i = j
break
else:
L += d
if j == len(X) - 1:
i += 1
arclength += step_length
if L + d < step_length:
terminate = True

if len(Y) < n_steps:
_, _ = _calculate_new_point()

if T is not None:
return np.array(Y), arclength, T
else:
return np.array(Y), arclength
Loading

0 comments on commit e401e9b

Please sign in to comment.