Skip to content

Commit

Permalink
fix printing of loading cpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Emelyanov committed Aug 12, 2021
1 parent 25719f1 commit 4eb1fbe
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False):
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)

if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

# Load the checkpoint.
if os.path.isfile(checkpoint_name):
sd = torch.load(checkpoint_name, map_location='cpu')
Expand All @@ -344,6 +340,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False):
)
sd = torch.load(checkpoint_name, map_location='cpu')

if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

if isinstance(model, torchDDP):
model = model.module

Expand Down

0 comments on commit 4eb1fbe

Please sign in to comment.