diff --git a/train_GeneticAgent.py b/train_GeneticAgent.py index 2a0f5c3..c68f34e 100644 --- a/train_GeneticAgent.py +++ b/train_GeneticAgent.py @@ -46,6 +46,7 @@ def train(): scores = np.array(list(tqdm(pool.imap_unordered(agent.get_fitness, args), desc='Population done', leave=False))) agent.weights = agent.weights[scores.argsort()[::-1]] + best_per_gen[gen] = agent.weights[0] all_scores[gen] = scores survivors = int(np.ceil(elite_pct * population_size)) @@ -55,7 +56,6 @@ def train(): idx = np.random.randint(0, len(parents), size=2) parent1, parent2 = parents[idx, :] - best_per_gen[gen] = agent.weights[scores.argmax()] if np.random.random() < crossover_rate: child1, child2 = agent.breed(parent1, parent2)