diff --git a/train_individualTF.py b/train_individualTF.py index ac00af8..9687ae1 100644 --- a/train_individualTF.py +++ b/train_individualTF.py @@ -234,50 +234,50 @@ def main(): frames = [] dt = [] - for id_b,batch in enumerate(test_dl): - inp_.append(batch['src']) - gt.append(batch['trg'][:,:,0:2]) + for id_b, batch in enumerate(test_dl): + inp_.append(batch['src'][:, :, 0:2]) + gt.append(batch['trg'][:, :, 0:2]) frames.append(batch['frames']) peds.append(batch['peds']) dt.append(batch['dataset']) - inp = (batch['src'][:, 1:, 2:4].to(device) - mean.to(device)) / std.to(device) src_att = torch.ones((inp.shape[0], 1, inp.shape[1])).to(device) start_of_seq = torch.Tensor([0, 0, 1]).unsqueeze(0).unsqueeze(1).repeat(inp.shape[0], 1, 1).to( device) - dec_inp=start_of_seq + dec_inp = start_of_seq for i in range(args.preds): trg_att = subsequent_mask(dec_inp.shape[1]).repeat(dec_inp.shape[0], 1, 1).to(device) out = model(inp, dec_inp, src_att, trg_att) - dec_inp=torch.cat((dec_inp,out[:,-1:,:]),1) - + dec_inp = torch.cat((dec_inp, out[:, -1:, :]), 1) - preds_tr_b=(dec_inp[:,1:,0:2]*std.to(device)+mean.to(device)).cpu().numpy().cumsum(1)+batch['src'][:,-1:,0:2].cpu().numpy() + preds_tr_b = (dec_inp[:, 1:, 0:2] * std.to(device) + mean.to(device)).cpu().numpy().cumsum(1) + \ + batch['src'][:, -1:, 0:2].cpu().numpy() pr.append(preds_tr_b) print("test epoch %03i/%03i batch %04i / %04i" % ( - epoch, args.max_epoch, id_b, len(test_dl))) + epoch, args.max_epoch, id_b, len(test_dl))) peds = np.concatenate(peds, 0) frames = np.concatenate(frames, 0) dt = np.concatenate(dt, 0) gt = np.concatenate(gt, 0) + inp_ = np.concatenate(inp_, 0) dt_names = test_dataset.data['dataset_name'] pr = np.concatenate(pr, 0) - mad, fad, errs = baselineUtils.distance_metrics(gt, pr) - + mad, fad, errs = baselineUtils.distance_metrics(gt, pr) # In this method, we take euclidean dist bw all true trajectory points and pred trajs points, and then divide by total number of trajs points log.add_scalar('eval/DET_mad', mad, epoch) log.add_scalar('eval/DET_fad', fad, epoch) - + # print(gt.shape, inp_.shape, pr.shape) # log.add_scalar('eval/DET_mad', mad, epoch) # log.add_scalar('eval/DET_fad', fad, epoch) scipy.io.savemat(f"output/Individual/{args.name}/det_{epoch}.mat", - {'input': inp, 'gt': gt, 'pr': pr, 'peds': peds, 'frames': frames, 'dt': dt, + {'input': inp_, 'gt': gt, 'pr': pr, 'peds': peds, 'frames': frames, 'dt': dt, 'dt_names': dt_names}) + if epoch%args.save_step==0: torch.save(model.state_dict(),f'models/Individual/{args.name}/{epoch:05d}.pth')