From 5b1f18187dcbfe852afab7aeb70f9be6cf4c0fce Mon Sep 17 00:00:00 2001 From: KevKibe Date: Thu, 12 Sep 2024 19:06:53 +0300 Subject: [PATCH] fix(eval): implement DataParallel for multi-GPU inferencing --- training/run_eval.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/training/run_eval.py b/training/run_eval.py index e89da78..5de2537 100644 --- a/training/run_eval.py +++ b/training/run_eval.py @@ -47,6 +47,7 @@ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE from transformers.utils import check_min_version, is_accelerate_available from transformers.utils.versions import require_version +from torch.nn import DataParallel # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -561,6 +562,9 @@ def main(): cache_dir=data_args.cache_dir, variant=data_args.model_variant, ) + if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs") + model = DataParallel(model) model.to("cuda:0", dtype=dtype) model_pipeline = None