From 4511ac5c90e59c07773a95af942b4939b569e818 Mon Sep 17 00:00:00 2001 From: Saransh Karira Date: Thu, 18 Oct 2018 11:20:09 +0530 Subject: [PATCH] Update train.py Initialized Dataparallel before restoring from checkpoints to properly load wts --- train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 3c3513a3..77f96ead 100644 --- a/train.py +++ b/train.py @@ -90,10 +90,6 @@ def weights_init(m): crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh) crnn.apply(weights_init) -if opt.pretrained != '': - print('loading pretrained model from %s' % opt.pretrained) - crnn.load_state_dict(torch.load(opt.pretrained)) -print(crnn) image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH) text = torch.IntTensor(opt.batchSize * 5) @@ -105,6 +101,13 @@ def weights_init(m): image = image.cuda() criterion = criterion.cuda() +if opt.pretrained != '': + print('loading pretrained model from %s' % opt.pretrained) + crnn.load_state_dict(torch.load(opt.pretrained)) +print(crnn) + + + image = Variable(image) text = Variable(text) length = Variable(length)