Skip to content

Commit

Permalink
use 6mm model for all models with crop
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Aug 29, 2023
1 parent e09dde7 commit 2515ad4
Showing 1 changed file with 38 additions and 36 deletions.
74 changes: 38 additions & 36 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 258
resample = None
trainer = "nnUNetTrainer"
crop = "lung"
crop = ["lung_upper_lobe_left", "lung_lower_lobe_left", "lung_upper_lobe_right",
"lung_middle_lobe_right", "lung_lower_lobe_right"]
if ml: raise ValueError("task lung_vessels does not work with option --ml, because of postprocessing.")
if fast: raise ValueError("task lung_vessels does not work with option --fast")
model = "3d_fullres"
Expand All @@ -99,7 +100,8 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 201
resample = None
trainer = "nnUNetTrainer"
crop = "lung"
crop = ["lung_upper_lobe_left", "lung_lower_lobe_left", "lung_upper_lobe_right",
"lung_middle_lobe_right", "lung_lower_lobe_right"]
model = "3d_fullres"
folds = [0]
print("WARNING: The COVID model finds many types of lung opacity not only COVID. Use with care!")
Expand All @@ -108,23 +110,23 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 150
resample = None
trainer = "nnUNetTrainer"
crop = "brain"
crop = ["brain"]
model = "3d_fullres"
folds = [0]
if fast: raise ValueError("task cerebral_bleed does not work with option --fast")
elif task == "hip_implant":
task_id = 260
resample = None
trainer = "nnUNetTrainer"
crop = "pelvis"
crop = ["femur_left", "femur_right", "hip_left", "hip_right"]
model = "3d_fullres"
folds = [0]
if fast: raise ValueError("task hip_implant does not work with option --fast")
elif task == "coronary_arteries":
task_id = 503
resample = None
trainer = "nnUNetTrainer"
crop = "heart"
crop = ["heart"]
model = "3d_fullres"
folds = [0]
print("WARNING: The coronary artery model does not work very robustly. Use with care!")
Expand All @@ -150,7 +152,8 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 315
resample = None
trainer = "nnUNetTrainer"
crop = "lung"
crop = ["lung_upper_lobe_left", "lung_lower_lobe_left", "lung_upper_lobe_right",
"lung_middle_lobe_right", "lung_lower_lobe_right"]
crop_addon = [50, 50, 50]
model = "3d_fullres"
folds = None
Expand All @@ -159,7 +162,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 8
resample = None
trainer = "nnUNetTrainer"
crop = "liver"
crop = ["liver"]
crop_addon = [20, 20, 20]
model = "3d_fullres"
folds = None
Expand All @@ -178,7 +181,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
task_id = 301
resample = None
trainer = "nnUNetTrainer"
crop = "heart"
crop = ["heart"]
crop_addon = [5, 5, 5]
model = "3d_fullres"
folds = None
Expand Down Expand Up @@ -211,12 +214,6 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
folds = [0]
if fast: raise ValueError("task face does not work with option --fast")
show_license_info()

elif task in ["bones_extremities", "tissue_types", "heartchambers_highres",
"head", "aortic_branches"]:
print("\nThis model is only available upon purchase of a license (free licenses available for " +
"academic projects). \nContact [email protected] if you are interested.\n")
sys.exit()
elif task == "test":
task_id = [517]
resample = None
Expand All @@ -242,29 +239,34 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6,

if roi_subset is not None and type(roi_subset) is not list:
raise ValueError("roi_subset must be a list of strings")
if roi_subset is not None and task != "total":
raise ValueError("roi_subset only works with task 'total'")

# Generate rough organ segmentation (6mm) for speed up if roi_subset is used
if roi_subset is not None and len(roi_subset) < 10:
if task != "appendicular_bones" and task != "tissue":
body_seg = False # can not be used together with body_seg
download_pretrained_weights(298)
st = time.time()
if not quiet: print("Generating rough body segmentation...")
organ_seg = nnUNet_predict_image(input, None, 298, model="3d_fullres", folds=[0],
trainer="nnUNetTrainerNoMirroring", tta=False, multilabel_image=True, resample=6.0,
crop=None, crop_path=None, task_name="total", nora_tag="None", preview=False,
save_binary=False, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
crop_addon=crop_addon, output_type=output_type, statistics=False,
quiet=quiet, verbose=verbose, test=0, skip_saving=False, device=device)
class_map_inv = {v: k for k, v in class_map["total"].items()}
crop = np.zeros(organ_seg.shape, dtype=np.uint8)
organ_seg_data = organ_seg.get_fdata()
roi_subset_crop = [map_to_total[roi] if roi in map_to_total else roi for roi in roi_subset]
for roi in roi_subset_crop:
crop[organ_seg_data == class_map_inv[roi]] = 1
crop = nib.Nifti1Image(crop, organ_seg.affine)
crop_addon = [20,20,20]
if verbose: print(f"Rough organ segmentation generated in {time.time()-st:.2f}s")
# Generate rough organ segmentation (6mm) for speed up if crop or roi_subset is used
if crop is not None or \
(roi_subset is not None and len(roi_subset) < 10):

body_seg = False # can not be used together with body_seg
download_pretrained_weights(298)
st = time.time()
if not quiet: print("Generating rough body segmentation...")
organ_seg = nnUNet_predict_image(input, None, 298, model="3d_fullres", folds=[0],
trainer="nnUNetTrainerNoMirroring", tta=False, multilabel_image=True, resample=6.0,
crop=None, crop_path=None, task_name="total", nora_tag="None", preview=False,
save_binary=False, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
crop_addon=crop_addon, output_type=output_type, statistics=False,
quiet=quiet, verbose=verbose, test=0, skip_saving=False, device=device)
class_map_inv = {v: k for k, v in class_map["total"].items()}
crop_mask = np.zeros(organ_seg.shape, dtype=np.uint8)
organ_seg_data = organ_seg.get_fdata()
# roi_subset_crop = [map_to_total[roi] if roi in map_to_total else roi for roi in roi_subset]
roi_subset_crop = crop if crop is not None else roi_subset
for roi in roi_subset_crop:
crop_mask[organ_seg_data == class_map_inv[roi]] = 1
crop_mask = nib.Nifti1Image(crop_mask, organ_seg.affine)
crop_addon = [20,20,20]
crop = crop_mask
if verbose: print(f"Rough organ segmentation generated in {time.time()-st:.2f}s")

# Generate rough body segmentation (6mm) (speedup for big images; not useful in combination with --fast option)
if crop is None and body_seg:
Expand Down

0 comments on commit 2515ad4

Please sign in to comment.