Skip to content

Commit

Permalink
use dice instead of nr diff voxels in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 6, 2024
1 parent 4326aec commit fbe088e
Showing 1 changed file with 49 additions and 30 deletions.
79 changes: 49 additions & 30 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
import pandas as pd


def dice_score(y_true, y_pred):
intersect = np.sum(y_true * y_pred)
denominator = np.sum(y_true) + np.sum(y_pred)
f1 = (2 * intersect) / (denominator + 1e-6)
return f1


def dice_score_multilabel(y_true, y_pred):
"""
Calc dice for each class and then return the mean.
"""
dice_scores = []
for i in np.unique(y_true)[1:]:
gt = y_true == i
pred = y_pred == i
dice_scores.append(dice_score(gt, pred))
print(f"Dice scores per class: {dice_scores}") # only gets printed if the test fails
return np.mean(dice_scores)


class test_end_to_end(unittest.TestCase):

def setUp(self):
Expand All @@ -17,26 +37,26 @@ def setUp(self):
def test_prediction_multilabel(self):
img_ref = nib.load("tests/reference_files/example_seg.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"multilabel prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
# nr_of_diff_voxels = (img_ref != img_new).sum()
# images_equal = nr_of_diff_voxels < 100
dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"multilabel prediction not correct (dice: {dice:.6f})")

def test_prediction_liver_roi_subset(self):
img_ref = nib.load("tests/reference_files/example_seg_roi_subset.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_roi_subset.nii.gz").get_fdata()
# prediction is not completely deterministic therefore allow for small differences
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"roi subset prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"roi subset prediction not correct (dice: {dice:.6f})")

def test_prediction_fast(self):
for roi in ["liver", "vertebrae_L1"]:
img_ref = nib.load(f"tests/reference_files/example_seg_fast/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction_fast/{roi}.nii.gz").get_fdata()
# prediction is not completely deterministic therefore allow for small differences
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} fast prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"{roi} fast prediction not correct (dice: {dice:.6f})")

def test_preview(self):
preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png")
Expand All @@ -45,32 +65,31 @@ def test_preview(self):
def test_prediction_multilabel_fast(self):
img_ref = nib.load("tests/reference_files/example_seg_fast.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"multilabel prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"multilabel prediction fast not correct (dice: {dice:.6f})")

def test_prediction_multilabel_fast_force_split(self):
img_ref = nib.load("tests/reference_files/example_seg_fast_force_split.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast_force_split.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")

dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"force_split prediction not correct (nr_of_diff_voxels: {dice:.6f})")

def test_prediction_multilabel_fast_body_seg(self):
img_ref = nib.load("tests/reference_files/example_seg_fast_body_seg.nii.gz").get_fdata()
img_new = nib.load("tests/unittest_prediction_fast_body_seg.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 100
self.assertTrue(images_equal, f"body_seg prediction fast not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score_multilabel(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"body_seg prediction fast not correct (dice: {dice:.6f})")

def test_lung_vessels(self):
for roi in ["lung_trachea_bronchia", "lung_vessels"]:
img_ref = nib.load(f"tests/reference_files/example_seg_lung_vessels/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})")

def test_tissue_types_wo_license(self):
no_output_file = not os.path.exists("tests/unittest_no_license.nii.gz")
Expand All @@ -84,17 +103,17 @@ def test_tissue_types(self):
for roi in ["subcutaneous_fat", "skeletal_muscle", "torso_fat"]:
img_ref = nib.load(f"tests/reference_files/example_seg_tissue_types/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})")

def test_appendicular_bones(self):
for roi in ["patella", "phalanges_hand"]:
img_ref = nib.load(f"tests/reference_files/example_seg_appendicular_bones/{roi}.nii.gz").get_fdata()
img_new = nib.load(f"tests/unittest_prediction/{roi}.nii.gz").get_fdata()
nr_of_diff_voxels = (img_ref != img_new).sum()
images_equal = nr_of_diff_voxels < 30
self.assertTrue(images_equal, f"{roi} prediction not correct (nr_of_diff_voxels: {nr_of_diff_voxels})")
dice = dice_score(img_ref, img_new)
images_equal = dice > 0.99
self.assertTrue(images_equal, f"{roi} prediction not correct (dice: {dice:.6f})")

def test_statistics(self):
stats_ref = json.load(open("tests/reference_files/example_seg_fast/statistics.json"))
Expand Down

0 comments on commit fbe088e

Please sign in to comment.