-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
addressing issues 5, 8, (part of 13) - new #17
Changes from all commits
7671be6
cd6e245
174db85
0eada33
7e0a4de
e257660
13a9802
6053c83
22b7f3c
d8dc035
0f37ecc
1a1665a
c7f7f47
c3e61b9
b830454
75525d9
f30c142
2ed5ede
dbbea76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,41 @@ | ||
|
||
class AlphaZeroConfig: | ||
|
||
def __init__(self): | ||
|
||
# Molecule | ||
self.max_atoms = 10 # max atoms in molecule | ||
self.min_atoms = 4 # max atoms in molecule | ||
|
||
# MCTS / rollout | ||
self.lru_cache_maxsize = 100000 | ||
self.num_rollouts = 1000 # should we limit, if so how much? | ||
self.num_simulations = 256 # number of simulations used by MCTS per game step | ||
self.root_dirichlet_alpha = 0.0 # 0.3 chess, 0.03 Go, 0.15 shogi | ||
self.root_exploration_fraction = 0.25 | ||
self.pb_c_base = 1 # 19652 in pseudocode | ||
self.pb_c_init = 1.25 | ||
# Molecule | ||
max_atoms = 10 # max atoms in molecule | ||
min_atoms = 4 # max atoms in molecule | ||
|
||
# Network | ||
self.l2_regularization_coef = 1e-4 | ||
self.features = 16 # used by all network layers | ||
self.num_messages = 1 | ||
self.num_heads = 4 # Number of attention heads | ||
self.batch_size = 32 # for gradient updates | ||
self.checkpoint_frequency = 1 # save new model file every N batches | ||
self.batch_update_frequency = 10 # get most recent data every N updates | ||
self.gradient_steps_per_batch = 32 # num step per batch | ||
self.training_iterations = int(1e06) # training iterations for NN | ||
|
||
assert self.features % self.num_heads == 0, \ | ||
"dimension mismatch for attention heads" | ||
# MCTS / rollout | ||
lru_cache_maxsize = 100000 | ||
num_rollouts = 1000 # should we limit, if so how much? | ||
num_simulations = 256 # number of simulations used by MCTS per game step | ||
root_dirichlet_alpha = 0.0 # 0.3 chess, 0.03 Go, 0.15 shogi | ||
root_exploration_fraction = 0.25 | ||
pb_c_base = 1 # 19652 in pseudocode | ||
pb_c_init = 1.25 | ||
min_reward = -1. # Minimum reward to return for invalid actions | ||
reward_buffer = 25 # 250 in the R2 paper | ||
|
||
# Buffers | ||
self.ranked_reward_alpha = 0.9 | ||
self.buffer_max_size = 512 | ||
# Network | ||
l2_regularization_coef = 1e-4 | ||
features = 16 # used by all network layers | ||
num_messages = 1 | ||
num_heads = 4 # Number of attention heads | ||
batch_size = 32 # for gradient updates | ||
checkpoint_frequency = 1 # save new model file every N batches | ||
batch_update_frequency = 10 # get most recent data every N updates | ||
gradient_steps_per_batch = 32 # num step per batch | ||
training_iterations = int(1e06) # training iterations for NN | ||
|
||
# Training | ||
self.training_steps = 100 | ||
#assert self.features % self.num_heads == 0, \ | ||
# "dimension mismatch for attention heads" | ||
|
||
# Buffers | ||
ranked_reward_alpha = 0.9 | ||
buffer_max_size = 512 | ||
|
||
# Training | ||
training_steps = 100 | ||
|
||
# DB tables | ||
sql_basename = "Stable" | ||
|
||
# Experiment id | ||
experiment_id = "0001" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import alphazero.config as config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can directly modify config variables in the run scripts, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, |
||
|
||
# DB table names modified by the user according to their wish | ||
config.sql_basename = "StableES" | ||
|
||
# Experiment id | ||
config.experiment_id = "0001" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to fix those retracing errors we were seeing