From 4326aec954f24cf0b142786c536c02a7e276d6ed Mon Sep 17 00:00:00 2001 From: wasserth Date: Mon, 5 Feb 2024 16:11:25 +0100 Subject: [PATCH] allow larger diff in tests; exit os tests if error --- tests/test_end_to_end.py | 27 ++++++++++++++++----------- tests/tests_os.py | 18 +++++++++++++----- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 58d3a6947..04a5a0b5c 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -45,29 +45,32 @@ def test_preview(self): def test_prediction_multilabel_fast(self): img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata() - images_equal = np.array_equal(img_ref, img_new) - self.assertTrue(images_equal, "multilabel prediction fast not correct") + nr_of_diff_voxels = (img_ref != img_new).sum() + images_equal = nr_of_diff_voxels < 100 + self.assertTrue(images_equal, f"multilabel prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_prediction_multilabel_fast_force_split(self): img_ref = nib.load("tests/reference_files/example_seg_fast_force_split.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast_force_split.nii.gz").get_fdata() nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 + images_equal = nr_of_diff_voxels < 100 self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_prediction_multilabel_fast_body_seg(self): img_ref = nib.load("tests/reference_files/example_seg_fast_body_seg.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast_body_seg.nii.gz").get_fdata() - images_equal = np.array_equal(img_ref, img_new) - self.assertTrue(images_equal, "body_seg prediction fast not correct") + nr_of_diff_voxels = (img_ref != img_new).sum() + images_equal = nr_of_diff_voxels < 100 + self.assertTrue(images_equal, f"body_seg prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_lung_vessels(self): for roi in ["lung_trachea_bronchia", "lung_vessels"]: img_ref = nib.load(f"tests/reference_files/example_seg_lung_vessels/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - images_equal = np.array_equal(img_ref, img_new) - self.assertTrue(images_equal, f"{roi} prediction not correct") + nr_of_diff_voxels = (img_ref != img_new).sum() + images_equal = nr_of_diff_voxels < 30 + self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_tissue_types_wo_license(self): no_output_file = not os.path.exists("tests/unittest_no_license.nii.gz") @@ -81,15 +84,17 @@ def test_tissue_types(self): for roi in ["subcutaneous_fat", "skeletal_muscle", "torso_fat"]: img_ref = nib.load(f"tests/reference_files/example_seg_tissue_types/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - images_equal = np.array_equal(img_ref, img_new) - self.assertTrue(images_equal, f"{roi} prediction not correct") + nr_of_diff_voxels = (img_ref != img_new).sum() + images_equal = nr_of_diff_voxels < 30 + self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_appendicular_bones(self): for roi in ["patella", "phalanges_hand"]: img_ref = nib.load(f"tests/reference_files/example_seg_appendicular_bones/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - images_equal = np.array_equal(img_ref, img_new) - self.assertTrue(images_equal, f"{roi} prediction not correct") + nr_of_diff_voxels = (img_ref != img_new).sum() + images_equal = nr_of_diff_voxels < 30 + self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") def test_statistics(self): stats_ref = json.load(open("tests/reference_files/example_seg_fast/statistics.json")) diff --git a/tests/tests_os.py b/tests/tests_os.py index 2c6a07004..b691af01e 100755 --- a/tests/tests_os.py +++ b/tests/tests_os.py @@ -10,19 +10,21 @@ from totalsegmentator.python_api import totalsegmentator -if __name__ == "__main__": +def run_tests_and_exit_on_failure(): # Test python api # Test organ predictions - fast totalsegmentator("tests/reference_files/example_ct_sm.nii.gz", "tests/unittest_prediction_fast", fast=True, device="cpu") - pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"]) + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"]) shutil.rmtree("tests/unittest_prediction_fast") + if r != 0: sys.exit("Test failed: test_prediction_fast") # Test python api - nifti input input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz") totalsegmentator(input_img, "tests/unittest_prediction_fast", fast=True, device="cpu") - pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"]) + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"]) shutil.rmtree("tests/unittest_prediction_fast") + if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input") # Test terminal # Test organ predictions - fast - multilabel @@ -31,10 +33,16 @@ file_in = os.path.join("tests", "reference_files", "example_ct_sm.nii.gz") file_out = os.path.join("tests", "unittest_prediction_fast.nii.gz") subprocess.call(f"TotalSegmentator -i {file_in} -o {file_out} --fast --ml -d cpu", shell=True) - pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"]) + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"]) os.remove(file_out) + if r != 0: sys.exit("Test failed: test_prediction_multilabel_fast") # Test Dicom input totalsegmentator("tests/reference_files/example_ct_dicom", "tests/unittest_prediction_dicom.nii.gz", fast=True, ml=True, device="cpu") - pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom"]) + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom"]) os.remove("tests/unittest_prediction_dicom.nii.gz") + if r != 0: sys.exit("Test failed: test_prediction_dicom") + + +if __name__ == "__main__": + run_tests_and_exit_on_failure() \ No newline at end of file