Skip to content

Commit

Permalink
resume training implemented
Browse files Browse the repository at this point in the history
when --resume_train is specified, model is loaded with weights specified as model_path 
eg : python train_individualTF.py --dataset_name eth --name eth --max_epoch 240 --batch_size 100 --name eth_train --factor 1 --resume_train --model_pth 00013.pth
  • Loading branch information
nithinvenny07 authored Jan 6, 2021
1 parent 18a7877 commit 9cf67f1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train_individualTF.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def main():
parser.add_argument('--save_step', type=int, default=1)
parser.add_argument('--warmup', type=int, default=10)
parser.add_argument('--evaluate', type=bool, default=True)
parser.add_argument('--model_pth', type=str)



Expand Down Expand Up @@ -107,8 +108,8 @@ def main():
import individual_TF
model=individual_TF.IndividualTF(2, 3, 3, N=args.layers,
d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,mean=[0,0],std=[0,0]).to(device)


if args.resume_train:
model.load_state_dict(torch.load(f'models/Individual/{args.name}/{args.model_pth}'))

tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
Expand Down

0 comments on commit 9cf67f1

Please sign in to comment.