-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
58 lines (47 loc) · 1.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Author: Aryaman Pandya
File contents: Program entry point. Select game, initalize agent, train agent.
This file will also be used to output plots and monitor training progress.
This file will also be used for evaluation and testing.
"""
import logging
import sys
import torch
from src.agent import AlphaZeroAgent
from src.mcts import MCTS
from src.models import OthelloNN
sys.path.append("Othello")
from othello_game import OthelloGame
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | [%(levelname)s] | %(name)s | %(filename)s | %(funcName)s() | line.%(lineno)d | %(message)s",
)
def main():
"""
main
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
neural_network = OthelloNN()
neural_network = neural_network.to(device)
game = OthelloGame(n=8)
learning_rate = 3e-4
l2_reg = 1e-3
actor_optimizer = torch.optim.Adam(
neural_network.parameters(), lr=learning_rate, weight_decay=l2_reg
)
agent = AlphaZeroAgent(
optimizer=actor_optimizer,
num_simulations=25,
game=game,
c_uct=1,
device=device,
mcts=MCTS(game),
)
agent.train(
train_batch_size=16,
neural_network=neural_network,
num_episodes=5,
num_epochs=100,
)
if __name__ == "__main__":
main()