From acc40bb023e9da18404c8c7d7cdcec025c37a4fb Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Sat, 19 Oct 2024 13:38:58 +0000 Subject: [PATCH] update scaler --- src/open_clip_train/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 1089ade53..1aa0750fc 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -328,12 +328,12 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) + scaler = None if args.precision == "amp": - if args.device == "npu" and torch.npu.is_available(): - from torch.npu.amp import GradScaler - else: - from torch.cuda.amp import GradScaler - scaler = GradScaler() if args.precision == "amp" else None + try: + scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + scaler = torch.cuda.amp.GradScaler() # optionally resume from a checkpoint start_epoch = 0