From fd442b329c07f0def5f51bee24ab99409de8a4bb Mon Sep 17 00:00:00 2001 From: lennybronner Date: Tue, 24 Sep 2024 22:22:57 -0400 Subject: [PATCH] updated tests --- tests/conftest.py | 6 ++++++ tests/test_linear_solver.py | 20 +++++++++++++++++++- tests/test_ols.py | 14 ++++++++++++++ tests/test_quantile.py | 11 +++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ca274686..0271ca64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import sys import pandas as pd +import numpy as np import pytest _TEST_FOLDER = os.path.dirname(__file__) @@ -40,3 +41,8 @@ def random_data_no_weights(get_fixture): @pytest.fixture(scope="session") def random_data_weights(get_fixture): return get_fixture("random_data_n100_p5_12549_weights.csv") + +@pytest.fixture(scope="session") +def rng(): + seed = 8232 + return np.random.default_rng(seed=seed) \ No newline at end of file diff --git a/tests/test_linear_solver.py b/tests/test_linear_solver.py index 98e1f163..4cc06d86 100644 --- a/tests/test_linear_solver.py +++ b/tests/test_linear_solver.py @@ -2,9 +2,27 @@ import pytest from elexsolver.LinearSolver import LinearSolver - +from elexsolver.QuantileRegressionSolver import QuantileRegressionSolver def test_fit(): solver = LinearSolver() with pytest.raises(NotImplementedError): solver.fit(np.ndarray((5, 3)), np.ndarray((1, 3))) + + +################## +# Test residuals # +################## +def test_residuals_without_weights(rng): + x = rng.normal(size=(100, 5)) + beta = rng.normal(size=(5, 1)) + y = x @ beta + + # we need an a subclass of LinearSolver to actually run a fit + reg = QuantileRegressionSolver() + reg.fit(x, y, fit_intercept=False) + reg.predict(x) + + residuals_train = reg.residuals(x, y, K=None, center=False) + residuals_K = reg.residuals(x, y, K=10, center=False) + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/tests/test_ols.py b/tests/test_ols.py index bea6876b..875037e7 100644 --- a/tests/test_ols.py +++ b/tests/test_ols.py @@ -30,6 +30,20 @@ def test_basic2(): preds = lm.predict(x) assert all(np.abs(preds - [6.666667, 6.666667, 6.666667, 15]) <= TOL) +def test_cache(): + lm = OLSRegressionSolver() + x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]]) + y = np.asarray([3, 8, 9, 15]) + lm.fit(x, y, fit_intercept=True, cache=False) + + assert lm.normal_eqs is None + assert lm.hat_vals is None + + lm.fit(x, y, fit_intercept=True, cache=True) + + assert lm.normal_eqs is not None + assert lm.hat_vals is not None + assert lm.coefficients is not None ###################### # Intermediate tests # diff --git a/tests/test_quantile.py b/tests/test_quantile.py index a9386156..7738501e 100644 --- a/tests/test_quantile.py +++ b/tests/test_quantile.py @@ -55,6 +55,17 @@ def test_basic_upper(): preds = quantreg.predict(x) np.testing.assert_array_equal(preds, [[9], [9], [9], [15]]) +def test_cache(): + quantreg = QuantileRegressionSolver() + tau = 0.9 + x = np.asarray([[1, 1], [1, 1], [1, 1], [1, 2]]) + y = np.asarray([3, 8, 9, 15]) + quantreg.fit(x, y, tau, cache=False) + + assert quantreg.coefficients == [] + + quantreg.fit(x, y, tau, cache=True) + assert len(quantreg.coefficients) > 0 ###################### # Intermediate tests #