-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patha1_dtw.py
executable file
·129 lines (96 loc) · 3.1 KB
/
a1_dtw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from dtw import dtw
from fastdtw import fastdtw
from ast import literal_eval
from math import radians, cos, sin, atan2, sqrt
import gmplot
import operator
import numpy as np
import pandas as pd
import timeit
import errno
import os
#User chooses fast or simple dtw
answer = ''
while (answer != '1' and answer != '2'):
answer = raw_input("Press 1 for DTW\nPress 2 for Fast DTW\nAnswer: ")
if answer == '1':
print "Running with DTW"
else:
print "Running with Fast DTW"
trainSet = pd.read_csv(
'train_set.csv', # replace with the correct path
converters={"Trajectory": literal_eval},
index_col='tripId'
)
testSet = pd.read_csv(
'test_set_a1.csv', # replace with the correct path
converters={"Trajectory": literal_eval},
sep='\t'
)
trainSet_ids = np.array(trainSet["journeyPatternId"])
trainSet = np.array(trainSet["Trajectory"])
testSet = np.array(testSet["Trajectory"])
def haversine(x, y):
timestamp1, lons1, lats1 = x
timestamp2, lons2, lats2 = y
lons1, lats1, lons2, lats2 = map(radians, [lons1, lats1, lons2, lats2])
dif_lons = lons2 - lons1
dif_lats = lats2 - lats1
a = sin(dif_lats/2)**2 + cos(lats1) * cos(lats2) * sin(dif_lons/2)**2
c = 2 * atan2(sqrt(a),sqrt(1-a))
r = 6371
return c * r
#sub data
#trainSet = trainSet[0:100]
if(answer == '1'):
path_dir = 'Maps_A1_DTW'
else:
path_dir = 'Maps_A1_FastDTW'
try:
os.mkdir(path_dir)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
row_list = []
test_counter = 1
for test in testSet:
start = timeit.default_timer()
row = []
trip = "Trip Id " + str(test_counter)
row.append(trip)
trajectories_list = []
i = 0
for trajectory in trainSet:
if answer == '1':
distance, cost, accuracy, path = dtw(test,trajectory, dist=haversine)
else:
distance, path = fastdtw(test, trajectory, dist=haversine)
trajectories_list.append((i, distance))
i = i + 1
trajectories_list.sort(key=operator.itemgetter(1))
list_map = trajectories_list[1:6]
timestamp, lons, lats = zip(*test)
gmap = gmplot.GoogleMapPlotter(53.350140, -6.266155, 12)
gmap.plot(lats, lons, 'green', edge_width=3)
gmap.draw(path_dir + os.sep + 'TrajectoryID_' + str(test_counter) +'.html')
n_count = 1
for map_traj in list_map:
num_map=map_traj[0]
timestamp, lons, lats = zip(*trainSet[num_map])
neighbor_info = "JP_ID: " + trainSet_ids[num_map] + " DTW: " + str(map_traj[1]) + "km"
row.append(neighbor_info)
gmap = gmplot.GoogleMapPlotter(53.350140, -6.266155, 12)
gmap.plot(lats, lons, 'green', edge_width=3)
gmap.draw(path_dir + os.sep +'TrajectoryID_' + str(test_counter) + '_Neighbor_' + str(n_count) + '.html')
n_count = n_count + 1
test_counter = test_counter + 1
stop = timeit.default_timer()
row.insert(1, stop-start)
row_list.append(row)
#print results in a1_results_dtw.csv or a1_results_fastdtw.csv depending on choice of the user
df = pd.DataFrame(row_list, columns = [ 'TripId', 'Time', 'Neighbor1', 'Neighbor2', 'Neighbor3', 'Neighbor4', 'Neighbor5' ])
if answer == '1':
df.to_csv('a1_results_dtw.csv', sep='\t', encoding='utf-8')
else:
df.to_csv('a1_results_fastdtw.csv', sep='\t', encoding='utf-8')