Skip to content

Commit

Permalink
make sure torch cudnn benchmark and num_threads setting is not perman…
Browse files Browse the repository at this point in the history
…ently changed by totalsegmentator
  • Loading branch information
wasserth committed Nov 27, 2024
1 parent 624f7d7 commit 1de4511
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa

nora_tag = "None" if nora_tag is None else nora_tag

# Store initial torch settings
initial_cudnn_benchmark = torch.backends.cudnn.benchmark
initial_num_threads = torch.get_num_threads()

validate_device_type_api(device)
device = convert_device_to_cuda(device)

Expand Down Expand Up @@ -555,6 +559,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
get_radiomics_features_for_entire_dir(input_path, output, stats_dir / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

# Restore initial torch settings
torch.backends.cudnn.benchmark = initial_cudnn_benchmark
torch.set_num_threads(initial_num_threads)

if statistics or statistics_fast:
return seg_img, stats
else:
Expand Down

0 comments on commit 1de4511

Please sign in to comment.