Skip to content

Commit

Permalink
Formats train genetic
Browse files Browse the repository at this point in the history
  • Loading branch information
runeharlyk committed Jan 14, 2024
1 parent 78ed313 commit 9d744ba
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions train_GeneticAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,35 @@
import multiprocessing
from environment.tetris import Tetris
from agents.GeneticAgent import GeneticAgent

multiprocessing.freeze_support()
from utils.plot import ScatterPlot

level_multi = False
cols = 10
rows = 20
state_size = 5
level_multi = False
cols = 10
rows = 20
state_size = 5

population_size = 20
elite_pct = 0.1
parent_pct = 0.8
elite_pct = 0.1
parent_pct = 0.8

num_gens = 10
max_steps = 1000
num_gens = 10
max_steps = 1000

mutation_value = 0.02
mutation_value = 0.02
mutation_chance = 0.9
crossover_rate = 0.7
crossover_rate = 0.7

agent = GeneticAgent(state_size, elite_pct, population_size, max_steps, num_gens, mutation_value, mutation_chance)
agent = GeneticAgent(
state_size,
elite_pct,
population_size,
max_steps,
num_gens,
mutation_value,
mutation_chance,
)


def train():
Expand All @@ -37,30 +46,31 @@ def train():
scores = np.array(list(tqdm(pool.imap_unordered(agent.get_fitness, args), desc='Population done', leave=False)))

agent.weights = agent.weights[scores.argsort()[::-1]]

all_scores[gen] = scores
survivors = int(np.ceil(elite_pct*population_size))
parents = agent.weights[:int(population_size*parent_pct)-survivors]
survivors = int(np.ceil(elite_pct * population_size))
parents = agent.weights[: int(population_size * parent_pct) - survivors]

for i in range(survivors, population_size, 2):
idx = np.random.randint(0, len(parents), size=2)
parent1, parent2 = parents[idx,:]
best_per_gen[gen] = agent.weights[scores.argmax()]
parent1, parent2 = parents[idx, :]

best_per_gen[gen] = agent.weights[scores.argmax()]
if np.random.random() < crossover_rate:
child1, child2 = agent.breed(parent1, parent2)
else:

else:
child1, child2 = parent1, parent2

agent.weights[i] = child1
if i + 1 < population_size:
agent.weights[i+1] = child2
agent.weights[i + 1] = child2
for i, score in enumerate(scores):
plot.add_point(i+gen*population_size,score, True)
plot.add_point(i + gen * population_size, score, True)
tqdm.write(f"Final best weights: {agent.weights[0]}")
plot.save("genetic_scores.csv")
np.savetxt("best_weights_per_gen.csv", best_per_gen, delimiter=",")

if __name__ == '__main__':

if __name__ == "__main__":
train()

0 comments on commit 9d744ba

Please sign in to comment.