From b7e4fa2edec449b2085ff28b7ba7cdc188d7c9cb Mon Sep 17 00:00:00 2001 From: wasserth Date: Tue, 22 Aug 2023 16:06:37 +0200 Subject: [PATCH] more adaptions for v2 classes --- bin/TotalSegmentator | 14 +++-- totalsegmentator/config.py | 6 ++ totalsegmentator/libs.py | 68 ++++++++++++---------- totalsegmentator/map_to_binary.py | 4 +- totalsegmentator/preview.py | 39 +++++++------ totalsegmentator/python_api.py | 93 +++++++++++++++++++++---------- 6 files changed, 142 insertions(+), 82 deletions(-) diff --git a/bin/TotalSegmentator b/bin/TotalSegmentator index 1bb962b5f..2c147b7ad 100644 --- a/bin/TotalSegmentator +++ b/bin/TotalSegmentator @@ -45,11 +45,15 @@ def main(): # cerebral_bleed: Intracerebral hemorrhage # liver_vessels: hepatic vessels - parser.add_argument("-ta", "--task", choices=["total", "lung_vessels", "cerebral_bleed", - "hip_implant", "coronary_arteries", "body", "pleural_pericard_effusion", - "liver_vessels", "bones_extremities", "tissue_types", - "heartchambers_highres", "head", "aortic_branches", "heartchambers_test", - "bones_tissue_test", "aortic_branches_test", "test"], + parser.add_argument("-ta", "--task", choices=["total", "body", "vertebrae_body", + + "lung_vessels", "cerebral_bleed", "hip_implant", "coronary_arteries", + "pleural_pericard_effusion", "test", + + "appendicular_bones", "tissue_types", "heartchambers_highres", + "face", + ], + # future: liver_vessels, head, help="Select which model to use. This determines what is predicted.", default="total") diff --git a/totalsegmentator/config.py b/totalsegmentator/config.py index 712851be5..838c0197f 100644 --- a/totalsegmentator/config.py +++ b/totalsegmentator/config.py @@ -11,6 +11,8 @@ import requests import torch +from totalsegmentator.libs import is_valid_license + def setup_nnunet(): # check if environment variable totalsegmentator_config is set @@ -53,6 +55,10 @@ def setup_totalseg(totalseg_id=None): def set_license_number(license_number): + if not is_valid_license(license_number): + print("ERROR: Invalid license number. Please check your license number or contact support.") + sys.exit(1) + home_path = Path("/tmp") if str(Path.home()) == "/" else Path.home() totalseg_config_file = home_path / ".totalsegmentator" / "config.json" diff --git a/totalsegmentator/libs.py b/totalsegmentator/libs.py index 4a4763056..f68bab092 100644 --- a/totalsegmentator/libs.py +++ b/totalsegmentator/libs.py @@ -55,16 +55,7 @@ def get_config_dir(): # f.write(chunk) -def has_valid_license(): - home_path = Path("/tmp") if str(Path.home()) == "/" else Path.home() - totalseg_config_file = home_path / ".totalsegmentator" / "config.json" - if totalseg_config_file.exists(): - config = json.load(open(totalseg_config_file, "r")) - license_number = config["license_number"] - else: - print(f"ERROR: Could not find config file: {totalseg_config_file}") - return False - +def is_valid_license(license_number): try: url = f"http://backend.totalsegmentator.com:80/" r = requests.post(url + "is_valid_license_number", @@ -78,6 +69,23 @@ def has_valid_license(): except Exception as e: print(f"An Exception occured: {e}") return False + + +def has_valid_license(): + home_path = Path("/tmp") if str(Path.home()) == "/" else Path.home() + totalseg_config_file = home_path / ".totalsegmentator" / "config.json" + if totalseg_config_file.exists(): + config = json.load(open(totalseg_config_file, "r")) + if "license_number" in config: + license_number = config["license_number"] + else: + # print(f"ERROR: A license number has not been set so far.") + return False + else: + # print(f"ERROR: Could not find config file: {totalseg_config_file}") + return False + + return is_valid_license(license_number) def download_model_with_license_and_unpack(task_name, config_dir): @@ -122,7 +130,7 @@ def download_model_with_license_and_unpack(task_name, config_dir): print(f" downloaded in {time.time()-st:.2f}s") else: if r.json()['status'] == "invalid_license": - print(f"Invalid license number ({license_number}). Please check your license number or contact support.") + print(f"ERROR: Invalid license number ({license_number}). Please check your license number or contact support.") sys.exit(1) except Exception as e: @@ -188,25 +196,25 @@ def download_pretrained_weights(task_id): # (config_dir / "2d").mkdir(exist_ok=True, parents=True) old_weights = [ - "Task251_TotalSegmentator_part1_organs_1139subj", - "Task252_TotalSegmentator_part2_vertebrae_1139subj", - "Task253_TotalSegmentator_part3_cardiac_1139subj", - "Task254_TotalSegmentator_part4_muscles_1139subj", - "Task255_TotalSegmentator_part5_ribs_1139subj", - "Task256_TotalSegmentator_3mm_1139subj", - "Task258_lung_vessels_248subj", - "Task200_covid_challenge", - "Task201_covid", - "Task150_icb_v0", - "Task260_hip_implant_71subj", - "Task269_Body_extrem_6mm_1200subj", - "Task503_cardiac_motion", - "Task273_Body_extrem_1259subj", - "Task315_thoraxCT", - "Task008_HepaticVessel", - "Task417_heart_mixed_317subj", - "Task278_TotalSegmentator_part6_bones_1259subj", - "Task435_Heart_vessels_118subj" + "nnUNet/3d_fullres/Task251_TotalSegmentator_part1_organs_1139subj", + "nnUNet/3d_fullres/Task252_TotalSegmentator_part2_vertebrae_1139subj", + "nnUNet/3d_fullres/Task253_TotalSegmentator_part3_cardiac_1139subj", + "nnUNet/3d_fullres/Task254_TotalSegmentator_part4_muscles_1139subj", + "nnUNet/3d_fullres/Task255_TotalSegmentator_part5_ribs_1139subj", + "nnUNet/3d_fullres/Task256_TotalSegmentator_3mm_1139subj", + "nnUNet/3d_fullres/Task258_lung_vessels_248subj", + "nnUNet/3d_fullres/Task200_covid_challenge", + "nnUNet/3d_fullres/Task201_covid", + "nnUNet/3d_fullres/Task150_icb_v0", + "nnUNet/3d_fullres/Task260_hip_implant_71subj", + "nnUNet/3d_fullres/Task269_Body_extrem_6mm_1200subj", + "nnUNet/3d_fullres/Task503_cardiac_motion", + "nnUNet/3d_fullres/Task273_Body_extrem_1259subj", + "nnUNet/3d_fullres/Task315_thoraxCT", + "nnUNet/3d_fullres/Task008_HepaticVessel", + "nnUNet/3d_fullres/Task417_heart_mixed_317subj", + "nnUNet/3d_fullres/Task278_TotalSegmentator_part6_bones_1259subj", + "nnUNet/3d_fullres/Task435_Heart_vessels_118subj" ] url = "http://backend.totalsegmentator.com" diff --git a/totalsegmentator/map_to_binary.py b/totalsegmentator/map_to_binary.py index 32c6a52ec..703b09a08 100644 --- a/totalsegmentator/map_to_binary.py +++ b/totalsegmentator/map_to_binary.py @@ -277,7 +277,7 @@ 10: "metacarpal", 11: "phalanges_hand" }, - "tissue": { + "tissue_types": { 1: "subcutaneous_fat", 2: "skeletal_muscle", 3: "torso_fat" @@ -294,7 +294,7 @@ commercial_models = { "heartchambers_highres": 301, "appendicular_bones": 296, - "tissue": 481, + "tissue_types": 481, "face": 303 } # future diff --git a/totalsegmentator/preview.py b/totalsegmentator/preview.py index 5b00f4b51..ed9bbd1ad 100644 --- a/totalsegmentator/preview.py +++ b/totalsegmentator/preview.py @@ -81,30 +81,37 @@ "liver_vessels": [ ["liver_vessels", "liver_tumor"] ], - "heartchambers_test": [ + "vertebrae_body": [ + ["vertebrae_body"] + ], + "heartchambers_highres": [ ["heart_myocardium"], ["heart_atrium_left", "heart_ventricle_left"], ["heart_atrium_right", "heart_ventricle_right"], ["aorta", "pulmonary_artery"] ], - "bones_tissue_test": [ - ["femur", "patella", "tibia", "fibula", "tarsal", "metatarsal", "phalanges_feet", - "humerus", "ulna", "radius", "carpal", "metacarpal", "phalanges_hand", "sternum", - "skull", "spinal_cord"], - ["subcutaneous_fat", "skeletal_muscle", "torso_fat"] + "appendicular_bones": [ + ["patella", "tibia", "fibula", "tarsal", "metatarsal", "phalanges_feet", + "ulna", "radius", "carpal", "metacarpal", "phalanges_hand"] + ], + "tissue_types": [ + ["subcutaneous_fat"], + ["torso_fat"], + ["skeletal_muscle"] ], - "aortic_branches_test": [ - ["brachiocephalic_trunk", "subclavian_artery_right", "subclavian_artery_left", "aorta", - "common_carotid_artery_right", "common_carotid_artery_left"], - ["superior_vena_cava", - "brachiocephalic_vein_left", "brachiocephalic_vein_right", "atrial_appendage_left"], - ["pulmonary_vein", "pulmonary_artery"], - ["heart_atrium_left", "heart_atrium_right", "thyroid_gland"] + "face": [ + ["face"] ], + # "aortic_branches_test": [ + # ["brachiocephalic_trunk", "subclavian_artery_right", "subclavian_artery_left", "aorta", + # "common_carotid_artery_right", "common_carotid_artery_left"], + # ["superior_vena_cava", + # "brachiocephalic_vein_left", "brachiocephalic_vein_right", "atrial_appendage_left"], + # ["pulmonary_vein", "pulmonary_artery"], + # ["heart_atrium_left", "heart_atrium_right", "thyroid_gland"] + # ], "test": [ - ["carpal", "clavicula", "femur", "fibula", "humerus", "metacarpal", "metatarsal", - "patella", "hips", "phalanges_hand", "radius", "ribs", "scapula", "skull", "spine", - "sternum", "tarsal", "tibia", "phalanges_feet", "ulna"] + ["ulna"] ] } diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index 9cc23eb00..816eb36e1 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -8,12 +8,26 @@ import torch from totalsegmentator.statistics import get_basic_statistics_for_entire_dir, get_radiomics_features_for_entire_dir -from totalsegmentator.libs import download_pretrained_weights -from totalsegmentator.config import setup_nnunet, setup_totalseg, increase_prediction_counter, send_usage_stats +from totalsegmentator.libs import download_pretrained_weights, has_valid_license +from totalsegmentator.config import setup_nnunet, setup_totalseg, increase_prediction_counter +from totalsegmentator.config import send_usage_stats, set_license_number from totalsegmentator.map_to_binary import class_map from totalsegmentator.map_to_total import map_to_total +def show_license_info(): + if not has_valid_license(): + print(""" + In contrast to the other tasks this task is not openly available. + It requires a license. For academic usage a free license can be + acquired here: + https://totalsegmentator-academic.streamlit.app + + For commercial usage see: + https://totalsegmentator-commercial.streamlit.app + """) + + def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, fast=False, nora_tag="None", preview=False, task="total", roi_subset=None, statistics=False, radiomics=False, crop_path=None, body_seg=False, @@ -52,15 +66,13 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, if fast: task_id = 297 resample = 3.0 - # trainer = "nnUNetTrainer_4000epochs_NoMirroring" trainer = "nnUNetTrainerNoMirroring" crop = None if not quiet: print("Using 'fast' option: resampling to lower resolution (3mm)") task = "total_fast" else: - task_id = [291, 292, 293, 294, 295, 296] + task_id = [291, 292, 293, 294, 295] resample = 1.5 - # trainer = "nnUNetTrainer_4000epochs_NoMirroring" trainer = "nnUNetTrainerNoMirroring" crop = None model = "3d_fullres" @@ -68,7 +80,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "lung_vessels": task_id = 258 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "lung" if ml: raise ValueError("task lung_vessels does not work with option --ml, because of postprocessing.") if fast: raise ValueError("task lung_vessels does not work with option --fast") @@ -77,7 +89,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "covid": task_id = 201 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "lung" model = "3d_fullres" folds = [0] @@ -86,7 +98,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "cerebral_bleed": task_id = 150 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "brain" model = "3d_fullres" folds = [0] @@ -94,7 +106,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "hip_implant": task_id = 260 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "pelvis" model = "3d_fullres" folds = [0] @@ -102,7 +114,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "coronary_arteries": task_id = 503 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "heart" model = "3d_fullres" folds = [0] @@ -110,17 +122,17 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, if fast: raise ValueError("task coronary_arteries does not work with option --fast") elif task == "body": if fast: - task_id = 269 + task_id = 300 resample = 6.0 - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = None model = "3d_fullres" folds = [0] if not quiet: print("Using 'fast' option: resampling to lower resolution (6mm)") else: - task_id = 273 + task_id = 299 resample = 1.5 - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = None model = "3d_fullres" folds = [0] @@ -128,7 +140,7 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "pleural_pericard_effusion": task_id = 315 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "lung" crop_addon = [50, 50, 50] model = "3d_fullres" @@ -137,37 +149,60 @@ def totalsegmentator(input, output, ml=False, nr_thr_resamp=1, nr_thr_saving=6, elif task == "liver_vessels": task_id = 8 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "liver" crop_addon = [20, 20, 20] model = "3d_fullres" folds = None if fast: raise ValueError("task liver_vessels does not work with option --fast") - elif task == "heartchambers_test": - task_id = 417 + elif task == "vertebrae_body": + task_id = 302 + resample = 1.5 + trainer = "nnUNetTrainer" + crop = None + model = "3d_fullres" + folds = [0] + if fast: raise ValueError("task vertebrae_body does not work with option --fast") + + # Commercial models + elif task == "heartchambers_highres": + task_id = 301 resample = None - trainer = "nnUNetTrainerV2" + trainer = "nnUNetTrainer" crop = "heart" crop_addon = [5, 5, 5] - model = "3d_lowres" + model = "3d_fullres" folds = None - if fast: raise ValueError("task heartchambers_test does not work with option --fast") - elif task == "bones_tissue_test": - task_id = 278 + if fast: raise ValueError("task heartchambers_highres does not work with option --fast") + show_license_info() + elif task == "appendicular_bones": + task_id = 296 resample = 1.5 - trainer = "nnUNetTrainerV2_ep4000_nomirror" + trainer = "nnUNetTrainer" crop = None model = "3d_fullres" folds = [0] - if fast: raise ValueError("task bones_tissue_test does not work with option --fast") - elif task == "aortic_branches_test": - task_id = 435 + if fast: raise ValueError("task appendicular_bones does not work with option --fast") + show_license_info() + elif task == "tissue_types": + task_id = 481 resample = 1.5 - trainer = "nnUNetTrainerV2_nomirror" + trainer = "nnUNetTrainer" crop = None model = "3d_fullres" folds = [0] - if fast: raise ValueError("task aortic_branches_test does not work with option --fast") + if fast: raise ValueError("task tissue_types does not work with option --fast") + show_license_info() + elif task == "face": + task_id = 303 + resample = 1.5 + trainer = "nnUNetTrainer" + crop = None + model = "3d_fullres" + folds = [0] + if fast: raise ValueError("task face does not work with option --fast") + show_license_info() + elif task in ["bones_extremities", "tissue_types", "heartchambers_highres", "head", "aortic_branches"]: print("\nThis model is only available upon purchase of a license (free licenses available for " +