-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from KlugerLab/10-check-gene_expression-matrix…
…-is-positive 10 check gene expression matrix is positive
- Loading branch information
Showing
8 changed files
with
138 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
|
||
def validate_not_none(obj, obj_name: str = 'input'): | ||
if obj is None: | ||
raise ValueError(f"{obj_name} is None") | ||
|
||
|
||
def validate_matrix( | ||
m: np.array, | ||
obj_name: str = 'input', | ||
nrows: Optional[int] = None, | ||
ncols: Optional[int] = None, | ||
shape: Optional[tuple[int, int]] = None, | ||
square: Optional[bool] = None, | ||
min_size: Optional[int] = 1, | ||
min_value: Optional = None, | ||
max_value: Optional = None, | ||
): | ||
""" | ||
Validates an input matrix | ||
@param m: the input matrix | ||
@param obj_name: the name of the object, for error reporting | ||
@param min_size: Minimum matrix size in each dimension. Defaults to 1, and will raise for an empty matrix | ||
@param nrows: Number of rows in the matrix | ||
@param ncols: Number of rows in the matrix | ||
@param shape: the expected shape of the matrix | ||
@param square: If True, ensures the matrix is square. If False, ensures the matrix is not | ||
@param min_value: Minimum value for each element | ||
@param max_value: Maximum value for each element | ||
""" | ||
validate_not_none(m, obj_name=obj_name) | ||
|
||
if len(m.shape) != 2: | ||
raise ValueError(f"{obj_name} is not a matrix. Shape: {m.shape}") | ||
mr, mc = m.shape | ||
|
||
if nrows is not None: | ||
if mr != nrows: | ||
raise ValueError(f"{obj_name} does not have {nrows} rows. Shape: {m.shape}") | ||
|
||
if ncols is not None: | ||
if mc != ncols: | ||
raise ValueError(f"{obj_name} does not have {ncols} columns. Shape: {m.shape}") | ||
|
||
if shape is not None: | ||
if m.shape != shape: | ||
raise ValueError(f"{obj_name} does not have shape {shape}. Shape: {m.shape}") | ||
|
||
if square is True: | ||
if mr != mc: | ||
raise ValueError(f"{obj_name} is not a square matrix. Shape: {m.shape}") | ||
elif square is False: | ||
if mr == mc: | ||
raise ValueError(f"{obj_name} is a square matrix. Shape: {m.shape}") | ||
|
||
if min_size is not None: | ||
for s in m.shape: | ||
if s < min_size: | ||
raise ValueError(f"{obj_name} does not have enough elements. Min_size: {min_size}, Shape: {m.shape}") | ||
|
||
if min_value is not None: | ||
if m.min() < min_value: | ||
raise ValueError(f"{obj_name} should not have values less than {min_value}. Minimum found: {m.min()}") | ||
|
||
if max_value is not None: | ||
if m.max() > max_value: | ||
raise ValueError(f"{obj_name} should not have values greater than {max_value}. Maximum found: {m.max()}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from gene_trajectory.util.input_validation import validate_matrix | ||
|
||
|
||
class InputValidationTestCase(unittest.TestCase): | ||
def test_validate_matrix(self): | ||
m = np.array([[1, 2], [3, 4]]) | ||
|
||
validate_matrix(m, min_value=1, max_value=4, square=True, shape=(2, 2)) | ||
|
||
with self.assertRaisesRegexp(ValueError, '.*does not have 3 rows.*'): | ||
validate_matrix(m, nrows=3) | ||
with self.assertRaisesRegexp(ValueError, '.*does not have 8 columns.*'): | ||
validate_matrix(m, ncols=8) | ||
with self.assertRaisesRegexp(ValueError, '.*does not have shape \\(1, 1\\)'): | ||
validate_matrix(m, shape=(1, 1)) | ||
with self.assertRaisesRegexp(ValueError, '.*Min_size: 3.*'): | ||
validate_matrix(m, min_size=3) | ||
with self.assertRaisesRegexp(ValueError, '.*should not have values less than 5.*'): | ||
validate_matrix(m, min_value=5) | ||
with self.assertRaisesRegexp(ValueError, '.*should not have values greater than 1.*'): | ||
validate_matrix(m, max_value=1) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |