Skip to content

Commit

Permalink
fix(eval): implement DataParallel for multi-GPU inferencing
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed Sep 12, 2024
1 parent 5818ca3 commit 5b1f181
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions training/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.utils import check_min_version, is_accelerate_available
from transformers.utils.versions import require_version
from torch.nn import DataParallel


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -561,6 +562,9 @@ def main():
cache_dir=data_args.cache_dir,
variant=data_args.model_variant,
)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs")
model = DataParallel(model)
model.to("cuda:0", dtype=dtype)

model_pipeline = None
Expand Down

0 comments on commit 5b1f181

Please sign in to comment.