Skip to content

Commit

Permalink
fixing a few remaining issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Oct 2, 2020
1 parent 2ed5ede commit dbbea76
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 60 deletions.
2 changes: 2 additions & 0 deletions alphazero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
root_exploration_fraction = 0.25
pb_c_base = 1 # 19652 in pseudocode
pb_c_init = 1.25
min_reward = -1. # Minimum reward to return for invalid actions
reward_buffer = 25 # 250 in the R2 paper

# Network
l2_regularization_coef = 1e-4
Expand Down
13 changes: 10 additions & 3 deletions alphazero/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,18 @@ def expand(self, parent: Node):
"""

# Create the children nodes and add them to the graph
self.add_edges_from(((parent, child) for child in parent.build_children()))
children = list(parent.build_children())

if not children:
parent.terminal = True
parent._reward = config.min_reward
return parent._reward

self.add_edges_from(((parent, child) for child in children))

# Run the policy network to get value and prior_logit predictions
values, prior_logits = model(parent.policy_inputs_with_children())
prior_logits = prior_logits[1:].numpy().flatten()
values, prior_logits = model.predict(parent.policy_inputs_with_children())
prior_logits = prior_logits[1:].flatten()

# if we're adding noise, perturb the logits
if self.dirichlet_noise:
Expand Down
43 changes: 26 additions & 17 deletions stable_radical_optimization/initialize.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import psycopg2
import argparse
import sys

sys.path.append("..")

import psycopg2

import alphazero.config as config
import stable_rad_config
# Initialize PostgreSQL tables

parser = argparse.ArgumentParser(description='Initialize the postgres tables.')
parser.add_argument("--drop", action='store_true', help="whether to drop existing tables, if found")
args = parser.parse_args()



dbparams = {
'dbname': 'bde',
'port': 5432,
Expand All @@ -20,25 +27,28 @@
## But, we don't want this to run every time we run the script,
## just keeping it here as a reference

with psycopg2.connect(**dbparams) as conn:
with psycopg2.connect(**dbparams) as conn:
with conn.cursor() as cur:
cur.execute("""
DROP TABLE IF EXISTS {table}_reward;

CREATE TABLE {table}_reward (
id serial PRIMARY KEY,
if args.drop:
cur.execute("""
DROP TABLE IF EXISTS {table}_reward;
DROP TABLE IF EXISTS {table}_replay;
DROP TABLE IF EXISTS {table}_game;
""".format(table=config.sql_basename))

cur.execute("""
CREATE TABLE IF NOT EXISTS {table}_reward (
smiles varchar(50) PRIMARY KEY,
time timestamp DEFAULT CURRENT_TIMESTAMP,
real_reward real,
smiles varchar(50) UNIQUE,
atom_type varchar(2),
buried_vol real,
max_spin real,
atom_index int
);
DROP TABLE IF EXISTS {table}_replay;
CREATE TABLE {table}_replay (
CREATE TABLE IF NOT EXISTS {table}_replay (
id serial PRIMARY KEY,
time timestamp DEFAULT CURRENT_TIMESTAMP,
experiment_id varchar(50),
Expand All @@ -49,12 +59,11 @@
position int,
data BYTEA);
DROP TABLE IF EXISTS {table}_game;
CREATE TABLE {table}_game (
CREATE TABLE IF NOT EXISTS {table}_game (
id serial PRIMARY KEY,
time timestamp DEFAULT CURRENT_TIMESTAMP,
experiment_id varchar(50),
gameid varchar(8),
real_reward real);
""".format(table=config.sql_basename))
gameid varchar(8),
real_reward real,
final_smiles varchar(50));
""".format(table=config.sql_basename))
105 changes: 68 additions & 37 deletions stable_radical_optimization/run_mcts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys
import uuid

sys.path.append('..')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
import pandas as pd
Expand All @@ -18,10 +21,9 @@
import tensorflow as tf
import nfp

#tf.logging.set_verbosity(tf.logging.ERROR)

model = tf.keras.models.load_model(
'/projects/rlmolecule/pstjohn/models/20200923_radical_stability_model')
'/projects/rlmolecule/pstjohn/models/20200923_radical_stability_model',
compile=False)

dbparams = {
'dbname': 'bde',
Expand All @@ -32,32 +34,38 @@
'options': f'-c search_path=rl',
}

def get_ranked_rewards(reward, conn=None):
def get_ranked_rewards(reward):

if conn is None:
conn = psycopg2.connect(**dbparams)

with conn:
n_rewards = pd.read_sql_query("""
select count(*) from {table}_reward
""".format(table=config.sql_basename), conn)
with psycopg2.connect(**dbparams) as conn:
with conn.cursor() as cur:
cur.execute("select count(*) from {table}_game;".format(
table=config.sql_basename))
n_games = cur.fetchone()[0]

if n_games < config.reward_buffer:
# Here, we don't have enough of a game buffer
# to decide if the move is good or not
return np.random.choice([-1., 1.])

if n_rewards['count'][0] < config.batch_size:
return np.random.choice([-1.,1.])
else:
param = {config.ranked_reward_alpha, config.batch_size}
r_alpha = pd.read_sql_query("""
select percentile_cont(%s) within group (order by real_reward) from (
select real_reward
from {table}_reward
order by id desc limit %s) as finals
""".format(table=config.sql_basename), conn, params=param)
if reward > r_alpha['percentile_cont'][0]:
with conn.cursor() as cur:
cur.execute("""
select percentile_cont(%s) within group (order by real_reward)
from (select real_reward from {table}_game
order by id desc limit %s) as finals
""".format(table=config.sql_basename),
(config.ranked_reward_alpha, config.reward_buffer))

r_alpha = cur.fetchone()[0]

if reward > r_alpha:
return 1.
elif reward < r_alpha['percentile_cont'][0]:

elif reward < r_alpha:
return -1.

else:
return np.random.choice([-1.,1.])
return np.random.choice([-1., 1.])


class StabilityNode(Node):
Expand All @@ -66,7 +74,9 @@ def get_reward(self):

with psycopg2.connect(**dbparams) as conn:
with conn.cursor() as cur:
cur.execute("select real_reward from {table}_reward where smiles = %s".format(table=config.sql_basename), (self.smiles,))
cur.execute(
"select real_reward from {table}_reward where smiles = %s".format(
table=config.sql_basename), (self.smiles,))
result = cur.fetchone()

if result:
Expand All @@ -78,17 +88,17 @@ def get_reward(self):
# Node is outside the domain of validity
elif ((self.policy_inputs['atom'] == 1).any() |
(self.policy_inputs['bond'] == 1).any()):
rr = get_ranked_rewards(0.)
self._true_reward = 0.
return rr
return config.min_reward

else:
spins, buried_vol = model(

spins, buried_vol = model.predict(
{key: tf.constant(np.expand_dims(val, 0))
for key, val in self.policy_inputs.items()})

spins = spins.numpy().flatten()
buried_vol = buried_vol.numpy().flatten()
spins = spins.flatten()
buried_vol = buried_vol.flatten()

atom_index = int(spins.argmax())
max_spin = spins[atom_index]
Expand All @@ -109,7 +119,8 @@ def get_reward(self):
cur.execute("""
INSERT INTO {table}_reward
(smiles, real_reward, atom_type, buried_vol, max_spin, atom_index)
values (%s, %s, %s, %s, %s, %s);""".format(table=config.sql_basename), (
values (%s, %s, %s, %s, %s, %s)
ON CONFLICT DO NOTHING;""".format(table=config.sql_basename), (
self.smiles, float(reward), atom_type, # This should be the real reward
float(spin_buried_vol), float(max_spin), atom_index))

Expand All @@ -128,19 +139,39 @@ def run_game():

game = list(G.run_mcts())
reward = game[-1].reward # here it returns the ranked reward


try:
terminal_true_reward = float(game[-1]._true_reward)

except AttributeError:
# This is hacky until we have a better separation of `true_reward`
# and `ranked_reward`. This can happen if a node doesn't have any
# children, but still gets chosen as the final state.
terminal_true_reward = 0.


with psycopg2.connect(**dbparams) as conn:

with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO {table}_game
(experiment_id, gameid, real_reward, final_smiles) values (%s, %s, %s, %s);
""".format(table=config.sql_basename), (
config.experiment_id, gameid, terminal_true_reward, game[-1].smiles))

for i, node in enumerate(game[:-1]):
with conn.cursor() as cur:
cur.execute(
"""INSERT INTO {table}_replay
(experiment_id, gameid, smiles, final_smiles, ranked_reward, position, data) values (%s, %s, %s, %s, %s, %s, %s);
INSERT INTO {table}_game
(experiment_id, gameid, real_reward) values (%s, %s, %s);
"""
INSERT INTO {table}_replay
(experiment_id, gameid, smiles, final_smiles, ranked_reward, position, data)
values (%s, %s, %s, %s, %s, %s, %s);
""".format(table=config.sql_basename), (
config.experiment_id, gameid, node.smiles, game[-1].smiles, reward, i,
node.get_action_inputs_as_binary(),config.experiment_id, gameid, float(game[-1]._true_reward)))
node.get_action_inputs_as_binary()))



print(f'finishing game {gameid}', flush=True)

Expand Down
6 changes: 3 additions & 3 deletions stable_radical_optimization/submit_mcts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#SBATCH --time=00:20:00
#SBATCH --job-name=mcts_q2_debug
#SBATCH --partition=debug
#SBATCH -n 4
#SBATCH -c 18
#SBATCH --output=/scratch/eskordil/git-repos/rlmolecule_new/rlmolecule/mcts.%j.out
#SBATCH -n 72
#SBATCH -c 1
#SBATCH --output=/scratch/pstjohn/rlmolecule/mcts.%j.out

source ~/.bashrc
conda activate /projects/rlmolecule/pstjohn/envs/tf2_cpu
Expand Down

0 comments on commit dbbea76

Please sign in to comment.