diff --git a/src/elexsolver/TransitionMatrixSolver.py b/src/elexsolver/TransitionMatrixSolver.py index 03da3e2..95e3125 100644 --- a/src/elexsolver/TransitionMatrixSolver.py +++ b/src/elexsolver/TransitionMatrixSolver.py @@ -71,8 +71,6 @@ def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray) -> np.ndarr return transition_matrix.value def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = None) -> np.ndarray: - self._check_data_type(X) - self._check_data_type(Y) self._check_any_element_nan_or_inf(X) self._check_any_element_nan_or_inf(Y) self._check_for_zero_units(X) @@ -86,9 +84,6 @@ def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = N if X.shape[0] != Y.shape[0]: raise ValueError(f"Number of units in X ({X.shape[0]}) != number of units in Y ({Y.shape[0]}).") - X = self._rescale(X) - Y = self._rescale(Y) - weights = self._check_and_prepare_weights(X, Y, sample_weight) self.coefficients = self.__solve(X, Y, weights) diff --git a/src/elexsolver/TransitionSolver.py b/src/elexsolver/TransitionSolver.py index bb70949..77e2616 100644 --- a/src/elexsolver/TransitionSolver.py +++ b/src/elexsolver/TransitionSolver.py @@ -1,5 +1,4 @@ import logging -import warnings import numpy as np @@ -59,14 +58,6 @@ def predict(self, X: np.ndarray) -> np.ndarray: raise RuntimeError("Solver must be fit before prediction can be performed.") return X @ self.coefficients - def _check_data_type(self, A: np.ndarray): - """ - Make sure we're starting with count data which we'll standardize to percentages - by calling `self._rescale(A)` later. - """ - if not np.all(A.astype("int64") == A): - raise ValueError("Matrix must contain integers.") - def _check_for_zero_units(self, A: np.ndarray): """ If we have at least one unit whose columns are all zero, we can't continue. @@ -74,18 +65,6 @@ def _check_for_zero_units(self, A: np.ndarray): if np.any(np.sum(A, axis=1) == 0): raise ValueError("Matrix cannot contain any rows (units) where all columns (things) are zero.") - def _rescale(self, A: np.ndarray) -> np.ndarray: - """ - Rescale rows (units) to ensure they sum to 1 (100%). - """ - A = A.copy().astype(float) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide") - A = (A.T / A.sum(axis=1)).T - - return np.nan_to_num(A, nan=0, posinf=0, neginf=0) - def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None) -> np.ndarray: """ If `weights` is not None, and `weights` has the same number of rows in both matrices `X` and `Y`, diff --git a/tests/test_transition_solver.py b/tests/test_transition_solver.py index 2adaf03..b7bc9aa 100644 --- a/tests/test_transition_solver.py +++ b/tests/test_transition_solver.py @@ -68,43 +68,6 @@ def test_check_for_zero_units_bad(): ts._check_for_zero_units(A) # pylint: disable=protected-access -@patch.object(TransitionSolver, "__abstractmethods__", set()) -def test_rescale_rescaled_numpy(): - A = np.ones((2, 2)).astype(int) - expected = np.array([[0.5, 0.5], [0.5, 0.5]]) - ts = TransitionSolver() - np.testing.assert_array_equal(ts._rescale(A), expected) # pylint: disable=protected-access - - -@patch.object(TransitionSolver, "__abstractmethods__", set()) -def test_rescale_rescaled_pandas(): - try: - import pandas # pylint: disable=import-outside-toplevel - - a_df = pandas.DataFrame(np.ones((2, 2)), columns=["A", "B"]).astype(int) - expected_df = pandas.DataFrame([[0.5, 0.5], [0.5, 0.5]], columns=["A", "B"]) - ts = TransitionSolver() - np.testing.assert_array_equal(expected_df, ts._rescale(a_df)) # pylint: disable=protected-access - except ImportError: - # pass this test through since pandas isn't a requirement for elex-solver - assert True - - -@patch.object(TransitionSolver, "__abstractmethods__", set()) -def test_check_data_type_good(): - A = np.array([[1, 2, 3], [4, 5, 6]]) - ts = TransitionSolver() - ts._check_data_type(A) # pylint: disable=protected-access - - -@patch.object(TransitionSolver, "__abstractmethods__", set()) -def test_check_data_type_bad(): - with pytest.raises(ValueError): - A = np.array([[0.1, 0.2, 0.3]]) - ts = TransitionSolver() - ts._check_data_type(A) # pylint: disable=protected-access - - @patch.object(TransitionSolver, "__abstractmethods__", set()) def test_check_and_prepare_weights_bad(): with pytest.raises(ValueError):