Skip to content

Commit

Permalink
Merge pull request #269 from fohofmann/master
Browse files Browse the repository at this point in the history
postprocessing body, add quiet for tqdm
  • Loading branch information
wasserth authored Feb 8, 2024
2 parents 9b2c145 + c37f089 commit 030fe5f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,15 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
# Postprocessing multilabel (run here on lower resolution)
if task_name == "body":
img_pred_pp = keep_largest_blob_multilabel(img_pred.get_fdata().astype(np.uint8),
class_map[task_name], ["body_trunc"], debug=False)
class_map[task_name], ["body_trunc"], debug=False, quiet=quiet)
img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

if task_name == "body":
vox_vol = np.prod(img_pred.header.get_zooms())
size_thr_mm3 = 50000 / vox_vol
img_pred_pp = remove_small_blobs_multilabel(img_pred.get_fdata().astype(np.uint8),
class_map[task_name], ["body_extremities"],
interval=[size_thr_mm3, 1e10], debug=False)
interval=[size_thr_mm3, 1e10], debug=False, quiet=quiet)
img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

if preview:
Expand Down
8 changes: 4 additions & 4 deletions totalsegmentator/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def keep_largest_blob(data, debug=False):
return (blob_map == largest_blob_label).astype(np.uint8)


def keep_largest_blob_multilabel(data, class_map, rois, debug=False):
def keep_largest_blob_multilabel(data, class_map, rois, debug=False, quiet=False):
"""
Keep the largest blob for the classes defined in rois.
Expand All @@ -33,7 +33,7 @@ def keep_largest_blob_multilabel(data, class_map, rois, debug=False):
"""
st = time.time()
class_map_inv = {v: k for k, v in class_map.items()}
for roi in tqdm(rois):
for roi in tqdm(rois, disable=quiet):
idx = class_map_inv[roi]
data_roi = data == idx
cleaned_roi = keep_largest_blob(data_roi, debug) > 0.5
Expand Down Expand Up @@ -74,7 +74,7 @@ def remove_small_blobs(img: np.ndarray, interval=[10, 30], debug=False) -> np.nd
return mask


def remove_small_blobs_multilabel(data, class_map, rois, interval=[10, 30], debug=False):
def remove_small_blobs_multilabel(data, class_map, rois, interval=[10, 30], debug=False, quiet=False):
"""
Remove small blobs for the classes defined in rois.
Expand All @@ -87,7 +87,7 @@ def remove_small_blobs_multilabel(data, class_map, rois, interval=[10, 30], debu
st = time.time()
class_map_inv = {v: k for k, v in class_map.items()}

for roi in tqdm(rois):
for roi in tqdm(rois, disable=quiet):
idx = class_map_inv[roi]
data_roi = (data == idx)
cleaned_roi = remove_small_blobs(data_roi, interval, debug) > 0.5 # Remove small blobs from this ROI
Expand Down

0 comments on commit 030fe5f

Please sign in to comment.