diff --git a/nigsp/tests/test_metrics.py b/nigsp/tests/test_metrics.py index 8f4da29..3dd5900 100644 --- a/nigsp/tests/test_metrics.py +++ b/nigsp/tests/test_metrics.py @@ -3,7 +3,7 @@ import numpy as np from numpy.random import rand -from pytest import raises +from pytest import raises, warns from nigsp.operations import metrics from nigsp.utils import prepare_ndim_iteration @@ -73,13 +73,22 @@ def test_gsdi(): def test_smoothness(): - signal = rand(10, 2) + s1 = rand(10) + s2 = rand(10, 2) + s3 = rand(2, 10) laplacian = rand(10, 10) - expected_smoothness = np.dot(signal.T, np.dot(laplacian, signal)) - computed_smoothness = metrics.smoothness(laplacian, signal) + expected_smoothness1 = np.dot(s1.T, np.dot(laplacian, s1)) + expected_smoothness2 = np.dot(s2.T, np.dot(laplacian, s2)) + expected_smoothness3 = np.dot(s3, np.dot(laplacian, s3.T)) + + computed_smoothness1 = metrics.smoothness(laplacian, s1) + computed_smoothness2 = metrics.smoothness(laplacian, s2) + computed_smoothness3 = metrics.smoothness(laplacian, s3) - assert (expected_smoothness == computed_smoothness).all() + assert (expected_smoothness1 == computed_smoothness1).all() + assert (expected_smoothness2 == computed_smoothness2).all() + assert (expected_smoothness3 == computed_smoothness3).all() # ### Break tests