diff --git a/tests/test_recall.py b/tests/test_recall.py index 022d1ee..f37d32b 100644 --- a/tests/test_recall.py +++ b/tests/test_recall.py @@ -65,12 +65,20 @@ def test_recall_function(self): """Teest of ts_recall function. """ - real = np.array([1, 1, 0, 0, 0]) - pred = np.array([1, 1, 1, 1, 0]) + # test case1 + real = np.array([1, 0, 0, 0, 0]) + pred = np.array([1, 1, 0, 0, 0]) score = ts_recall(real, pred) self.assertEqual(score, 1.0) + # test case2 + real = np.array([1, 1, 0, 0, 0]) + pred = np.array([0, 0, 1, 1, 1]) + + score = ts_recall(real, pred) + self.assertEqual(score, 0.0) + def test_recall_function_with_list(self): """Teest of ts_recall function with list type arguments. """