Skip to content

Commit

Permalink
allow larger diff in tests; exit os tests if error
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 5, 2024
1 parent 52123a4 commit 4326aec
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
27 changes: 16 additions & 11 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,32 @@ def test_preview(self):
def test_prediction_multilabel_fast(self):
img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata()
images_equal = np.array_equal(img_ref, img_new)
self.assertTrue(images_equal, "multilabel prediction fast not correct")
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"multilabel prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

def test_prediction_multilabel_fast_force_split(self):
img_ref = nib.load("tests/reference_files/example_seg_fast_force_split.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast_force_split.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")


def test_prediction_multilabel_fast_body_seg(self):
img_ref = nib.load("tests/reference_files/example_seg_fast_body_seg.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast_body_seg.nii.gz").get_fdata()
images_equal = np.array_equal(img_ref, img_new)
self.assertTrue(images_equal, "body_seg prediction fast not correct")
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"body_seg prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

def test_lung_vessels(self):
for roi in ["lung_trachea_bronchia", "lung_vessels"]:
img_ref = nib.load(f"tests/reference_files/example_seg_lung_vessels/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
images_equal = np.array_equal(img_ref, img_new)
self.assertTrue(images_equal, f"{roi} prediction not correct")
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

def test_tissue_types_wo_license(self):
no_output_file = not os.path.exists("tests/unittest_no_license.nii.gz")
Expand All @@ -81,15 +84,17 @@ def test_tissue_types(self):
for roi in ["subcutaneous_fat", "skeletal_muscle", "torso_fat"]:
img_ref = nib.load(f"tests/reference_files/example_seg_tissue_types/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
images_equal = np.array_equal(img_ref, img_new)
self.assertTrue(images_equal, f"{roi} prediction not correct")
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

def test_appendicular_bones(self):
for roi in ["patella", "phalanges_hand"]:
img_ref = nib.load(f"tests/reference_files/example_seg_appendicular_bones/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
images_equal = np.array_equal(img_ref, img_new)
self.assertTrue(images_equal, f"{roi} prediction not correct")
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

def test_statistics(self):
stats_ref = json.load(open("tests/reference_files/example_seg_fast/statistics.json"))
Expand Down
18 changes: 13 additions & 5 deletions tests/tests_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@
from totalsegmentator.python_api import totalsegmentator


if __name__ == "__main__":
def run_tests_and_exit_on_failure():

# Test python api
# Test organ predictions - fast
totalsegmentator("tests/reference_files/example_ct_sm.nii.gz", "tests/unittest_prediction_fast", fast=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
shutil.rmtree("tests/unittest_prediction_fast")
if r != 0: sys.exit("Test failed: test_prediction_fast")

# Test python api - nifti input
input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz")
totalsegmentator(input_img, "tests/unittest_prediction_fast", fast=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
shutil.rmtree("tests/unittest_prediction_fast")
if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input")

# Test terminal
# Test organ predictions - fast - multilabel
Expand All @@ -31,10 +33,16 @@
file_in = os.path.join("tests", "reference_files", "example_ct_sm.nii.gz")
file_out = os.path.join("tests", "unittest_prediction_fast.nii.gz")
subprocess.call(f"TotalSegmentator -i {file_in} -o {file_out} --fast --ml -d cpu", shell=True)
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"])
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"])
os.remove(file_out)
if r != 0: sys.exit("Test failed: test_prediction_multilabel_fast")

# Test Dicom input
totalsegmentator("tests/reference_files/example_ct_dicom", "tests/unittest_prediction_dicom.nii.gz", fast=True, ml=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom"])
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom"])
os.remove("tests/unittest_prediction_dicom.nii.gz")
if r != 0: sys.exit("Test failed: test_prediction_dicom")


if __name__ == "__main__":
run_tests_and_exit_on_failure()

0 comments on commit 4326aec

Please sign in to comment.