diff --git a/alphazero/config.py b/alphazero/config.py index b3f21efa..2b5d4f71 100644 --- a/alphazero/config.py +++ b/alphazero/config.py @@ -10,6 +10,8 @@ root_exploration_fraction = 0.25 pb_c_base = 1 # 19652 in pseudocode pb_c_init = 1.25 +min_reward = -1. # Minimum reward to return for invalid actions +reward_buffer = 25 # 250 in the R2 paper # Network l2_regularization_coef = 1e-4 diff --git a/alphazero/game.py b/alphazero/game.py index 77affa39..caaa85bb 100644 --- a/alphazero/game.py +++ b/alphazero/game.py @@ -72,11 +72,18 @@ def expand(self, parent: Node): """ # Create the children nodes and add them to the graph - self.add_edges_from(((parent, child) for child in parent.build_children())) + children = list(parent.build_children()) + + if not children: + parent.terminal = True + parent._reward = config.min_reward + return parent._reward + + self.add_edges_from(((parent, child) for child in children)) # Run the policy network to get value and prior_logit predictions - values, prior_logits = model(parent.policy_inputs_with_children()) - prior_logits = prior_logits[1:].numpy().flatten() + values, prior_logits = model.predict(parent.policy_inputs_with_children()) + prior_logits = prior_logits[1:].flatten() # if we're adding noise, perturb the logits if self.dirichlet_noise: diff --git a/stable_radical_optimization/initialize.py b/stable_radical_optimization/initialize.py index 4e81f967..a56575af 100644 --- a/stable_radical_optimization/initialize.py +++ b/stable_radical_optimization/initialize.py @@ -1,12 +1,19 @@ -import psycopg2 +import argparse import sys - sys.path.append("..") +import psycopg2 + import alphazero.config as config import stable_rad_config # Initialize PostgreSQL tables +parser = argparse.ArgumentParser(description='Initialize the postgres tables.') +parser.add_argument("--drop", action='store_true', help="whether to drop existing tables, if found") +args = parser.parse_args() + + + dbparams = { 'dbname': 'bde', 'port': 5432, @@ -20,25 +27,28 @@ ## But, we don't want this to run every time we run the script, ## just keeping it here as a reference -with psycopg2.connect(**dbparams) as conn: +with psycopg2.connect(**dbparams) as conn: with conn.cursor() as cur: - cur.execute(""" - DROP TABLE IF EXISTS {table}_reward; - CREATE TABLE {table}_reward ( - id serial PRIMARY KEY, + if args.drop: + cur.execute(""" + DROP TABLE IF EXISTS {table}_reward; + DROP TABLE IF EXISTS {table}_replay; + DROP TABLE IF EXISTS {table}_game; + """.format(table=config.sql_basename)) + + cur.execute(""" + CREATE TABLE IF NOT EXISTS {table}_reward ( + smiles varchar(50) PRIMARY KEY, time timestamp DEFAULT CURRENT_TIMESTAMP, real_reward real, - smiles varchar(50) UNIQUE, atom_type varchar(2), buried_vol real, max_spin real, atom_index int ); - DROP TABLE IF EXISTS {table}_replay; - - CREATE TABLE {table}_replay ( + CREATE TABLE IF NOT EXISTS {table}_replay ( id serial PRIMARY KEY, time timestamp DEFAULT CURRENT_TIMESTAMP, experiment_id varchar(50), @@ -49,12 +59,11 @@ position int, data BYTEA); - DROP TABLE IF EXISTS {table}_game; - - CREATE TABLE {table}_game ( + CREATE TABLE IF NOT EXISTS {table}_game ( id serial PRIMARY KEY, time timestamp DEFAULT CURRENT_TIMESTAMP, experiment_id varchar(50), - gameid varchar(8), - real_reward real); - """.format(table=config.sql_basename)) \ No newline at end of file + gameid varchar(8), + real_reward real, + final_smiles varchar(50)); + """.format(table=config.sql_basename)) diff --git a/stable_radical_optimization/run_mcts.py b/stable_radical_optimization/run_mcts.py index f697cab3..c5b7f1d5 100644 --- a/stable_radical_optimization/run_mcts.py +++ b/stable_radical_optimization/run_mcts.py @@ -1,6 +1,9 @@ +import os import sys import uuid + sys.path.append('..') +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import numpy as np import pandas as pd @@ -18,10 +21,9 @@ import tensorflow as tf import nfp -#tf.logging.set_verbosity(tf.logging.ERROR) - model = tf.keras.models.load_model( - '/projects/rlmolecule/pstjohn/models/20200923_radical_stability_model') + '/projects/rlmolecule/pstjohn/models/20200923_radical_stability_model', + compile=False) dbparams = { 'dbname': 'bde', @@ -32,32 +34,38 @@ 'options': f'-c search_path=rl', } -def get_ranked_rewards(reward, conn=None): +def get_ranked_rewards(reward): - if conn is None: - conn = psycopg2.connect(**dbparams) - - with conn: - n_rewards = pd.read_sql_query(""" - select count(*) from {table}_reward - """.format(table=config.sql_basename), conn) + with psycopg2.connect(**dbparams) as conn: + with conn.cursor() as cur: + cur.execute("select count(*) from {table}_game;".format( + table=config.sql_basename)) + n_games = cur.fetchone()[0] + + if n_games < config.reward_buffer: + # Here, we don't have enough of a game buffer + # to decide if the move is good or not + return np.random.choice([-1., 1.]) - if n_rewards['count'][0] < config.batch_size: - return np.random.choice([-1.,1.]) else: - param = {config.ranked_reward_alpha, config.batch_size} - r_alpha = pd.read_sql_query(""" - select percentile_cont(%s) within group (order by real_reward) from ( - select real_reward - from {table}_reward - order by id desc limit %s) as finals - """.format(table=config.sql_basename), conn, params=param) - if reward > r_alpha['percentile_cont'][0]: + with conn.cursor() as cur: + cur.execute(""" + select percentile_cont(%s) within group (order by real_reward) + from (select real_reward from {table}_game + order by id desc limit %s) as finals + """.format(table=config.sql_basename), + (config.ranked_reward_alpha, config.reward_buffer)) + + r_alpha = cur.fetchone()[0] + + if reward > r_alpha: return 1. - elif reward < r_alpha['percentile_cont'][0]: + + elif reward < r_alpha: return -1. + else: - return np.random.choice([-1.,1.]) + return np.random.choice([-1., 1.]) class StabilityNode(Node): @@ -66,7 +74,9 @@ def get_reward(self): with psycopg2.connect(**dbparams) as conn: with conn.cursor() as cur: - cur.execute("select real_reward from {table}_reward where smiles = %s".format(table=config.sql_basename), (self.smiles,)) + cur.execute( + "select real_reward from {table}_reward where smiles = %s".format( + table=config.sql_basename), (self.smiles,)) result = cur.fetchone() if result: @@ -78,17 +88,17 @@ def get_reward(self): # Node is outside the domain of validity elif ((self.policy_inputs['atom'] == 1).any() | (self.policy_inputs['bond'] == 1).any()): - rr = get_ranked_rewards(0.) self._true_reward = 0. - return rr + return config.min_reward else: - spins, buried_vol = model( + + spins, buried_vol = model.predict( {key: tf.constant(np.expand_dims(val, 0)) for key, val in self.policy_inputs.items()}) - spins = spins.numpy().flatten() - buried_vol = buried_vol.numpy().flatten() + spins = spins.flatten() + buried_vol = buried_vol.flatten() atom_index = int(spins.argmax()) max_spin = spins[atom_index] @@ -109,7 +119,8 @@ def get_reward(self): cur.execute(""" INSERT INTO {table}_reward (smiles, real_reward, atom_type, buried_vol, max_spin, atom_index) - values (%s, %s, %s, %s, %s, %s);""".format(table=config.sql_basename), ( + values (%s, %s, %s, %s, %s, %s) + ON CONFLICT DO NOTHING;""".format(table=config.sql_basename), ( self.smiles, float(reward), atom_type, # This should be the real reward float(spin_buried_vol), float(max_spin), atom_index)) @@ -128,19 +139,39 @@ def run_game(): game = list(G.run_mcts()) reward = game[-1].reward # here it returns the ranked reward - + + try: + terminal_true_reward = float(game[-1]._true_reward) + + except AttributeError: + # This is hacky until we have a better separation of `true_reward` + # and `ranked_reward`. This can happen if a node doesn't have any + # children, but still gets chosen as the final state. + terminal_true_reward = 0. + + with psycopg2.connect(**dbparams) as conn: + + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO {table}_game + (experiment_id, gameid, real_reward, final_smiles) values (%s, %s, %s, %s); + """.format(table=config.sql_basename), ( + config.experiment_id, gameid, terminal_true_reward, game[-1].smiles)) + for i, node in enumerate(game[:-1]): with conn.cursor() as cur: cur.execute( - """INSERT INTO {table}_replay - (experiment_id, gameid, smiles, final_smiles, ranked_reward, position, data) values (%s, %s, %s, %s, %s, %s, %s); - - INSERT INTO {table}_game - (experiment_id, gameid, real_reward) values (%s, %s, %s); + """ + INSERT INTO {table}_replay + (experiment_id, gameid, smiles, final_smiles, ranked_reward, position, data) + values (%s, %s, %s, %s, %s, %s, %s); """.format(table=config.sql_basename), ( config.experiment_id, gameid, node.smiles, game[-1].smiles, reward, i, - node.get_action_inputs_as_binary(),config.experiment_id, gameid, float(game[-1]._true_reward))) + node.get_action_inputs_as_binary())) + + print(f'finishing game {gameid}', flush=True) diff --git a/stable_radical_optimization/submit_mcts.sh b/stable_radical_optimization/submit_mcts.sh index 55512f48..0ed0034f 100644 --- a/stable_radical_optimization/submit_mcts.sh +++ b/stable_radical_optimization/submit_mcts.sh @@ -3,9 +3,9 @@ #SBATCH --time=00:20:00 #SBATCH --job-name=mcts_q2_debug #SBATCH --partition=debug -#SBATCH -n 4 -#SBATCH -c 18 -#SBATCH --output=/scratch/eskordil/git-repos/rlmolecule_new/rlmolecule/mcts.%j.out +#SBATCH -n 72 +#SBATCH -c 1 +#SBATCH --output=/scratch/pstjohn/rlmolecule/mcts.%j.out source ~/.bashrc conda activate /projects/rlmolecule/pstjohn/envs/tf2_cpu