forked from coordinated-systems-lab/SCAN-AAAI2021
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MATRiX_prediction_server.py
134 lines (109 loc) · 4.4 KB
/
MATRiX_prediction_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
#from torch.autograd import Variable
import glob
import matplotlib
import seaborn as sns
matplotlib.use('agg')
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from torch.utils.data import DataLoader
from arguments import parse_arguments
from model import TrajectoryGenerator
from data import dataset, collate_function
from generative_utils import *
from utils import *
import socket
HOST = "127.0.0.1"
PORT = 2063
# Parse Arguments
args = parse_arguments()
args.obs_len = 8
args.pred_len = 12
args.model_type = "spatial"
args.dset_name = "zara1"
args.best_k = 5
args.l = 0.1
args.delim = "\t"
def get_prediction(batch, model, args, generative=False):
# Generate Predictions
batch = get_batch(batch)
predictions, _, sequence, _, _, _ = predict(batch, model)
return predictions.unsqueeze(1), sequence
def reload_data():
testdataset = dataset(glob.glob(f'data/MATRiX/sample.txt'), args)
print(f'Number of Test Samples: {len(testdataset)}')
print('-'*100)
return testdataset
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize Test Dataset
#testdataset = reload_data()
# Initialize DataLoader
#testloader = DataLoader(testdataset, batch_size=1, collate_fn=collate_function(), shuffle=False)
k = args.best_k
l = args.l
# Initialize Model, Load Saved Weights
generative = ('generative' in args.model_type)
model = TrajectoryGenerator(model_type=args.model_type, obs_len=args.obs_len, pred_len=args.pred_len, feature_dim=2, embedding_dim=args.embedding_dim, encoder_dim=args.encoder_dim, decoder_dim=args.decoder_dim, attention_dim=args.attention_dim, domain_parameter=args.domain_parameter, delta_bearing=args.delta_bearing, delta_heading=args.delta_heading, pretrained_scene='resnet18', device=device, noise_dim=args.noise_dim if generative else None, noise_type=args.noise_type if generative else None).float().to(device)
if generative:
model_file = f'./trained-models/{args.model_type}/{args.dset_name}/{args.best_k}V-{args.l}_g.pt'
else:
model_file = f'./trained-models/{args.model_type}/{args.dset_name}.pt'
model.load_state_dict(torch.load(model_file))
def MATRIX_predictions():
testdataset = reload_data()
testloader = DataLoader(testdataset, batch_size=1, collate_fn=collate_function(), shuffle=False)
for b, batch in enumerate(testloader):
sequence, target, dist_matrix, bearing_matrix, heading_matrix, ip_mask, \
op_mask, pedestrians, batch_mean, batch_var = batch
if pedestrians.data<2:
continue
predictions, sequence = get_prediction(batch, model, args)
predictions = predictions.squeeze(0)
predictions = predictions.clone().detach().cpu()
sequence = sequence.squeeze(0).clone().detach().cpu()
target = target.squeeze(0).clone().detach().cpu()
gt_traj = torch.cat((sequence, target), dim=1)
num_ped, slen = sequence.size()[:2]
print(f"Predicting trajectories for {num_ped} agents")
#Empty array to hold prediction strings
pStrings = []
#For each pedestrian in the scene
for p1 in range(num_ped):
seq_p1 = sequence[p1,...]
#Retrieve predictions
pred_p1 = torch.Tensor.tolist(predictions[:,p1,...])[0]
#Create empty prediction string
pString = ""
#For each x/y pair, add {x}/{y} to the string followed by a comma
for pair in pred_p1:
print(pair)
pString += f"{pair[0]}/{pair[1]}"
pString += ","
pString = pString[:-1] #Remove the last comma
pStrings.append(pString)
#Empty string to eventually pass through server
dataString = ""
for string in pStrings:
dataString += f"{string}|"
dataString = dataString[:-1]
return dataString
# Start server
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((HOST, PORT))
s.listen()
conn, addr = s.accept()
with conn:
print(f"Connected by {addr}")
while True:
data = conn.recv(1024)
if not data:
break
data = data.decode()
# Do whatever with data
dataString = MATRIX_predictions()
print(dataString)
conn.sendall(dataString.encode())