Skip to content

Commit

Permalink
allow nifti obj as input to python api
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 5, 2024
1 parent 4207b5c commit 6331ed6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
24 changes: 16 additions & 8 deletions tests/tests_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
import shutil
import subprocess

import nibabel as nib

from totalsegmentator.python_api import totalsegmentator


if __name__ == "__main__":

# Test python api
# Test organ predictions - fast - statistics
totalsegmentator('tests/reference_files/example_ct_sm.nii.gz', 'tests/unittest_prediction_fast', fast=True, device="cpu")
pytest.main(['-v', 'tests/test_end_to_end.py::test_end_to_end::test_prediction_fast'])
shutil.rmtree('tests/unittest_prediction_fast')
# Test organ predictions - fast
totalsegmentator("tests/reference_files/example_ct_sm.nii.gz", "tests/unittest_prediction_fast", fast=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
shutil.rmtree("tests/unittest_prediction_fast")

# Test python api - nifti input
input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz")
totalsegmentator(input_img, "tests/unittest_prediction_fast", fast=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_fast"])
shutil.rmtree("tests/unittest_prediction_fast")

# Test terminal
# Test organ predictions - fast - multilabel
Expand All @@ -23,10 +31,10 @@
file_in = os.path.join("tests", "reference_files", "example_ct_sm.nii.gz")
file_out = os.path.join("tests", "unittest_prediction_fast.nii.gz")
subprocess.call(f"TotalSegmentator -i {file_in} -o {file_out} --fast --ml -d cpu", shell=True)
pytest.main(['-v', 'tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast'])
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast"])
os.remove(file_out)

# Test Dicom input
totalsegmentator('tests/reference_files/example_ct_dicom', 'tests/unittest_prediction_dicom.nii.gz', fast=True, ml=True, device="cpu")
pytest.main(['-v', 'tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom'])
os.remove('tests/unittest_prediction_dicom.nii.gz')
totalsegmentator("tests/reference_files/example_ct_dicom", "tests/unittest_prediction_dicom.nii.gz", fast=True, ml=True, device="cpu")
pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_dicom"])
os.remove("tests/unittest_prediction_dicom.nii.gz")
10 changes: 8 additions & 2 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,14 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
if not quiet: print("Calculating radiomics...")
st = time.time()
stats_dir = output.parent if ml else output
get_radiomics_features_for_entire_dir(input, output, stats_dir / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")
with tempfile.TemporaryDirectory(prefix="radiomics_tmp_") as tmp_folder:
if isinstance(input, Nifti1Image):
input_path = tmp_folder / "ct.nii.gz"
nib.save(input, input_path)
else:
input_path = input
get_radiomics_features_for_entire_dir(input_path, output, stats_dir / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

return seg_img

0 comments on commit 6331ed6

Please sign in to comment.