diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 04a5a0b5c..bc687ae38 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -9,6 +9,26 @@ import pandas as pd +def dice_score(y_true, y_pred): + intersect = np.sum(y_true * y_pred) + denominator = np.sum(y_true) + np.sum(y_pred) + f1 = (2 * intersect) / (denominator + 1e-6) + return f1 + + +def dice_score_multilabel(y_true, y_pred): + """ + Calc dice for each class and then return the mean. + """ + dice_scores = [] + for i in np.unique(y_true)[1:]: + gt = y_true == i + pred = y_pred == i + dice_scores.append(dice_score(gt, pred)) + print(f"Dice scores per class: {dice_scores}") # only gets printed if the test fails + return np.mean(dice_scores) + + class test_end_to_end(unittest.TestCase): def setUp(self): @@ -17,26 +37,26 @@ def setUp(self): def test_prediction_multilabel(self): img_ref = nib.load("tests/reference_files/example_seg.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"multilabel prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + # nr_of_diff_voxels = (img_ref != img_new).sum() + # images_equal = nr_of_diff_voxels < 100 + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"multilabel prediction not correct (dice: {dice:.6f})") def test_prediction_liver_roi_subset(self): img_ref = nib.load("tests/reference_files/example_seg_roi_subset.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_roi_subset.nii.gz").get_fdata() - # prediction is not completely deterministic therefore allow for small differences - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"roi subset prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"roi subset prediction not correct (dice: {dice:.6f})") def test_prediction_fast(self): for roi in ["liver", "vertebrae_L1"]: img_ref = nib.load(f"tests/reference_files/example_seg_fast/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction_fast/{roi}.nii.gz").get_fdata() - # prediction is not completely deterministic therefore allow for small differences - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"{roi} fast prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"{roi} fast prediction not correct (dice: {dice:.6f})") def test_preview(self): preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png") @@ -45,32 +65,31 @@ def test_preview(self): def test_prediction_multilabel_fast(self): img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 100 - self.assertTrue(images_equal, f"multilabel prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"multilabel prediction fast not correct (dice: {dice:.6f})") def test_prediction_multilabel_fast_force_split(self): img_ref = nib.load("tests/reference_files/example_seg_fast_force_split.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast_force_split.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 100 - self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") - + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {dice:.6f})") def test_prediction_multilabel_fast_body_seg(self): img_ref = nib.load("tests/reference_files/example_seg_fast_body_seg.nii.gz").get_fdata() img_new = nib.load("tests/unittest_prediction_fast_body_seg.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 100 - self.assertTrue(images_equal, f"body_seg prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score_multilabel(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"body_seg prediction fast not correct (dice: {dice:.6f})") def test_lung_vessels(self): for roi in ["lung_trachea_bronchia", "lung_vessels"]: img_ref = nib.load(f"tests/reference_files/example_seg_lung_vessels/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})") def test_tissue_types_wo_license(self): no_output_file = not os.path.exists("tests/unittest_no_license.nii.gz") @@ -84,17 +103,17 @@ def test_tissue_types(self): for roi in ["subcutaneous_fat", "skeletal_muscle", "torso_fat"]: img_ref = nib.load(f"tests/reference_files/example_seg_tissue_types/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})") def test_appendicular_bones(self): for roi in ["patella", "phalanges_hand"]: img_ref = nib.load(f"tests/reference_files/example_seg_appendicular_bones/{roi}.nii.gz").get_fdata() img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata() - nr_of_diff_voxels = (img_ref != img_new).sum() - images_equal = nr_of_diff_voxels < 30 - self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})") + dice = dice_score(img_ref, img_new) + images_equal = dice > 0.99 + self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})") def test_statistics(self): stats_ref = json.load(open("tests/reference_files/example_seg_fast/statistics.json"))