diff --git a/train_individualTF.py b/train_individualTF.py index a17cf14..ac00af8 100644 --- a/train_individualTF.py +++ b/train_individualTF.py @@ -42,6 +42,7 @@ def main(): parser.add_argument('--save_step', type=int, default=1) parser.add_argument('--warmup', type=int, default=10) parser.add_argument('--evaluate', type=bool, default=True) + parser.add_argument('--model_pth', type=str) @@ -107,8 +108,8 @@ def main(): import individual_TF model=individual_TF.IndividualTF(2, 3, 3, N=args.layers, d_model=args.emb_size, d_ff=2048, h=args.heads, dropout=args.dropout,mean=[0,0],std=[0,0]).to(device) - - + if args.resume_train: + model.load_state_dict(torch.load(f'models/Individual/{args.name}/{args.model_pth}')) tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)