From 4eb1fbe33be6cbc33588fd99ee18d8c528f17a4d Mon Sep 17 00:00:00 2001 From: Anton Emelyanov Date: Thu, 12 Aug 2021 21:37:08 +0300 Subject: [PATCH] fix printing of loading cpt --- src/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils.py b/src/utils.py index 6a8b952..14bebc3 100644 --- a/src/utils.py +++ b/src/utils.py @@ -329,10 +329,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False): # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - # Load the checkpoint. if os.path.isfile(checkpoint_name): sd = torch.load(checkpoint_name, map_location='cpu') @@ -344,6 +340,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False): ) sd = torch.load(checkpoint_name, map_location='cpu') + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + if isinstance(model, torchDDP): model = model.module