Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noval #2

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
507e239
adding sanity check experiments for small horizon len and layernorm a…
RajGhugare19 Feb 21, 2023
faaee49
adding sanity check experiments for small horizon len and layernorm a…
RajGhugare19 Feb 21, 2023
d75ae8b
adding 25 len experiments as well
RajGhugare19 Feb 21, 2023
f4f6099
does rl produce unique strings?
RajGhugare19 Feb 21, 2023
49186dc
does rl produce unique strings?
RajGhugare19 Feb 21, 2023
0ea6ec3
adding regression eval
RajGhugare19 Feb 22, 2023
ecb8136
adding regression
RajGhugare19 Feb 22, 2023
a569eb8
adding regression
RajGhugare19 Feb 22, 2023
2eed3e9
adding mila script
RajGhugare19 Feb 23, 2023
f31a5b7
adding cedar scripts
RajGhugare19 Feb 23, 2023
3eefc6f
adding mila scripts
RajGhugare19 Feb 23, 2023
323f2d5
adding cedar scripts
RajGhugare19 Feb 23, 2023
9dd7749
adding cedar scripts
RajGhugare19 Feb 23, 2023
bf6e7db
add scripts
RajGhugare19 Feb 23, 2023
b416e09
install pandas
RajGhugare19 Feb 23, 2023
8f6ba52
adding fc mlp
RajGhugare19 Feb 23, 2023
e6f62bd
adding fc mlp
RajGhugare19 Feb 23, 2023
5c2360b
minor edits
RajGhugare19 Feb 23, 2023
e40a5b4
addin scripts for other targets
RajGhugare19 Feb 23, 2023
0d94175
addin scripts for other targets
RajGhugare19 Feb 23, 2023
fe4f30b
adding scripts with horizon 50
RajGhugare19 Feb 23, 2023
e50ee59
adding new code
RajGhugare19 Feb 24, 2023
4d9d628
add
RajGhugare19 Feb 25, 2023
59e365c
addinf sanity check cedar scripts
RajGhugare19 Feb 25, 2023
9cd7efd
len 22 sanity
RajGhugare19 Feb 25, 2023
d7e4eea
adding 0 dropout experiments
RajGhugare19 Feb 27, 2023
d9729eb
adding conv1d experiments
RajGhugare19 Feb 27, 2023
726434d
adding char conv exps
RajGhugare19 Feb 28, 2023
61081fe
adding conv regression experiments based on chem VAE
RajGhugare19 Mar 2, 2023
e51c92a
adding partial development
RajGhugare19 Mar 2, 2023
9f2ccde
adding dqn as agent
RajGhugare19 Mar 10, 2023
040387d
adding fast dqn to cedar
RajGhugare19 Mar 10, 2023
57d449c
bring back sac
RajGhugare19 Mar 10, 2023
e104ff5
removing mback utils
RajGhugare19 Mar 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MUJOCO_LOG.TXT
saved/

#virtual env
env_rl4chem/
env_*

#notebooks
.ipynb_checkpoints/
Expand All @@ -24,10 +24,12 @@ datasets/
filtered_data/
docked_data/
raw_data/
data/

#package dependencies
dockstring/
openbabel.git
oracle/

#local experiments
store_room/
Expand Down
3 changes: 2 additions & 1 deletion cfgs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ seed: 1
#environment specific
target: 'fa7'
selfies_enc_type: 'one_hot'
max_selfie_length: 25
max_selfie_length: 22
vina_program: 'qvina2'
temp_dir: 'tmp'
exhaustiveness: 1
Expand All @@ -22,6 +22,7 @@ timeout_dock: 100
num_train_steps: 1000000
env_buffer_size: 100000

explore_molecules: 250
parallel_molecules: 250
batch_size: 256
obs_dtype:
Expand Down
108 changes: 108 additions & 0 deletions dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np
import random
import wandb

import utils

class QNetwork(nn.Module):
def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.ReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LayerNorm(hidden_dims),
nn.ReLU(), nn.Linear(hidden_dims, output_dims))

def forward(self, x):
return self.network(x)

class DQNAgent:
def __init__(self, device, obs_dims, num_actions,
gamma, tau, update_interval, target_update_interval, lr, batch_size,
hidden_dims, wandb_log, log_interval):

self.device = device

#learning
self.gamma = gamma
self.tau = tau
self.update_interval = update_interval
self.target_update_interval = target_update_interval
self.batch_size = batch_size

#logging
self.wandb_log = wandb_log
self.log_interval = log_interval

self._init_networks(obs_dims, num_actions, hidden_dims)
self._init_optims(lr)

def get_action(self, obs, eval=False):
with torch.no_grad():
obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
q_values = self.q(obs)
action = torch.argmax(q_values)
return action.cpu().numpy()

def _init_networks(self, obs_dims, num_actions, hidden_dims):
self.q = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device)
self.q_target = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device)
utils.hard_update(self.q_target, self.q)

def _init_optims(self, lr):
self.q_opt = torch.optim.Adam(self.q.parameters(), lr=lr["q"])

def get_save_dict(self):
return {
"q": self.q.state_dict(),
"q_target":self.q_target.state_dict(),
}

def load_save_dict(self, saved_dict):
self.q.load_state_dict(saved_dict["q"])
self.q_target.load_state_dict(saved_dict["q_target"])

def update(self, buffer, step):
metrics = dict()
if step % self.log_interval == 0 and self.wandb_log:
log = True
else:
log = False

if step % self.update_interval == 0:
state_batch, action_batch, reward_batch, next_state_batch, done_batch, time_batch = buffer.sample(self.batch_size)
state_batch = torch.tensor(state_batch, dtype=torch.float32, device=self.device)
next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32, device=self.device)
action_batch = torch.tensor(action_batch, dtype=torch.long, device=self.device)
reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=self.device)
done_batch = torch.tensor(done_batch, dtype=torch.float32, device=self.device)
discount_batch = self.gamma*(1-done_batch)

with torch.no_grad():
target_max, _ = self.q_target(next_state_batch).max(dim=1)
td_target = reward_batch + self.gamma * target_max * discount_batch

old_val = self.q(state_batch).gather(1, action_batch).squeeze()

loss = F.mse_loss(td_target, old_val)
self.q_opt.zero_grad()
loss.backward()
self.q_opt.step()

if log:
metrics['mean_q_target'] = torch.mean(td_target).item()
metrics['max_reward'] = torch.max(reward_batch).item()
metrics['min_reward'] = torch.min(reward_batch).item()
metrics['variance_q_target'] = torch.var(td_target).item()
metrics['min_q_target'] = torch.min(td_target).item()
metrics['max_q_target'] = torch.max(td_target).item()
metrics['critic_loss'] = loss.item()

if step % self.target_update_interval == 0:
utils.soft_update(self.q_target, self.q, self.tau)

if log:
wandb.log(metrics, step=step)
112 changes: 77 additions & 35 deletions env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,77 @@
from docking import DockingVina
from collections import defaultdict

class selfies_vocabulary(object):
def __init__(self, vocab_path=None):

if vocab_path is None:
self.alphabet = sf.get_semantic_robust_alphabet()
else:
self.alphabet = set()
with open(vocab_path, 'r') as f:
chars = f.read().split()
for char in chars:
self.alphabet.add(char)

self.special_tokens = ['BOS', 'EOS', 'PAD', 'UNK']

self.alphabet_list = list(self.alphabet)
self.alphabet_list.sort()
self.alphabet_list = self.alphabet_list + self.special_tokens
self.alphabet_length = len(self.alphabet_list)

self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)}
self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()}

def tokenize(self, selfies, add_bos=False, add_eos=False):
"""Takes a SELFIES and return a list of characters/tokens"""
tokenized = list(sf.split_selfies(selfies))
if add_bos:
tokenized.insert(0, "BOS")
if add_eos:
tokenized.append('EOS')
return tokenized

def encode(self, selfies, add_bos=False, add_eos=False):
"""Takes a list of SELFIES and encodes to array of indices"""
char_list = self.tokenize(selfies, add_bos, add_eos)
encoded_selfies = np.zeros(len(char_list), dtype=np.uint8)
for i, char in enumerate(char_list):
encoded_selfies[i] = self.alphabet_to_idx[char]
return encoded_selfies

def decode(self, encoded_seflies, rem_bos=True, rem_eos=True):
"""Takes an array of indices and returns the corresponding SELFIES"""
if rem_bos and encoded_seflies[0] == self.bos:
encoded_seflies = encoded_seflies[1:]
if rem_eos and encoded_seflies[-1] == self.eos:
encoded_seflies = encoded_seflies[:-1]

chars = []
for i in encoded_seflies:
chars.append(self.idx_to_alphabet[i])
selfies = "".join(chars)
return selfies

def __len__(self):
return len(self.alphabet_to_idx)

@property
def bos(self):
return self.alphabet_to_idx['BOS']

@property
def eos(self):
return self.alphabet_to_idx['EOS']

@property
def pad(self):
return self.alphabet_to_idx['PAD']

@property
def unk(self):
return self.alphabet_to_idx['UNK']

class docking_env(object):
'''This environment is build assuming selfies version 2.1.1
To-do
Expand Down Expand Up @@ -86,8 +157,6 @@ def __init__(self, cfg):

# Intitialising smiles batch for parallel evaluation
self.smiles_batch = []
self.selfies_batch = []
self.len_selfies_batch = []

# Initialize Step
self.t = 0
Expand Down Expand Up @@ -123,29 +192,27 @@ def step(self, action):
if done:
molecule_smiles = sf.decoder(self.molecule_selfie)
pretty_selfies = sf.encoder(molecule_smiles)

self.smiles_batch.append(molecule_smiles)
self.selfies_batch.append(pretty_selfies)
self.len_selfies_batch.append(sf.len_selfies(pretty_selfies))
info["episode"]["l"] = self.t
info["episode"]["smiles"] = molecule_smiles
info["episode"]["seflies"] = pretty_selfies
info["episode"]["selfies_len"] = sf.len_selfies(pretty_selfies)
reward = -1000
else:
reward = 0
return self.enc_selifes_fn(self.molecule_selfie), reward, done, info

def _add_smiles_to_batch(self, molecule_smiles):
self.smiles_batch.append(molecule_smiles)

def _reset_store_batch(self):
# Intitialising smiles batch for parallel evaluation
self.smiles_batch = []
self.selfies_batch = []
self.len_selfies_batch = []

def get_reward_batch(self):
info = defaultdict(dict)
docking_scores = self.predictor.predict(self.smiles_batch)
reward_batch = np.clip(-np.array(docking_scores), a_min=0.0, a_max=None)
info['smiles'] = self.smiles_batch
info['selfies'] = self.selfies_batch
info['len_selfies'] = self.len_selfies_batch
info['docking_scores'] = docking_scores
self._reset_store_batch()
return reward_batch, info
Expand All @@ -168,31 +235,6 @@ class args():
timeout_dock= 100

env = docking_env(args)
state = env.reset()
done = False
next_state, reward, done, info = env.step(15)
print(env.action_space[15])
print(next_state)
print(env.molecule_selfie)
# print(state)
# while not done:
# action = np.random.randint(env.num_actions)
# next_state, reward, done, info = env.step(action)
# print(action)
# print(env.alphabet_to_idx[env.action_space[action]])
# print(next_state)
# print('\n')

# possible_states = []
# for a1 in range(env.num_actions):
# for a2 in range(env.num_actions):
# env.reset()
# env.step(a1)
# env.step(a2)

# reward_batch, reward_info = env.get_reward_batch()
# print(np.argmax(reward_batch), np.max(reward_batch))
# print(reward_info['smiles'][np.argmax(reward_batch)])

'''
ENV stats
Expand Down
62 changes: 62 additions & 0 deletions regression/cfgs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#common
target: 'ESR2'
string_rep: 'selfies'
data_path: 'data/filtered_dockstring-dataset.tsv'
splits_path: 'data/filtered_cluster_split.tsv'

#learning
lr: 1e-3
seed: 1
device: 'cuda'
num_epochs: 20
batch_size: 64
max_grad_norm: 0.5
model_name:
vocab_size:
pad_idx:
input_size:

#CharRNN
charrnn:
_target_: models.CharRNN
vocab_size: ${vocab_size}
pad_idx: ${pad_idx}
device: ${device}
num_layers: 3
hidden_size: 256
embedding_size: 32

#CharMLP
charmlp:
_target_: models.CharMLP
vocab_size: ${vocab_size}
pad_idx: ${pad_idx}
device: ${device}
hidden_size: 256
embedding_size: 32
input_size: ${input_size}

#CharConv
charconv:
_target_: models.CharConv
vocab_size: ${vocab_size}
pad_idx: ${pad_idx}
device: ${device}
hidden_size: 256
embedding_size: 32
input_size: ${input_size}

#eval
eval_interval: 100

#logging
wandb_log: False
wandb_entity: 'raj19'
wandb_run_name: 'docking-regression'
log_interval: 100

hydra:
run:
dir: ./local_exp/${now:%Y.%m.%d}/${now:%H.%M.%S}_${seed}
job:
chdir: False
Loading