Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: used initial observation the entire time #274

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions examples/openai-lander/evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@

env = gym.make('LunarLander-v2')

print("action space: {0!r}".format(env.action_space))
print("observation space: {0!r}".format(env.observation_space))

# print("action space: {0!r}".format(env.action_space))
# print("observation space: {0!r}".format(env.observation_space))

class LanderGenome(neat.DefaultGenome):
def __init__(self, key):
Expand Down Expand Up @@ -86,22 +85,21 @@ def __init__(self, num_workers):
def simulate(self, nets):
scores = []
for genome, net in nets:
observation_init_vals, observation_init_info = env.reset()
observation, observation_init_info = env.reset()
step = 0
data = []
while 1:
step += 1
if step < 200 and random.random() < 0.2:
action = env.action_space.sample()
else:
output = net.activate(observation_init_vals)
output = net.activate(observation)
action = np.argmax(output)

# Note: done has been deprecated.
observation, reward, terminated, done, info = env.step(action)
observation, reward, terminated, truncated, info = env.step(action)
data.append(np.hstack((observation, action, reward)))

if terminated:
if terminated or truncated:
break

data = np.array(data)
Expand Down Expand Up @@ -169,7 +167,7 @@ def run():
pop.add_reporter(neat.StdOutReporter(True))
# Checkpoint every 25 generations or 900 seconds.
pop.add_reporter(neat.Checkpointer(25, 900))

best_genomes = None
# Run until the winner from a generation is able to solve the environment
# or the user interrupts the process.
ec = PooledErrorCompute(NUM_CORES)
Expand Down Expand Up @@ -203,7 +201,7 @@ def run():
solved = True
best_scores = []
for k in range(100):
observation_init_vals, observation_init_info = env.reset()
observation, observation_init_info = env.reset()
score = 0
step = 0
while 1:
Expand All @@ -212,31 +210,40 @@ def run():
# determine the best action given the current state.
votes = np.zeros((4,))
for n in best_networks:
output = n.activate(observation_init_vals)
output = n.activate(observation)
votes[np.argmax(output)] += 1

best_action = np.argmax(votes)
# Note: done has been deprecated.
observation, reward, terminated, done, info = env.step(best_action)

observation, reward, terminated, truncated, info = env.step(best_action)
score += reward
env.render()
if terminated:
if terminated or truncated:
break

ec.episode_score.append(score)
ec.episode_length.append(step)

best_scores.append(score)
avg_score = sum(best_scores) / len(best_scores)
print(k, score, avg_score)
print(f'Solved {k} times. '
f'Last score {score}, '
f'Average score {avg_score}, '
f'Max score {np.max(best_scores)}')
if avg_score < 200:
solved = False
break

if solved:
print("Solved.")
break
except KeyboardInterrupt:
print("User break.")
break

# Save the winners.
finally:
# Save the winners.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could put an if best_genomes: here to handle the case where the user terminates within the first 5 generations (if I read it correctly, best_genomes will be undefined in that case).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I defined best_genomes = None in the top to make sure the variable exists and can be checked and added an if best_genomes on the bottom

if best_genomes:
for n, g in enumerate(best_genomes):
name = 'winner-{0}'.format(n)
with open(name + '.pickle', 'wb') as f:
Expand All @@ -245,10 +252,6 @@ def run():
visualize.draw_net(config, g, view=False, filename=name + "-net.gv")
visualize.draw_net(config, g, view=False, filename=name + "-net-pruned.gv", prune_unused=True)

break
except KeyboardInterrupt:
print("User break.")
break

env.close()

Expand Down