-
Notifications
You must be signed in to change notification settings - Fork 0
/
simulator.py
371 lines (294 loc) · 14.8 KB
/
simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 13 16:48:10 2017
@author: ECOWIZARD
"""
###########################################
# Suppress matplotlib user warnings
# Necessary for newer version of matplotlib
import warnings
warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
###########################################
import os
import time
import random
import importlib
import csv
class Simulator4x1(object):
"""Simulates agents in a dynamic smartcab environment.
Uses PyGame to display GUI, if available.
"""
colors = {
'black' : ( 0, 0, 0),
'white' : (255, 255, 255),
'red' : (255, 0, 0),
'green' : ( 0, 255, 0),
'dgreen' : ( 0, 228, 0),
'blue' : ( 0, 0, 255),
'cyan' : ( 0, 200, 200),
'magenta' : (200, 0, 200),
'yellow' : (255, 255, 0),
'mustard' : (200, 200, 0),
'orange' : (255, 128, 0),
'maroon' : (200, 0, 0),
'crimson' : (128, 0, 0),
'gray' : (155, 155, 155)
}
def __init__(self, env, size=None, update_delay=2.0, display=True, log_metrics=False, optimized=False,vanilla=False):
self.env = env
self.vanilla = vanilla
self.blocksize = 100
self.size = size if size is not None else (4 * self.blocksize,int(2.2*self.blocksize))
self.width, self.height = self.size
self.road_width = 44
self.bg_color = self.colors['gray']
self.line_color = self.colors['mustard']
self.boundary = self.colors['black']
self.stop_color = self.colors['crimson']
self.quit = False
self.start_time = None
self.current_time = 0.0
self.last_updated = 0.0
self.update_delay = update_delay # duration between each step (in seconds)
self.display = display
if self.display:
try:
self.pygame = importlib.import_module('pygame')
self.pygame.init()
self.screen = self.pygame.display.set_mode(self.size)
self.frame_delay = max(1, int(self.update_delay * 1000)) # delay between GUI frames in ms (min: 1)
#self.agent_sprite_size = (32, 32)
self.primary_agent_sprite_size = (42, 42)
self.agent_circle_radius = 20 # radius of circle, when using simple representation
#if self.pygame.font != None:
# self.font = self.pygame.font.Font(None, 20)
self.paused = False
except ImportError as e:
self.display = False
print("Simulator.__init__(): Unable to import pygame; display disabled.\n{}: {}".format(e.__class__.__name__, e))
except Exception as e:
self.display = False
print("Simulator.__init__(): Error initializing GUI objects; display disabled.\n{}: {}".format(e.__class__.__name__, e))
# Setup metrics to report
self.log_metrics = log_metrics
self.optimized = optimized
if self.log_metrics:
a = self.env.primary_agent
# Set log files
if vanilla:
if a.learning:
if self.optimized: # Whether the user is optimizing the parameters and decay functions
self.log_filename = os.path.join("logs", "qsim_improved-learning.csv")
self.table_filename = os.path.join("logs","qsim_improved-learning.txt")
else:
self.log_filename = os.path.join("logs", "qsim_default-learning.csv")
self.table_filename = os.path.join("logs","qsim_default-learning.txt")
self.table_file = open(self.table_filename, 'w')
else:
self.log_filename = os.path.join("logs", "qsim_no-learning.csv")
else:
if a.learning:
if self.optimized: # Whether the user is optimizing the parameters and decay functions
self.log_filename = os.path.join("logs", "sim_improved-learning.csv")
self.table_filename = os.path.join("logs","sim_improved-learning.txt")
else:
self.log_filename = os.path.join("logs", "sim_default-learning.csv")
self.table_filename = os.path.join("logs","sim_default-learning.txt")
self.table_file = open(self.table_filename, 'w')
else:
self.log_filename = os.path.join("logs", "sim_no-learning.csv")
self.log_fields = ['trial', 'testing', 'parameters', 'net_reward', 'age', 'success']
self.log_file = open(self.log_filename, 'w')
self.log_writer = csv.DictWriter(self.log_file, fieldnames=self.log_fields)
self.log_writer.writeheader()
def run(self, tolerance=0.05, n_test=0):
""" Run a simulation of the environment.
'tolerance' is the minimum epsilon necessary to begin testing (if enabled)
'n_test' is the number of testing trials simulated
Note that the minimum number of training trials is always 20. """
self.quit = False
# Get the primary agent
a = self.env.primary_agent
total_trials = 1
testing = False
trial = 1
while True:
# Flip testing switch
if not testing:
if total_trials > 20: # Must complete minimum 20 training trials
if a.learning:
print("epsilon = {}".format(a.epsilon))
print("tolerance = {}".format(tolerance))
if a.epsilon < tolerance: # assumes epsilon decays to 0
testing = True
trial = 1
else:
testing = True
trial = 1
# Break if we've reached the limit of testing trials
else:
if trial > n_test:
break
# Pretty print to terminal
print()
print("/-------------------------")
if testing:
print("| Testing trial {}".format(trial))
else:
print("| Training trial {}".format(trial))
print("\-------------------------")
print()
self.env.reset(testing)
self.current_time = 0.0
self.last_updated = 0.0
self.start_time = time.time()
while True:
try:
# Update current time
self.current_time = time.time() - self.start_time
# Handle GUI events
if self.display:
for event in self.pygame.event.get():
if event.type == self.pygame.QUIT:
self.quit = True
elif event.type == self.pygame.KEYDOWN:
if event.key == 27: # Esc
self.quit = True
elif event.unicode == u' ':
self.paused = True
if self.paused:
self.pause()
# Update environment
if self.current_time - self.last_updated >= self.update_delay:
self.env.step()
self.last_updated = self.current_time
# Render text
self.render_text(trial, testing)
# Render GUI and sleep
if self.display:
self.render(trial, testing)
self.pygame.time.wait(self.frame_delay)
except KeyboardInterrupt:
self.quit = True
finally:
if self.quit or self.env.done:
break
if self.quit:
break
# Collect metrics from trial
if self.log_metrics:
self.log_writer.writerow({
'trial': trial,
'testing': self.env.trial_data['testing'],
'parameters': self.env.trial_data['parameters'],
'net_reward': self.env.trial_data['net_reward'],
'age' : self.env.trial_data['age'],
'success': self.env.trial_data['success']
})
# Trial finished
if self.env.success == True:
print("\nTrial Completed!")
print("Agent reached the destination.")
else:
print("\nTrial Aborted!")
print("Agent did not reach the destination.")
# Increment
total_trials = total_trials + 1
trial = trial + 1
# Clean up
if self.log_metrics:
if a.learning:
f = self.table_file
f.write("/-----------------------------------------\n")
f.write("| State-action rewards from modified Q-Learning\n")
f.write("\-----------------------------------------\n\n")
for state in a.Q:
f.write("{}\n".format(state))
for action, reward in a.Q[state].items():
print("{} , {}".format(action,reward))
f.write(" -- {} : {:.2f}\n".format(action, reward))
f.write("\n")
self.table_file.close()
self.log_file.close()
print("\nSimulation ended. . . ")
# Report final metrics
if self.display:
self.pygame.display.quit() # shut down pygame
def render_text(self, trial, testing=False):
""" This is the non-GUI render display of the simulation.
Simulated trial data will be rendered in the terminal/command prompt. """
status = self.env.step_data
if status : # Continuing the trial
# Previous State
if status['state']:
print("Agent previous state: {}".format(status['state']))
else:
print("!! Agent state not been updated!")
# Starting new trial
else:
a = self.env.primary_agent
print("Simulating trial. . . ")
if a.learning:
print("epsilon = {:.4f}; alpha = {:.4f}".format(a.epsilon, a.alpha))
else:
print("Agent not set to learn.")
def renderMindState(self, xadjustment,yadjustment):
if self.vanilla:
return
currentState = self.env.primary_agent.environmentmodel.getState()
magnification = 50
beliefstate = self.env.primary_agent.environmentmodel.currentBelief
beliefstate = self.env.primary_agent.environmentmodel.translateNDpointto2D(beliefstate)
beliefstate[0] = int((beliefstate[0]*magnification) + xadjustment)
beliefstate[1] =int( (beliefstate[1]*magnification) + yadjustment)
for state in self.env.primary_agent.environmentmodel.pomdpStates:
coords = []
for x,y in state.coords:
coords.append([(x*magnification)+xadjustment,(y * magnification) + yadjustment])
if currentState == state:
self.pygame.draw.polygon(self.screen,self.colors["red"],coords,0)
else:
self.pygame.draw.polygon(self.screen,self.colors["red"],coords,1)
self.pygame.draw.circle(self.screen, self.colors["orange"],beliefstate,3,0)
def render(self, trial, testing=False):
""" This is the GUI render display of the simulation.
Supplementary trial data can be found from render_text. """
# Reset the screen.
self.screen.fill(self.bg_color)
# Draw elements
# * Static elements
# Boundary
screen_size = self.screen.get_size()
self.pygame.draw.rect(self.screen,self.boundary,self.pygame.Rect(1,1,screen_size[0]-2,screen_size[1]-2))
#
for position, state in self.env.states.items():
self.pygame.draw.rect(self.screen,self.colors['black'],self.pygame.Rect(position*self.blocksize,0,self.blocksize,self.blocksize))
self.pygame.draw.rect(self.screen,self.colors['white'],self.pygame.Rect((position*self.blocksize)+2,2,self.blocksize - 4, self.blocksize -4))
# * Dynamic elements
#self.font = self.pygame.font.Font(None, 20)
for agent, state in self.env.agent_states.items():
# Compute precise agent location here (back from the intersection some)
agent_pos = ((state['location'] * self.blocksize)+(self.blocksize/3), self.blocksize - (self.blocksize /2))
agent_color = self.colors[agent.color]
if state['location'] == 2:
self.pygame.draw.rect(self.screen,self.colors['yellow'],self.pygame.Rect((2*self.blocksize)+2,2,self.blocksize - 4, self.blocksize -4))
# Draw simple agent (circle with a short line segment poking out to indicate heading)
self.pygame.draw.circle(self.screen, agent_color, agent_pos, self.blocksize/3)
self.renderMindState(self.blocksize/2,1.4*self.blocksize)
# Flip buffers
self.pygame.display.flip()
def pause(self):
""" When the GUI is enabled, this function will pause the simulation. """
abs_pause_time = time.time()
self.font = self.pygame.font.Font(None, 30)
pause_text = "Simulation Paused. Press any key to continue. . ."
self.screen.blit(self.font.render(pause_text, True, self.colors['red'], self.bg_color), (400, self.height - 30))
self.pygame.display.flip()
print(pause_text)
while self.paused:
for event in self.pygame.event.get():
if event.type == self.pygame.KEYDOWN:
self.paused = False
self.pygame.time.wait(self.frame_delay)
self.screen.blit(self.font.render(pause_text, True, self.bg_color, self.bg_color), (400, self.height - 30))
self.start_time += (time.time() - abs_pause_time)