Skip to content

Commit

Permalink
Merge pull request #11 from KlugerLab/10-check-gene_expression-matrix…
Browse files Browse the repository at this point in the history
…-is-positive

10 check gene expression matrix is positive
  • Loading branch information
fra-pcmgf authored May 21, 2024
2 parents 35184f0 + 9989cff commit 2db6b93
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 8 deletions.
6 changes: 6 additions & 0 deletions gene_trajectory/coarse_grain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import scanpy as sc
from sklearn.cluster import KMeans

from gene_trajectory.util.input_validation import validate_matrix


def select_top_genes(
adata: sc.AnnData,
Expand Down Expand Up @@ -64,6 +66,10 @@ def coarse_grain(
:param random_seed: the random seed
:return: the updated cell embedding and gene expression matrices
"""
validate_matrix(gene_expression, obj_name='Gene Expression Matrix', min_value=0)
ncells, ngenes = gene_expression.shape
validate_matrix(cell_embedding, obj_name='Cell embedding', nrows=ncells)

if cluster is None:
k_means = KMeans(n_clusters=n, random_state=random_seed).fit(cell_embedding)
cluster = k_means.labels_ # noqa
Expand Down
7 changes: 6 additions & 1 deletion gene_trajectory/diffusion_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union
import numpy as np

from gene_trajectory.util.input_validation import validate_matrix


def diffusion_map(
dist_mat: np.array,
Expand All @@ -19,6 +21,8 @@ def diffusion_map(
:param t: Number of diffusion times
:return: the diffusion embedding and the eigenvalues
"""
validate_matrix(dist_mat, square=True)

affinity_matrix_symm = get_symmetrized_affinity_matrix(dist_mat=dist_mat, k=k, sigma=sigma)
normalized_vec = np.sqrt(1 / affinity_matrix_symm.sum(axis=1))
affinity_matrix_norm = (affinity_matrix_symm * normalized_vec * normalized_vec[:, None])
Expand Down Expand Up @@ -50,7 +54,8 @@ def get_symmetrized_affinity_matrix(
:return:
"""
assert dist_mat.shape[0] == dist_mat.shape[1]
validate_matrix(dist_mat, square=True)

dists = np.nan_to_num(dist_mat, 1e-6) # noqa
k = min(k, dist_mat.shape[0])

Expand Down
9 changes: 8 additions & 1 deletion gene_trajectory/extract_gene_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from gene_trajectory.diffusion_map import diffusion_map, get_symmetrized_affinity_matrix
from gene_trajectory.util.input_validation import validate_matrix

logger = logging.getLogger()

Expand All @@ -28,6 +29,8 @@ def get_gene_embedding(
:param t: Number of diffusion times
:return: the diffusion embedding and the eigenvalues
"""
validate_matrix(dist_mat, square=True)

k = min(k, dist_mat.shape[0])
n_ev = min(n_ev + 1, dist_mat.shape[0])
diffu_emb, eigen_vals = diffusion_map(dist_mat=dist_mat, k=k, sigma=sigma, n_ev=n_ev, t=t)
Expand All @@ -47,6 +50,8 @@ def get_randow_walk_matrix(
:param k: Adaptive kernel bandwidth for each point set to be the distance to its `K`-th nearest neighbor
:return: Random-walk matrix
"""
validate_matrix(dist_mat, square=True)

affinity_matrix_symm = get_symmetrized_affinity_matrix(dist_mat=dist_mat, k=k)
normalized_vec = 1 / affinity_matrix_symm.sum(axis=1)
affinity_matrix_norm = (affinity_matrix_symm * normalized_vec[:, None])
Expand All @@ -67,7 +72,7 @@ def get_gene_pseudoorder(
:param max_id: Index of the terminal gene
:return: The pseudoorder
"""
assert dist_mat.shape[0] == dist_mat.shape[1]
validate_matrix(dist_mat, square=True)

emd = dist_mat[subset][:, subset]
dm_emb, _ = diffusion_map(emd)
Expand Down Expand Up @@ -108,6 +113,8 @@ def extract_gene_trajectory(
:param other: Label for genes not in a trajectory. Default: 'Other'
:return: A data frame indicating gene trajectories and gene ordering along each trajectory
"""
validate_matrix(dist_mat, square=True)

if np.isscalar(t_list):
if n is None:
raise ValueError(f'n should be specified if t_list is a number: {t_list}')
Expand Down
13 changes: 8 additions & 5 deletions gene_trajectory/gene_distance_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ot
from tqdm import tqdm

from gene_trajectory.util.input_validation import validate_matrix
from gene_trajectory.util.shared_array import SharedArray, PartialStarApply

logger = logging.getLogger()
Expand Down Expand Up @@ -42,18 +43,20 @@ def cal_ot_mat(
:return: the distance matrix
"""
processes = int(processes) if isinstance(processes, float) else os.cpu_count()
n = gene_expr.shape[1]
validate_matrix(gene_expr, obj_name='Gene Expression Matrix', min_value=0)
ncells, ngenes = gene_expr.shape
validate_matrix(ot_cost, obj_name='Cost Matrix', shape=(ncells, ncells), min_value=0)

if show_progress_bar:
logger.info(f'Computing emd distance..')

if gene_pairs is None:
pairs = ((i, j) for i in range(0, n - 1) for j in range(i + 1, n))
npairs = (n * (n - 1)) // 2
pairs = ((i, j) for i in range(0, ngenes - 1) for j in range(i + 1, ngenes))
npairs = (ngenes * (ngenes - 1)) // 2
else:
pairs = gene_pairs
npairs = len(gene_pairs)

emd_mat = np.full((n, n), fill_value=np.NaN)
emd_mat = np.full((ngenes, ngenes), fill_value=np.NaN)

with SharedMemoryManager() as manager:
start_time = time.perf_counter()
Expand Down
70 changes: 70 additions & 0 deletions gene_trajectory/util/input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional

import numpy as np


def validate_not_none(obj, obj_name: str = 'input'):
if obj is None:
raise ValueError(f"{obj_name} is None")


def validate_matrix(
m: np.array,
obj_name: str = 'input',
nrows: Optional[int] = None,
ncols: Optional[int] = None,
shape: Optional[tuple[int, int]] = None,
square: Optional[bool] = None,
min_size: Optional[int] = 1,
min_value: Optional = None,
max_value: Optional = None,
):
"""
Validates an input matrix
@param m: the input matrix
@param obj_name: the name of the object, for error reporting
@param min_size: Minimum matrix size in each dimension. Defaults to 1, and will raise for an empty matrix
@param nrows: Number of rows in the matrix
@param ncols: Number of rows in the matrix
@param shape: the expected shape of the matrix
@param square: If True, ensures the matrix is square. If False, ensures the matrix is not
@param min_value: Minimum value for each element
@param max_value: Maximum value for each element
"""
validate_not_none(m, obj_name=obj_name)

if len(m.shape) != 2:
raise ValueError(f"{obj_name} is not a matrix. Shape: {m.shape}")
mr, mc = m.shape

if nrows is not None:
if mr != nrows:
raise ValueError(f"{obj_name} does not have {nrows} rows. Shape: {m.shape}")

if ncols is not None:
if mc != ncols:
raise ValueError(f"{obj_name} does not have {ncols} columns. Shape: {m.shape}")

if shape is not None:
if m.shape != shape:
raise ValueError(f"{obj_name} does not have shape {shape}. Shape: {m.shape}")

if square is True:
if mr != mc:
raise ValueError(f"{obj_name} is not a square matrix. Shape: {m.shape}")
elif square is False:
if mr == mc:
raise ValueError(f"{obj_name} is a square matrix. Shape: {m.shape}")

if min_size is not None:
for s in m.shape:
if s < min_size:
raise ValueError(f"{obj_name} does not have enough elements. Min_size: {min_size}, Shape: {m.shape}")

if min_value is not None:
if m.min() < min_value:
raise ValueError(f"{obj_name} should not have values less than {min_value}. Minimum found: {m.min()}")

if max_value is not None:
if m.max() > max_value:
raise ValueError(f"{obj_name} should not have values greater than {max_value}. Maximum found: {m.max()}")
2 changes: 1 addition & 1 deletion tests/test_compute_gene_distance_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gene_trajectory.compute_gene_distance_cmd import cal_ot_mat


class DiffusionMapTestCase(unittest.TestCase):
class ComputeGeneDistanceTestCase(unittest.TestCase):
gdm = np.array([
[0, 1, 2],
[1, 0, 2],
Expand Down
10 changes: 10 additions & 0 deletions tests/test_gene_distance_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def test_gene_distance_shared(self):
mt = cal_ot_mat(ot_cost=self.gdm, gene_expr=self.gem.T, show_progress_bar=False)
np.testing.assert_almost_equal(self.expected_emd, mt, 6)

def test_gene_distance_input_validation(self):
with self.assertRaisesRegexp(ValueError, 'Cost Matrix does not have shape.*'):
cal_ot_mat(ot_cost=self.gdm, gene_expr=np.ones(shape=(6, 3)), show_progress_bar=False)

with self.assertRaisesRegexp(ValueError, 'Cost Matrix does not have shape.*'):
cal_ot_mat(ot_cost=np.ones(shape=(6, 3)), gene_expr=self.gem.T, show_progress_bar=False)

with self.assertRaisesRegexp(ValueError, 'Gene Expression Matrix should not have values less than 0.*'):
cal_ot_mat(ot_cost=np.ones(shape=(6, 3)), gene_expr=self.gem.T - 1, show_progress_bar=False)

def test_cal_ot_mat_gene_pairs(self):
exp = self.expected_emd.copy()
exp[0, 2] = exp[2, 0] = 900
Expand Down
29 changes: 29 additions & 0 deletions tests/test_input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

import numpy as np

from gene_trajectory.util.input_validation import validate_matrix


class InputValidationTestCase(unittest.TestCase):
def test_validate_matrix(self):
m = np.array([[1, 2], [3, 4]])

validate_matrix(m, min_value=1, max_value=4, square=True, shape=(2, 2))

with self.assertRaisesRegexp(ValueError, '.*does not have 3 rows.*'):
validate_matrix(m, nrows=3)
with self.assertRaisesRegexp(ValueError, '.*does not have 8 columns.*'):
validate_matrix(m, ncols=8)
with self.assertRaisesRegexp(ValueError, '.*does not have shape \\(1, 1\\)'):
validate_matrix(m, shape=(1, 1))
with self.assertRaisesRegexp(ValueError, '.*Min_size: 3.*'):
validate_matrix(m, min_size=3)
with self.assertRaisesRegexp(ValueError, '.*should not have values less than 5.*'):
validate_matrix(m, min_value=5)
with self.assertRaisesRegexp(ValueError, '.*should not have values greater than 1.*'):
validate_matrix(m, max_value=1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2db6b93

Please sign in to comment.