Skip to content

Commit

Permalink
Merge pull request #14 from nithinvenny07/master
Browse files Browse the repository at this point in the history
resume training implemented
  • Loading branch information
FGiuliari authored Jan 6, 2021
2 parents 18a7877 + 9cf67f1 commit baec1a1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train_individualTF.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def main():
parser.add_argument('--save_step', type=int, default=1)
parser.add_argument('--warmup', type=int, default=10)
parser.add_argument('--evaluate', type=bool, default=True)
parser.add_argument('--model_pth', type=str)



Expand Down Expand Up @@ -107,8 +108,8 @@ def main():
import individual_TF
model=individual_TF.IndividualTF(2, 3, 3, N=args.layers,
d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,mean=[0,0],std=[0,0]).to(device)


if args.resume_train:
model.load_state_dict(torch.load(f'models/Individual/{args.name}/{args.model_pth}'))

tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
Expand Down

0 comments on commit baec1a1

Please sign in to comment.