diff --git a/.gitignore b/.gitignore index e9c99a8ae..7f5b5ceb1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,9 @@ store/ unittest_prediction unittest_prediction_fast unittest_prediction_roi_subset +unittest_prediction.nii.gz unittest_prediction_fast.nii.gz +unittest_prediction_fast_body_seg.nii.gz *.tfstate *.tfstate.backup diff --git a/bin/crop_to_body b/bin/crop_to_body index a010cc42d..9bfe5ae49 100644 --- a/bin/crop_to_body +++ b/bin/crop_to_body @@ -34,6 +34,9 @@ def main(): parser.add_argument("-ns", "--nr_thr_saving", type=int, help="Nr of threads for saving segmentations", default=6) + + parser.add_argument("-d", "--device", choices=["gpu", "cpu"], help="Device to run on (default: gpu).", + default="gpu") parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs", default=False) @@ -45,23 +48,25 @@ def main(): quiet, verbose = args.quiet, args.verbose - if not torch.cuda.is_available(): - print("No GPU detected. Running on CPU. This can be very slow. The '--fast' option can help to some extend.") + device = "cuda" if args.device == "gpu" else "cpu" + if device == "cuda" and not torch.cuda.is_available(): + print("No GPU detected. Running on CPU.") + device = "cpu" setup_nnunet() from totalsegmentator.nnunet import nnUNet_predict_image # this has to be after setting new env vars crop_addon = [3, 3, 3] # default value - download_pretrained_weights(269) + download_pretrained_weights(300) st = time.time() if not quiet: print("Generating rough body segmentation...") - body_seg = nnUNet_predict_image(args.input, None, 269, model="3d_fullres", folds=[0], - trainer="nnUNetTrainerV2", tta=False, multilabel_image=True, resample=6.0, + body_seg = nnUNet_predict_image(args.input, None, 300, model="3d_fullres", folds=[0], + trainer="nnUNetTrainer", tta=False, multilabel_image=True, resample=6.0, crop=None, crop_path=None, task_name="body", nora_tag="None", preview=False, save_binary=False, nr_threads_resampling=args.nr_thr_resamp, nr_threads_saving=1, - crop_addon=crop_addon, quiet=quiet, verbose=verbose, test=0) + crop_addon=crop_addon, quiet=quiet, verbose=verbose, test=0, device=device) if verbose: print(f"Rough body segmentation generated in {time.time()-st:.2f}s") body_seg_data = body_seg.get_fdata() diff --git a/tests/unittest_prediction.nii.gz b/tests/unittest_prediction.nii.gz deleted file mode 100644 index 1f518a4b5..000000000 Binary files a/tests/unittest_prediction.nii.gz and /dev/null differ diff --git a/tests/unittest_prediction_fast_body_seg.nii.gz b/tests/unittest_prediction_fast_body_seg.nii.gz deleted file mode 100644 index c8ae81e0c..000000000 Binary files a/tests/unittest_prediction_fast_body_seg.nii.gz and /dev/null differ diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index f7ffd2315..4884ea7fd 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -51,16 +51,16 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, input = Path(input) output = Path(output) - # available devices: gpu | cpu | mps - if device == "gpu": device = "cuda" - nora_tag = "None" if nora_tag is None else nora_tag if not quiet: print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n") - if not torch.cuda.is_available(): - print("No GPU detected. Running on CPU. This can be very slow. The '--fast' option can help to some extend.") + # available devices: gpu | cpu | mps + if device == "gpu": device = "cuda" + if device == "cuda" and not torch.cuda.is_available(): + print("No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.") + device = "cpu" setup_nnunet() setup_totalseg()