Skip to content

Commit

Permalink
Take out the old function and add a new test function for the still i…
Browse files Browse the repository at this point in the history
…mage with text. The font size format needs to be improved.
  • Loading branch information
Fatima Davelouis Gallardo authored and Fatima Davelouis Gallardo committed Nov 20, 2018
1 parent 86fba08 commit 292ca10
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 66 deletions.
39 changes: 29 additions & 10 deletions bin/plotting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from driving_gridworld.matplotlib import Progress
from driving_gridworld.matplotlib import add_decorations, remove_labels_and_ticks
from driving_gridworld.gridworld import DrivingGridworld
from driving_gridworld.matplotlib import new_plot_frame_with_text
from driving_gridworld.human_ui import observation_to_img, obs_to_rgb
#from exe.road import new_road
from driving_gridworld.road import Road
from driving_gridworld.car import Car
from driving_gridworld.obstacles import Pedestrian, Bump
Expand All @@ -21,35 +21,54 @@ def new_road(headlight_range=2):
return Road(
headlight_range,
Car(2, 2),
obstacles=[
Bump(0, 2),
Pedestrian(1, 1)
],
obstacles=[Bump(0, 2), Pedestrian(1, 1)],
allowed_obstacle_appearance_columns=[{2}, {1}],
allow_crashing=True)


def ensure_dir(dir_name):
try:
os.mkdir(dir_name)
except FileExistsError:
return


# create a single image with no text:
def test_still_image_with_text():
def test_still_image_with_no_text():
game = DrivingGridworld(new_road)
observation = game.its_showtime()[0]
img = observation_to_img(observation, obs_to_rgb)

fig, ax = plt.subplots(figsize=(3, 10))
ax = add_decorations(img, remove_labels_and_ticks(ax))
ax.imshow(img, aspect=1.5)
plt.show()
# plt.show()
my_path = os.path.dirname(os.path.realpath(__file__))
dir_name = my_path + '/../tmp'
ensure_dir(dir_name)
fig.savefig(dir_name + '/img_no_text.pdf')
# download_figure('dg-img.pdf')


def test_still_image_with_text():
game = DrivingGridworld(new_road)
observation = game.its_showtime()[0]
img = observation_to_img(observation, obs_to_rgb)
reward_function_list = [Progress(), Bumps(), Ditch(), Crashes()]
info_lists = []
frames = [[]]
info_lists.append([f.new_info() for f in reward_function_list])
fig, ax = plt.subplots(figsize=(3, 15))
frame, ax_texts = new_plot_frame_with_text(
img, 0, *info_lists[0], fig=fig, ax=ax)[:2]
frames[0] += [frame] + ax_texts

ax = add_decorations(img, remove_labels_and_ticks(ax))
ax.imshow(img, aspect=1.5)
plt.show()
my_path = os.path.dirname(os.path.realpath(__file__))
dir_name = my_path + '/../tmp'
ensure_dir(dir_name)
fig.savefig(dir_name + '/img_with_text.pdf')


if __name__ == '__main__':
test_still_image_with_text()
test_still_image_with_no_text()
57 changes: 1 addition & 56 deletions driving_gridworld/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,67 +177,12 @@ def new_plot_frame_with_text(img,
horizontalalignment='left',
fontproperties=font)
]
add_decorations(ax)
add_decorations(img, ax)

return ax.imshow(
extended_img, animated=animated, aspect=1.5), ax_texts, fig, ax


def plot_frame_with_text(img,
reward,
discounted_return,
action,
fig=None,
ax=None,
animated=False,
show_grid=False):
white_matrix = np.ones(img.shape)
extended_img = np.concatenate((img, white_matrix), axis=1)

text_list = [
'Action: {}'.format(ACTION_NAMES[action]),
'Reward: {:0.2f}'.format(reward),
'Return: {:0.2f}'.format(discounted_return)
]

if fig is None:
fig = plt.figure()

if ax is None:
ax = fig.add_subplot(111)

ax.grid(show_grid)

# Remove ticks and tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
for tic in ax.xaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False
for tic in ax.yaxis.get_major_ticks():
tic.tick1On = tic.tick2On = False

column = img.shape[1] - 0.4
ax_texts = [ax.annotate(t, (column, i)) for i, t in enumerate(text_list)]

return ax.imshow(extended_img, animated=animated), ax_texts, fig, ax


def plot_rollout(policy, game, num_steps=100, policy_on_game=False):
rollout = Rollout(policy, game, policy_on_game=policy_on_game)
frames = []

fig = None
ax = None
for t, o, a, r, d, o_prime, dr in rollout:
if t >= num_steps:
break

frame, ax_texts, fig, ax = plot_frame_with_text(
observation_to_img(o), r, dr, a, fig=fig, ax=ax)
frames.append([frame] + ax_texts)
return frames, fig, ax


class Simulator(object):
def __init__(self, policy, game):
self.policy = policy
Expand Down

1 comment on commit 292ca10

@daveloui
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "old function" in the commit message refers to plot_frame_with_text.

Please sign in to comment.