Skip to content

Commit

Permalink
BROKEN testing the requirement of integer data and the forced-rescali…
Browse files Browse the repository at this point in the history
…ng here
  • Loading branch information
dmnapolitano committed Mar 27, 2024
1 parent f714ef4 commit 4edb1fd
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 63 deletions.
5 changes: 0 additions & 5 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def __solve(self, A: np.ndarray, B: np.ndarray, weights: np.ndarray) -> np.ndarr
return transition_matrix.value

def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = None) -> np.ndarray:
self._check_data_type(X)
self._check_data_type(Y)
self._check_any_element_nan_or_inf(X)
self._check_any_element_nan_or_inf(Y)
self._check_for_zero_units(X)
Expand All @@ -86,9 +84,6 @@ def fit(self, X: np.ndarray, Y: np.ndarray, sample_weight: np.ndarray | None = N
if X.shape[0] != Y.shape[0]:
raise ValueError(f"Number of units in X ({X.shape[0]}) != number of units in Y ({Y.shape[0]}).")

X = self._rescale(X)
Y = self._rescale(Y)

weights = self._check_and_prepare_weights(X, Y, sample_weight)

self.coefficients = self.__solve(X, Y, weights)
Expand Down
21 changes: 0 additions & 21 deletions src/elexsolver/TransitionSolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings

import numpy as np

Expand Down Expand Up @@ -59,33 +58,13 @@ def predict(self, X: np.ndarray) -> np.ndarray:
raise RuntimeError("Solver must be fit before prediction can be performed.")
return X @ self.coefficients

def _check_data_type(self, A: np.ndarray):
"""
Make sure we're starting with count data which we'll standardize to percentages
by calling `self._rescale(A)` later.
"""
if not np.all(A.astype("int64") == A):
raise ValueError("Matrix must contain integers.")

def _check_for_zero_units(self, A: np.ndarray):
"""
If we have at least one unit whose columns are all zero, we can't continue.
"""
if np.any(np.sum(A, axis=1) == 0):
raise ValueError("Matrix cannot contain any rows (units) where all columns (things) are zero.")

def _rescale(self, A: np.ndarray) -> np.ndarray:
"""
Rescale rows (units) to ensure they sum to 1 (100%).
"""
A = A.copy().astype(float)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
A = (A.T / A.sum(axis=1)).T

return np.nan_to_num(A, nan=0, posinf=0, neginf=0)

def _check_and_prepare_weights(self, X: np.ndarray, Y: np.ndarray, weights: np.ndarray | None) -> np.ndarray:
"""
If `weights` is not None, and `weights` has the same number of rows in both matrices `X` and `Y`,
Expand Down
37 changes: 0 additions & 37 deletions tests/test_transition_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,43 +68,6 @@ def test_check_for_zero_units_bad():
ts._check_for_zero_units(A) # pylint: disable=protected-access


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_rescale_rescaled_numpy():
A = np.ones((2, 2)).astype(int)
expected = np.array([[0.5, 0.5], [0.5, 0.5]])
ts = TransitionSolver()
np.testing.assert_array_equal(ts._rescale(A), expected) # pylint: disable=protected-access


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_rescale_rescaled_pandas():
try:
import pandas # pylint: disable=import-outside-toplevel

a_df = pandas.DataFrame(np.ones((2, 2)), columns=["A", "B"]).astype(int)
expected_df = pandas.DataFrame([[0.5, 0.5], [0.5, 0.5]], columns=["A", "B"])
ts = TransitionSolver()
np.testing.assert_array_equal(expected_df, ts._rescale(a_df)) # pylint: disable=protected-access
except ImportError:
# pass this test through since pandas isn't a requirement for elex-solver
assert True


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_check_data_type_good():
A = np.array([[1, 2, 3], [4, 5, 6]])
ts = TransitionSolver()
ts._check_data_type(A) # pylint: disable=protected-access


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_check_data_type_bad():
with pytest.raises(ValueError):
A = np.array([[0.1, 0.2, 0.3]])
ts = TransitionSolver()
ts._check_data_type(A) # pylint: disable=protected-access


@patch.object(TransitionSolver, "__abstractmethods__", set())
def test_check_and_prepare_weights_bad():
with pytest.raises(ValueError):
Expand Down

0 comments on commit 4edb1fd

Please sign in to comment.