Skip to content

Commit

Permalink
update scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Oct 19, 2024
1 parent 0dee0b4 commit acc40bb
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/open_clip_train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,12 @@ def main(args):
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

scaler = None
if args.precision == "amp":
if args.device == "npu" and torch.npu.is_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
scaler = GradScaler() if args.precision == "amp" else None
try:
scaler = torch.amp.GradScaler(device=device)
except (AttributeError, TypeError) as e:
scaler = torch.cuda.amp.GradScaler()

# optionally resume from a checkpoint
start_epoch = 0
Expand Down

0 comments on commit acc40bb

Please sign in to comment.