Skip to content

Commit

Permalink
Add shape check to Dataset initialization (#106)
Browse files Browse the repository at this point in the history
* Check input array shapes when Dataset is initialized

* Add test for nimare.utils._check_inputs_shape

* Cover two missing lines by the test

* Add test when n or v is None

* Remove extra clause
  • Loading branch information
JulioAPeraza authored Jun 12, 2022
1 parent 0a99840 commit 19af399
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
5 changes: 5 additions & 0 deletions .zenodo.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
"affiliation": "Florida International University",
"orcid": "0000-0001-9813-3167"
},
{
"name": "Peraza, Julio A.",
"affiliation": "Florida International University",
"orcid": "0000-0003-3816-5903"
},
{
"name": "Nichols, Thomas E.",
"affiliation": "Big Data Institute, University of Oxford",
Expand Down
6 changes: 5 additions & 1 deletion pymare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from pymare.utils import _listify
from pymare.utils import _check_inputs_shape, _listify

from .estimators import (
DerSimonianLaird,
Expand Down Expand Up @@ -94,6 +94,10 @@ def __init__(
self.X = X
self.X_names = names

_check_inputs_shape(self.y, self.X, "y", "X", row=True)
_check_inputs_shape(self.y, self.v, "y", "v", row=True, column=True)
_check_inputs_shape(self.y, self.n, "y", "n", row=True, column=True)

def _get_predictors(self, X, names, add_intercept):
if X is None and not add_intercept:
raise ValueError(
Expand Down
30 changes: 30 additions & 0 deletions pymare/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
"""Tests for pymare.utils."""
import os.path as op

import numpy as np
import pytest

from pymare import utils


def test_get_resource_path():
"""Test nimare.utils.get_resource_path."""
print(utils.get_resource_path())
assert op.isdir(utils.get_resource_path())


def test_check_inputs_shape():
"""Test nimare.utils._check_inputs_shape."""
n_rows = 5
n_columns = 4
n_pred = 3
y = np.random.randint(1, 100, size=(n_rows, n_columns))
v = np.random.randint(1, 100, size=(n_rows + 1, n_columns))
n = np.random.randint(1, 100, size=(n_rows, n_columns))
X = np.random.randint(1, 100, size=(n_rows, n_pred))
X_names = [f"X{x}" for x in range(n_pred)]

utils._check_inputs_shape(y, X, "y", "X", row=True)
utils._check_inputs_shape(y, n, "y", "n", row=True, column=True)
utils._check_inputs_shape(X, np.array(X_names)[None, :], "X", "X_names", column=True)

# Raise error if the number of rows and columns of v don't match y
with pytest.raises(ValueError):
utils._check_inputs_shape(y, v, "y", "v", row=True, column=True)

# Raise error if neither row or column is True
with pytest.raises(ValueError):
utils._check_inputs_shape(y, n, "y", "n")

# Dataset may be initialized with n or v as None
utils._check_inputs_shape(y, None, "y", "n", row=True, column=True)
36 changes: 36 additions & 0 deletions pymare/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,39 @@ def _listify(obj):
This provides a simple way to accept flexible arguments.
"""
return obj if isinstance(obj, (list, tuple, type(None), np.ndarray)) else [obj]


def _check_inputs_shape(param1, param2, param1_name, param2_name, row=False, column=False):
"""Check whether 'param1' and 'param2' have the same shape.
Parameters
----------
param1 : array
param2 : array
param1_name : str
param2_name : str
row : bool, default to False.
column : bool, default to False.
"""
if (param1 is not None) and (param2 is not None):
if row and not column:
shape1 = param1.shape[0]
shape2 = param2.shape[0]
message = "rows"
elif column and not row:
shape1 = param1.shape[1]
shape2 = param2.shape[1]
message = "columns"
elif row and column:
shape1 = param1.shape
shape2 = param2.shape
message = "rows and columns"
else:
raise ValueError("At least one of the two parameters (row or column) should be True.")

if shape1 != shape2:
raise ValueError(
f"{param1_name} and {param2_name} should have the same number of {message}. "
f"You provided {param1_name} with shape {param1.shape} and {param2_name} "
f"with shape {param2.shape}."
)

0 comments on commit 19af399

Please sign in to comment.