Skip to content

Commit

Permalink
Merge pull request #14 from CognitiveHorizons/align-resume-mode
Browse files Browse the repository at this point in the history
Resume mode for aligner.
  • Loading branch information
mrdrozdov authored Feb 17, 2022
2 parents 20ea6f0 + 2678b36 commit f2a539c
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions align_cfg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ def forward(self, batch_map, input_vector=None):
return output, labels, labels_mask, label_node_ids


def load_checkpoint(path, net, cuda=False):
def load_checkpoint(path, net, opt, cuda=False):
try:
toload = torch.load(path)
except:
Expand All @@ -1195,7 +1195,7 @@ def load_checkpoint(path, net, cuda=False):
print('[load] copying {}'.format(k))
toload['state_dict'][k] = state_dict[k]

assert v.shape == toload['state_dict'][k].shape, k
assert v.shape == toload['state_dict'][k].shape, (k, v.shape, toload['state_dict'][k].shape)

seen.add(k)

Expand All @@ -1208,8 +1208,14 @@ def load_checkpoint(path, net, cuda=False):
# TODO: Verify that vocab lines up.
net.load_state_dict(toload['state_dict'])

try:
opt.load_state_dict(toload['opt_state_dict'])
except Exception as e:
print(e)
print('WARNING: Failed to load opt state dict. Be careful if resuming training.')


def save_checkpoint(path, dataset, net, metrics=None):
def save_checkpoint(path, dataset, net, opt, metrics=None):
state_dict = net.state_dict()

for k, v in net.named_parameters():
Expand All @@ -1222,6 +1228,7 @@ def save_checkpoint(path, dataset, net, metrics=None):
tosave['text_vocab'] = dataset.text_tokenizer.vocab
tosave['amr_vocab'] = dataset.amr_tokenizer.vocab
tosave['metrics'] = metrics
tosave['opt_state_dict'] = opt.state_dict()

try:
torch.save(tosave, path, _use_new_zipfile_serialization=False)
Expand Down Expand Up @@ -1570,8 +1577,13 @@ def main(args):

# Init model.
net = Net.from_dataset_and_config(trn_dataset, model_config, args.cache_dir)

# OPTIMIZER
opt = optim.Adam(net.parameters(), lr=lr)

# LOAD
if args.load is not None:
load_checkpoint(args.load, net)
load_checkpoint(args.load, net, opt)

if args.cuda:
net.cuda()
Expand All @@ -1585,10 +1597,6 @@ def main(args):

maybe_write(context)

# OPTIMIZER

opt = optim.Adam(net.parameters(), lr=lr)

# CACHE dataset items.

for dset in [trn_dataset] + val_dataset_list:
Expand Down Expand Up @@ -1691,10 +1699,10 @@ def func():
print('trn epoch = {}, loss = {:.3f}, loss-nr = {:.3f}, ppl = {:.3f}, pr = {:.3f}'.format(
epoch, trn_loss, trn_loss_notreduced, trn_ppl, trn_pr))

save_checkpoint(os.path.join(args.log_dir, 'model.latest.pt'), trn_dataset, net, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced))
save_checkpoint(os.path.join(args.log_dir, 'model.latest.pt'), trn_dataset, net, opt, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced))

if (epoch + 1) % args.save_every_epoch == 0:
save_checkpoint(os.path.join(args.log_dir, 'model.epoch_{}.pt'.format(epoch)), trn_dataset, net, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced))
if epoch % args.save_every_epoch == 0:
save_checkpoint(os.path.join(args.log_dir, 'model.epoch_{}.pt'.format(epoch)), trn_dataset, net, opt, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced))

# VALIDATION

Expand Down

0 comments on commit f2a539c

Please sign in to comment.