Skip to content

Commit

Permalink
add median to stats; update phase prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Jul 19, 2024
1 parent 3af0160 commit a42a54f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
25 changes: 22 additions & 3 deletions totalsegmentator/bin/totalseg_get_phase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import sys
from pathlib import Path
import time
import argparse
import json
import pickle
Expand Down Expand Up @@ -51,18 +52,36 @@ def pi_time_to_phase(pi_time: float) -> str:

def get_ct_contrast_phase(ct_img: nib.Nifti1Image, model_file: Path = None):

organs = ["liver", "spleen", "kidney_left", "kidney_right", "pancreas", "urinary_bladder", "gallbladder",
organs = ["liver", "pancreas", "urinary_bladder", "gallbladder",
"heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein",
"iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right",
"pulmonary_vein"]
"pulmonary_vein", "brain", "colon", "small_bowel"]

organs_hn = ["internal_carotid_artery_right", "internal_carotid_artery_left",
"internal_jugular_vein_right", "internal_jugular_vein_left"]

st = time.time()
seg_img, stats = totalsegmentator(ct_img, None, ml=True, fast=True, statistics=True,
roi_subset=None, statistics_exclude_masks_at_border=False,
quiet=True)
quiet=True, stats_aggregation="median")
print(f"ts took: {time.time()-st:.2f}s")

if stats["brain"]["volume"] > 100:
print(f"Brain in image, therefore also running headneck model.")
st = time.time()
seg_img_hn, stats_hn = totalsegmentator(ct_img, None, ml=True, fast=False, statistics=True,
task="headneck_bones_vessels",
roi_subset=None, statistics_exclude_masks_at_border=False,
quiet=True, stats_aggregation="median")
print(f"hn took: {time.time()-st:.2f}s")
else:
stats_hn = {organ: {"intensity": 0.0} for organ in organs_hn}

features = []
for organ in organs:
features.append(stats[organ]["intensity"])
for organ in organs_hn:
features.append(stats_hn[organ]["intensity"])

if model_file is None:
# weights from longitudinalliver dataset
Expand Down
5 changes: 3 additions & 2 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
crop_addon=[3,3,3], roi_subset=None, output_type="nifti",
statistics=False, quiet=False, verbose=False, test=0, skip_saving=False,
device="cuda", exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False):
v1_order=False, stats_aggregation="mean"):
"""
crop: string or a nibabel image
resample: None or float (target spacing for all dimensions) or list of floats
Expand Down Expand Up @@ -548,7 +548,8 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
else:
stats_file = None
stats = get_basic_statistics(img_pred.get_fdata(), img_in_rsp, stats_file,
quiet, task_name, exclude_masks_at_border, roi_subset)
quiet, task_name, exclude_masks_at_border, roi_subset,
metric=stats_aggregation)
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

if resample is not None:
Expand Down
7 changes: 4 additions & 3 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
force_split=False, output_type="nifti", quiet=False, verbose=False, test=0,
skip_saving=False, device="gpu", license_number=None,
statistics_exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False, fastest=False, roi_subset_robust=None):
v1_order=False, fastest=False, roi_subset_robust=None, stats_aggregation="mean"):
"""
Run TotalSegmentator from within python.
Expand Down Expand Up @@ -472,7 +472,8 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
output_type=output_type, statistics=statistics_fast,
quiet=quiet, verbose=verbose, test=test, skip_saving=skip_saving, device=device,
exclude_masks_at_border=statistics_exclude_masks_at_border,
no_derived_masks=no_derived_masks, v1_order=v1_order)
no_derived_masks=no_derived_masks, v1_order=v1_order,
stats_aggregation=stats_aggregation)
seg = seg_img.get_fdata().astype(np.uint8)

try:
Expand All @@ -496,7 +497,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
stats_file = None
stats = get_basic_statistics(seg, ct_img, stats_file,
quiet, task, statistics_exclude_masks_at_border,
roi_subset)
roi_subset, metric=stats_aggregation)
# get_radiomics_features_for_entire_dir(input, output, output / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

Expand Down
12 changes: 9 additions & 3 deletions totalsegmentator/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def get_basic_statistics(seg: np.array,
quiet: bool=False,
task: str="total",
exclude_masks_at_border: bool=True,
roi_subset: list=None):
roi_subset: list=None,
metric: str="mean"):
"""
ct_file: path to a ct_file or a nifti file object
"""
Expand All @@ -123,8 +124,13 @@ def get_basic_statistics(seg: np.array,
else:
stats[mask_name]["volume"] = data.sum() * vox_vol # vol in mm3; 0.2s
roi_mask = (data > 0).astype(np.uint8) # 0.16s
# stats[mask_name]["intensity"] = ct[roi_mask > 0].mean().round(2) if roi_mask.sum() > 0 else 0.0 # 3.0s
stats[mask_name]["intensity"] = np.average(ct, weights=roi_mask).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s
st = time.time()
if metric == "mean":
# stats[mask_name]["intensity"] = ct[roi_mask > 0].mean().round(2) if roi_mask.sum() > 0 else 0.0 # 3.0s
stats[mask_name]["intensity"] = np.average(ct, weights=roi_mask).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.03s
elif metric == "median":
stats[mask_name]["intensity"] = np.median(ct[roi_mask > 0]).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s # fast lowres mode: 0.014s
# print(f"took: {time.time()-st:.4f}s")

if file_out is not None:
# For nora json is good
Expand Down

0 comments on commit a42a54f

Please sign in to comment.