diff --git a/.gitignore b/.gitignore index 373ee1c..cbeaa88 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ MUJOCO_LOG.TXT saved/ #virtual env -env_rl4chem/ +env_* #notebooks .ipynb_checkpoints/ @@ -24,10 +24,12 @@ datasets/ filtered_data/ docked_data/ raw_data/ +data/ #package dependencies dockstring/ openbabel.git +oracle/ #local experiments store_room/ diff --git a/cfgs/config.yaml b/cfgs/config.yaml index 898f74c..82074da 100644 --- a/cfgs/config.yaml +++ b/cfgs/config.yaml @@ -8,7 +8,7 @@ seed: 1 #environment specific target: 'fa7' selfies_enc_type: 'one_hot' -max_selfie_length: 25 +max_selfie_length: 22 vina_program: 'qvina2' temp_dir: 'tmp' exhaustiveness: 1 @@ -22,6 +22,7 @@ timeout_dock: 100 num_train_steps: 1000000 env_buffer_size: 100000 +explore_molecules: 250 parallel_molecules: 250 batch_size: 256 obs_dtype: diff --git a/dqn.py b/dqn.py new file mode 100644 index 0000000..6b980ec --- /dev/null +++ b/dqn.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributions as td +import numpy as np +import random +import wandb + +import utils + +class QNetwork(nn.Module): + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.network = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LayerNorm(hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, output_dims)) + + def forward(self, x): + return self.network(x) + +class DQNAgent: + def __init__(self, device, obs_dims, num_actions, + gamma, tau, update_interval, target_update_interval, lr, batch_size, + hidden_dims, wandb_log, log_interval): + + self.device = device + + #learning + self.gamma = gamma + self.tau = tau + self.update_interval = update_interval + self.target_update_interval = target_update_interval + self.batch_size = batch_size + + #logging + self.wandb_log = wandb_log + self.log_interval = log_interval + + self._init_networks(obs_dims, num_actions, hidden_dims) + self._init_optims(lr) + + def get_action(self, obs, eval=False): + with torch.no_grad(): + obs = torch.tensor(obs, dtype=torch.float32, device=self.device) + q_values = self.q(obs) + action = torch.argmax(q_values) + return action.cpu().numpy() + + def _init_networks(self, obs_dims, num_actions, hidden_dims): + self.q = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device) + self.q_target = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device) + utils.hard_update(self.q_target, self.q) + + def _init_optims(self, lr): + self.q_opt = torch.optim.Adam(self.q.parameters(), lr=lr["q"]) + + def get_save_dict(self): + return { + "q": self.q.state_dict(), + "q_target":self.q_target.state_dict(), + } + + def load_save_dict(self, saved_dict): + self.q.load_state_dict(saved_dict["q"]) + self.q_target.load_state_dict(saved_dict["q_target"]) + + def update(self, buffer, step): + metrics = dict() + if step % self.log_interval == 0 and self.wandb_log: + log = True + else: + log = False + + if step % self.update_interval == 0: + state_batch, action_batch, reward_batch, next_state_batch, done_batch, time_batch = buffer.sample(self.batch_size) + state_batch = torch.tensor(state_batch, dtype=torch.float32, device=self.device) + next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32, device=self.device) + action_batch = torch.tensor(action_batch, dtype=torch.long, device=self.device) + reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=self.device) + done_batch = torch.tensor(done_batch, dtype=torch.float32, device=self.device) + discount_batch = self.gamma*(1-done_batch) + + with torch.no_grad(): + target_max, _ = self.q_target(next_state_batch).max(dim=1) + td_target = reward_batch + self.gamma * target_max * discount_batch + + old_val = self.q(state_batch).gather(1, action_batch).squeeze() + + loss = F.mse_loss(td_target, old_val) + self.q_opt.zero_grad() + loss.backward() + self.q_opt.step() + + if log: + metrics['mean_q_target'] = torch.mean(td_target).item() + metrics['max_reward'] = torch.max(reward_batch).item() + metrics['min_reward'] = torch.min(reward_batch).item() + metrics['variance_q_target'] = torch.var(td_target).item() + metrics['min_q_target'] = torch.min(td_target).item() + metrics['max_q_target'] = torch.max(td_target).item() + metrics['critic_loss'] = loss.item() + + if step % self.target_update_interval == 0: + utils.soft_update(self.q_target, self.q, self.tau) + + if log: + wandb.log(metrics, step=step) \ No newline at end of file diff --git a/env.py b/env.py index f1f58ee..25e1b3e 100644 --- a/env.py +++ b/env.py @@ -6,6 +6,77 @@ from docking import DockingVina from collections import defaultdict +class selfies_vocabulary(object): + def __init__(self, vocab_path=None): + + if vocab_path is None: + self.alphabet = sf.get_semantic_robust_alphabet() + else: + self.alphabet = set() + with open(vocab_path, 'r') as f: + chars = f.read().split() + for char in chars: + self.alphabet.add(char) + + self.special_tokens = ['BOS', 'EOS', 'PAD', 'UNK'] + + self.alphabet_list = list(self.alphabet) + self.alphabet_list.sort() + self.alphabet_list = self.alphabet_list + self.special_tokens + self.alphabet_length = len(self.alphabet_list) + + self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)} + self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()} + + def tokenize(self, selfies, add_bos=False, add_eos=False): + """Takes a SELFIES and return a list of characters/tokens""" + tokenized = list(sf.split_selfies(selfies)) + if add_bos: + tokenized.insert(0, "BOS") + if add_eos: + tokenized.append('EOS') + return tokenized + + def encode(self, selfies, add_bos=False, add_eos=False): + """Takes a list of SELFIES and encodes to array of indices""" + char_list = self.tokenize(selfies, add_bos, add_eos) + encoded_selfies = np.zeros(len(char_list), dtype=np.uint8) + for i, char in enumerate(char_list): + encoded_selfies[i] = self.alphabet_to_idx[char] + return encoded_selfies + + def decode(self, encoded_seflies, rem_bos=True, rem_eos=True): + """Takes an array of indices and returns the corresponding SELFIES""" + if rem_bos and encoded_seflies[0] == self.bos: + encoded_seflies = encoded_seflies[1:] + if rem_eos and encoded_seflies[-1] == self.eos: + encoded_seflies = encoded_seflies[:-1] + + chars = [] + for i in encoded_seflies: + chars.append(self.idx_to_alphabet[i]) + selfies = "".join(chars) + return selfies + + def __len__(self): + return len(self.alphabet_to_idx) + + @property + def bos(self): + return self.alphabet_to_idx['BOS'] + + @property + def eos(self): + return self.alphabet_to_idx['EOS'] + + @property + def pad(self): + return self.alphabet_to_idx['PAD'] + + @property + def unk(self): + return self.alphabet_to_idx['UNK'] + class docking_env(object): '''This environment is build assuming selfies version 2.1.1 To-do @@ -86,8 +157,6 @@ def __init__(self, cfg): # Intitialising smiles batch for parallel evaluation self.smiles_batch = [] - self.selfies_batch = [] - self.len_selfies_batch = [] # Initialize Step self.t = 0 @@ -123,29 +192,27 @@ def step(self, action): if done: molecule_smiles = sf.decoder(self.molecule_selfie) pretty_selfies = sf.encoder(molecule_smiles) - - self.smiles_batch.append(molecule_smiles) - self.selfies_batch.append(pretty_selfies) - self.len_selfies_batch.append(sf.len_selfies(pretty_selfies)) info["episode"]["l"] = self.t + info["episode"]["smiles"] = molecule_smiles + info["episode"]["seflies"] = pretty_selfies + info["episode"]["selfies_len"] = sf.len_selfies(pretty_selfies) reward = -1000 else: reward = 0 return self.enc_selifes_fn(self.molecule_selfie), reward, done, info + def _add_smiles_to_batch(self, molecule_smiles): + self.smiles_batch.append(molecule_smiles) + def _reset_store_batch(self): # Intitialising smiles batch for parallel evaluation self.smiles_batch = [] - self.selfies_batch = [] - self.len_selfies_batch = [] def get_reward_batch(self): info = defaultdict(dict) docking_scores = self.predictor.predict(self.smiles_batch) reward_batch = np.clip(-np.array(docking_scores), a_min=0.0, a_max=None) info['smiles'] = self.smiles_batch - info['selfies'] = self.selfies_batch - info['len_selfies'] = self.len_selfies_batch info['docking_scores'] = docking_scores self._reset_store_batch() return reward_batch, info @@ -168,31 +235,6 @@ class args(): timeout_dock= 100 env = docking_env(args) - state = env.reset() - done = False - next_state, reward, done, info = env.step(15) - print(env.action_space[15]) - print(next_state) - print(env.molecule_selfie) - # print(state) - # while not done: - # action = np.random.randint(env.num_actions) - # next_state, reward, done, info = env.step(action) - # print(action) - # print(env.alphabet_to_idx[env.action_space[action]]) - # print(next_state) - # print('\n') - - # possible_states = [] - # for a1 in range(env.num_actions): - # for a2 in range(env.num_actions): - # env.reset() - # env.step(a1) - # env.step(a2) - - # reward_batch, reward_info = env.get_reward_batch() - # print(np.argmax(reward_batch), np.max(reward_batch)) - # print(reward_info['smiles'][np.argmax(reward_batch)]) ''' ENV stats diff --git a/regression/cfgs/config.yaml b/regression/cfgs/config.yaml new file mode 100644 index 0000000..606f282 --- /dev/null +++ b/regression/cfgs/config.yaml @@ -0,0 +1,62 @@ +#common +target: 'ESR2' +string_rep: 'selfies' +data_path: 'data/filtered_dockstring-dataset.tsv' +splits_path: 'data/filtered_cluster_split.tsv' + +#learning +lr: 1e-3 +seed: 1 +device: 'cuda' +num_epochs: 20 +batch_size: 64 +max_grad_norm: 0.5 +model_name: +vocab_size: +pad_idx: +input_size: + +#CharRNN +charrnn: + _target_: models.CharRNN + vocab_size: ${vocab_size} + pad_idx: ${pad_idx} + device: ${device} + num_layers: 3 + hidden_size: 256 + embedding_size: 32 + +#CharMLP +charmlp: + _target_: models.CharMLP + vocab_size: ${vocab_size} + pad_idx: ${pad_idx} + device: ${device} + hidden_size: 256 + embedding_size: 32 + input_size: ${input_size} + +#CharConv +charconv: + _target_: models.CharConv + vocab_size: ${vocab_size} + pad_idx: ${pad_idx} + device: ${device} + hidden_size: 256 + embedding_size: 32 + input_size: ${input_size} + +#eval +eval_interval: 100 + +#logging +wandb_log: False +wandb_entity: 'raj19' +wandb_run_name: 'docking-regression' +log_interval: 100 + +hydra: + run: + dir: ./local_exp/${now:%Y.%m.%d}/${now:%H.%M.%S}_${seed} + job: + chdir: False \ No newline at end of file diff --git a/regression/char_conv.py b/regression/char_conv.py new file mode 100644 index 0000000..4fad2d7 --- /dev/null +++ b/regression/char_conv.py @@ -0,0 +1,112 @@ +import hydra +import wandb +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import sklearn.metrics +from omegaconf import DictConfig +from data import get_data + +class Averager(): + def __init__(self): + self.n = 0 + self.v = 0 + + def add(self, x): + self.v = (self.v * self.n + x) / (self.n + 1) + self.n += 1 + + def item(self): + return self.v + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def get_params(model): + return (p for p in model.parameters() if p.requires_grad) + +def train(cfg): + #get trainloader + train_loader, val_loader, vocab, input_size = get_data(cfg, get_max_len=True) + + #get model + cfg.vocab_size = len(vocab) + cfg.pad_idx = vocab.pad + cfg.input_size = int(input_size) + model = hydra.utils.instantiate(cfg.charconv) + + #set optimizer + optimizer = optim.Adam(get_params(model), lr=cfg.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.num_epochs, eta_min=0.0) + + num_params = count_parameters(model) + print('The model has ', num_params, ' number of trainable parameters.') + + avg_train_loss = Averager() + # avg_grad_norm = Averager() + for epoch in range(cfg.num_epochs): + metrics = dict() + for step, (x, y) in enumerate(train_loader): + preds = model(x) + loss = F.mse_loss(preds, y) + optimizer.zero_grad() + loss.backward() + # grad_norm = nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) + optimizer.step() + avg_train_loss.add(loss.item()) + # avg_grad_norm.add(grad_norm.item()) + + if step % cfg.eval_interval == 0: + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + model.train() + + if cfg.wandb_log and step % cfg.log_interval==0: + metrics['train loss'] = avg_train_loss.item() + # metrics['average grad norm'] = avg_grad_norm.item() + metrics['lr'] = scheduler.get_last_lr()[0] + metrics['epoch'] = epoch + avg_train_loss = Averager() + wandb.log(metrics) + scheduler.step() + + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + +def eval(model, val_loader): + metrics = dict() + preds_list = [] + targets_list = [] + avg_loss = Averager() + model.eval() + with torch.no_grad(): + for step, (x, y) in enumerate(val_loader): + preds = model(x) + preds_list.append(preds) + targets_list.append(y) + loss = F.mse_loss(preds, y) + avg_loss.add(loss.item()) + + preds_list = torch.cat(preds_list).tolist() + targets_list = torch.cat(targets_list).tolist() + r2_score = sklearn.metrics.r2_score(y_true = targets_list, y_pred = preds_list) + metrics['r2 score'] = r2_score + metrics['val loss'] = avg_loss.item() + return metrics + +@hydra.main(config_path='cfgs', config_name='config', version_base=None) +def main(cfg: DictConfig): + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + from char_conv import train + cfg.model_name = 'char_conv' + if cfg.wandb_log: + project_name = 'docking-regression-' + cfg.target + wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=hydra_cfg['runtime']['output_dir']) + wandb.run.name = cfg.wandb_run_name + train(cfg) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/regression/char_mlp.py b/regression/char_mlp.py new file mode 100644 index 0000000..743ecef --- /dev/null +++ b/regression/char_mlp.py @@ -0,0 +1,112 @@ +import hydra +import wandb +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import sklearn.metrics +from omegaconf import DictConfig +from data import get_data + +class Averager(): + def __init__(self): + self.n = 0 + self.v = 0 + + def add(self, x): + self.v = (self.v * self.n + x) / (self.n + 1) + self.n += 1 + + def item(self): + return self.v + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def get_params(model): + return (p for p in model.parameters() if p.requires_grad) + +def train(cfg): + #get trainloader + train_loader, val_loader, vocab, input_size = get_data(cfg, get_max_len=True) + + #get model + cfg.vocab_size = len(vocab) + cfg.pad_idx = vocab.pad + cfg.input_size = int(input_size) #length of the maximum sequence in the dataset + model = hydra.utils.instantiate(cfg.charmlp) + + #set optimizer + optimizer = optim.Adam(get_params(model), lr=cfg.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.num_epochs, eta_min=0.0) + + num_params = count_parameters(model) + print('The model has ', num_params, ' number of trainable parameters.') + + avg_train_loss = Averager() + avg_grad_norm = Averager() + for epoch in range(cfg.num_epochs): + metrics = dict() + for step, (x, y) in enumerate(train_loader): + preds = model(x) + loss = F.mse_loss(preds, y) + optimizer.zero_grad() + loss.backward() + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) + optimizer.step() + avg_train_loss.add(loss.item()) + avg_grad_norm.add(grad_norm.item()) + + if step % cfg.eval_interval == 0: + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + model.train() + + if cfg.wandb_log and step % cfg.log_interval==0: + metrics['train loss'] = avg_train_loss.item() + metrics['average grad norm'] = avg_grad_norm.item() + metrics['lr'] = scheduler.get_last_lr()[0] + metrics['epoch'] = epoch + avg_train_loss = Averager() + wandb.log(metrics) + scheduler.step() + + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + +def eval(model, val_loader): + metrics = dict() + preds_list = [] + targets_list = [] + avg_loss = Averager() + model.eval() + with torch.no_grad(): + for step, (x, y) in enumerate(val_loader): + preds = model(x) + preds_list.append(preds) + targets_list.append(y) + loss = F.mse_loss(preds, y) + avg_loss.add(loss.item()) + + preds_list = torch.cat(preds_list).tolist() + targets_list = torch.cat(targets_list).tolist() + r2_score = sklearn.metrics.r2_score(y_true = targets_list, y_pred = preds_list) + metrics['r2 score'] = r2_score + metrics['val loss'] = avg_loss.item() + return metrics + +@hydra.main(config_path='cfgs', config_name='config', version_base=None) +def main(cfg: DictConfig): + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + from char_mlp import train + cfg.model_name = 'char_mlp' + if cfg.wandb_log: + project_name = 'docking-regression-' + cfg.target + wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=hydra_cfg['runtime']['output_dir']) + wandb.run.name = cfg.wandb_run_name + train(cfg) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/regression/char_rnn.py b/regression/char_rnn.py new file mode 100644 index 0000000..d479c0c --- /dev/null +++ b/regression/char_rnn.py @@ -0,0 +1,111 @@ +import hydra +import wandb +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import sklearn.metrics +from omegaconf import DictConfig +from data import get_data + +class Averager(): + def __init__(self): + self.n = 0 + self.v = 0 + + def add(self, x): + self.v = (self.v * self.n + x) / (self.n + 1) + self.n += 1 + + def item(self): + return self.v + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def get_params(model): + return (p for p in model.parameters() if p.requires_grad) + +def train(cfg): + #get trainloader + train_loader, val_loader, vocab = get_data(cfg) + + #get model + cfg.vocab_size = len(vocab) + cfg.pad_idx = vocab.pad + model = hydra.utils.instantiate(cfg.charrnn) + + #set optimizer + optimizer = optim.Adam(get_params(model), lr=cfg.lr, eps=1e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.num_epochs, eta_min=0.0) + + num_params = count_parameters(model) + print('The model has ', num_params, ' number of trainable parameters.') + + avg_train_loss = Averager() + avg_grad_norm = Averager() + for epoch in range(cfg.num_epochs): + metrics = dict() + for step, (x, y, lens) in enumerate(train_loader): + preds = model(x, lens) + loss = F.mse_loss(preds, y) + optimizer.zero_grad() + loss.backward() + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) + optimizer.step() + avg_train_loss.add(loss.item()) + avg_grad_norm.add(grad_norm.item()) + + if step % cfg.eval_interval == 0: + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + model.train() + + if cfg.wandb_log and step % cfg.log_interval==0: + metrics['train loss'] = avg_train_loss.item() + metrics['average grad norm'] = avg_grad_norm.item() + metrics['lr'] = scheduler.get_last_lr()[0] + metrics['epoch'] = epoch + avg_train_loss = Averager() + wandb.log(metrics) + scheduler.step() + + metrics.update( + eval(model, val_loader)) + print('Epoch = ', epoch, 'Step = ', step, ' r2_score = ', metrics['r2 score']) + +def eval(model, val_loader): + metrics = dict() + preds_list = [] + targets_list = [] + avg_loss = Averager() + model.eval() + with torch.no_grad(): + for step, (x, y, lens) in enumerate(val_loader): + preds = model(x, lens) + preds_list.append(preds) + targets_list.append(y) + loss = F.mse_loss(preds, y) + avg_loss.add(loss.item()) + + preds_list = torch.cat(preds_list).tolist() + targets_list = torch.cat(targets_list).tolist() + r2_score = sklearn.metrics.r2_score(y_true = targets_list, y_pred = preds_list) + metrics['r2 score'] = r2_score + metrics['val loss'] = avg_loss.item() + return metrics + +@hydra.main(config_path='cfgs', config_name='config', version_base=None) +def main(cfg: DictConfig): + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + from char_rnn import train + cfg.model_name = 'char_rnn' + if cfg.wandb_log: + project_name = 'docking-regression-' + cfg.target + wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=hydra_cfg['runtime']['output_dir']) + wandb.run.name = cfg.wandb_run_name + train(cfg) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/regression/data.py b/regression/data.py new file mode 100644 index 0000000..871520d --- /dev/null +++ b/regression/data.py @@ -0,0 +1,265 @@ +import re +import torch +import numpy as np +import pandas as pd +import selfies as sf + +from torch.utils.data import DataLoader + +def replace_halogen(string): + """Regex to replace Br and Cl with single letters""" + br = re.compile('Br') + cl = re.compile('Cl') + string = br.sub('R', string) + string = cl.sub('L', string) + + return string + +class selfies_vocabulary(object): + def __init__(self, vocab_path='data/dockstring_selfies_vocabulary.txt', robust_alphabet=False): + + if robust_alphabet: + self.alphabet = sf.get_semantic_robust_alphabet() + else: + self.alphabet = set() + with open(vocab_path, 'r') as f: + chars = f.read().split() + for char in chars: + self.alphabet.add(char) + + self.special_tokens = ['BOS', 'EOS', 'PAD', 'UNK'] + + self.alphabet_list = list(self.alphabet) + self.alphabet_list.sort() + self.alphabet_list = self.alphabet_list + self.special_tokens + self.alphabet_length = len(self.alphabet_list) + + self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)} + self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()} + + def tokenize(self, selfies, add_bos=False, add_eos=False): + """Takes a SELFIES and return a list of characters/tokens""" + tokenized = list(sf.split_selfies(selfies)) + if add_bos: + tokenized.insert(0, "BOS") + if add_eos: + tokenized.append('EOS') + return tokenized + + def encode(self, selfies, add_bos=False, add_eos=False): + """Takes a list of SELFIES and encodes to array of indices""" + char_list = self.tokenize(selfies, add_bos, add_eos) + encoded_selfies = np.zeros(len(char_list), dtype=np.uint8) + for i, char in enumerate(char_list): + encoded_selfies[i] = self.alphabet_to_idx[char] + return encoded_selfies + + def decode(self, encoded_seflies, rem_bos=True, rem_eos=True): + """Takes an array of indices and returns the corresponding SELFIES""" + if rem_bos and encoded_seflies[0] == self.bos: + encoded_seflies = encoded_seflies[1:] + if rem_eos and encoded_seflies[-1] == self.eos: + encoded_seflies = encoded_seflies[:-1] + + chars = [] + for i in encoded_seflies: + chars.append(self.idx_to_alphabet[i]) + selfies = "".join(chars) + return selfies + + def __len__(self): + return len(self.alphabet_to_idx) + + @property + def bos(self): + return self.alphabet_to_idx['BOS'] + + @property + def eos(self): + return self.alphabet_to_idx['EOS'] + + @property + def pad(self): + return self.alphabet_to_idx['PAD'] + + @property + def unk(self): + return self.alphabet_to_idx['UNK'] + +class smiles_vocabulary(object): + def __init__(self, vocab_path='data/dockstring_smiles_vocabulary.txt'): + + self.alphabet = set() + with open(vocab_path, 'r') as f: + chars = f.read().split() + for char in chars: + self.alphabet.add(char) + + self.special_tokens = ['BOS', 'EOS', 'PAD', 'UNK'] + + self.alphabet_list = list(self.alphabet) + self.alphabet_list.sort() + self.alphabet_list = self.alphabet_list + self.special_tokens + self.alphabet_length = len(self.alphabet_list) + + self.alphabet_to_idx = {s: i for i, s in enumerate(self.alphabet_list)} + self.idx_to_alphabet = {s: i for i, s in self.alphabet_to_idx.items()} + + def tokenize(self, smiles, add_bos=False, add_eos=False): + """Takes a SMILES and return a list of characters/tokens""" + regex = '(\[[^\[\]]{1,6}\])' + smiles = replace_halogen(smiles) + char_list = re.split(regex, smiles) + tokenized = [] + for char in char_list: + if char.startswith('['): + tokenized.append(char) + else: + chars = [unit for unit in char] + [tokenized.append(unit) for unit in chars] + if add_bos: + tokenized.insert(0, "BOS") + if add_eos: + tokenized.append('EOS') + return tokenized + + def encode(self, smiles, add_bos=False, add_eos=False): + """Takes a list of SMILES and encodes to array of indices""" + char_list = self.tokenize(smiles, add_bos, add_eos) + encoded_smiles = np.zeros(len(char_list), dtype=np.uint8) + for i, char in enumerate(char_list): + encoded_smiles[i] = self.alphabet_to_idx[char] + return encoded_smiles + + def decode(self, encoded_smiles, rem_bos=True, rem_eos=True): + """Takes an array of indices and returns the corresponding SMILES""" + if rem_bos and encoded_smiles[0] == self.bos: + encoded_smiles = encoded_smiles[1:] + if rem_eos and encoded_smiles[-1] == self.eos: + encoded_smiles = encoded_smiles[:-1] + + chars = [] + for i in encoded_smiles: + chars.append(self.idx_to_alphabet[i]) + smiles = "".join(chars) + smiles = smiles.replace("L", "Cl").replace("R", "Br") + return smiles + + def __len__(self): + return len(self.alphabet_to_idx) + + @property + def bos(self): + return self.alphabet_to_idx['BOS'] + + @property + def eos(self): + return self.alphabet_to_idx['EOS'] + + @property + def pad(self): + return self.alphabet_to_idx['PAD'] + + @property + def unk(self): + return self.alphabet_to_idx['UNK'] + +class StringDataset: + def __init__(self, vocab, data, target, device, add_bos=False, add_eos=False): + """ + Arguments: + vocab: CharVocab instance for tokenization + data (list): SMILES/SELFIES strings for the datasety + target (arra): Array of target values + target (list): + """ + self.data = data + self.target = target + self.vocab = vocab + self.device = device + self.encoded_data = [vocab.encode(s, add_bos, add_eos) for s in data] + self.len = [len(s) for s in self.encoded_data] + self.max_len = np.max(self.len) + + def __len__(self): + """ + Computes a number of objects in the dataset + """ + return len(self.data) + + def __getitem__(self, index): + return torch.tensor(self.encoded_data[index], dtype=torch.long), self.target[index] + + def get_collate_fn(self, model_name): + if model_name == 'char_rnn': + def collate_fn(batch): + x, y = list(zip(*batch)) + lens = [len(s) for s in x] + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.vocab.pad).to(self.device) + y = torch.tensor(y, dtype=torch.float32, device=self.device) + return x, y, lens + elif model_name in ['char_mlp', 'char_conv']: + def collate_fn(batch): + x, y = list(zip(*batch)) + x = list(x) + x[0] = torch.nn.ConstantPad1d((0, self.max_len - x[0].shape[0]), self.vocab.pad)(x[0]) + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.vocab.pad).to(self.device) + y = torch.tensor(y, dtype=torch.float32, device=self.device) + return x, y + else: + raise NotImplementedError + + return collate_fn + +def get_data(cfg, get_max_len=False): + data_path=cfg.data_path + splits_path=cfg.splits_path + target=cfg.target + string_rep=cfg.string_rep + batch_size=cfg.batch_size + + assert target in ['ESR2', 'F2', 'KIT', 'PARP1', 'PGR'] + dockstring_df = pd.read_csv(data_path) + dockstring_splits = pd.read_csv(splits_path) + + assert np.all(dockstring_splits.smiles == dockstring_df.smiles) + + df_train = dockstring_df[dockstring_splits["split"] == "train"].dropna(subset=[target]) + df_test = dockstring_df[dockstring_splits["split"] == "test"].dropna(subset=[target]) + + y_train = df_train[target].values + y_test = df_test[target].values + + y_train = np.minimum(y_train, 5.0) + y_test = np.minimum(y_test, 5.0) + + if string_rep == 'smiles': + x_train = list(df_train['canon_smiles']) + x_test = list(df_test['canon_smiles']) + vocab = smiles_vocabulary() + + elif string_rep == 'selfies': + assert sf.__version__ == '2.1.0' + bond_constraints = sf.get_semantic_constraints() + bond_constraints['I'] = 5 + sf.set_semantic_constraints(bond_constraints) + assert sf.get_semantic_constraints()['I'] == 5 + + x_train = list(df_train['selfies_'+sf.__version__]) + x_test = list(df_test['selfies_'+sf.__version__]) + vocab = selfies_vocabulary() + else: + raise NotImplementedError + + train_dataset = StringDataset(vocab, x_train, y_train, device='cuda') + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=train_dataset.get_collate_fn(cfg.model_name)) + val_dataset = StringDataset(vocab, x_test, y_test, device='cuda') + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=train_dataset.get_collate_fn(cfg.model_name)) + + if get_max_len: + max_len = max(train_dataset.max_len, val_dataset.max_len) + train_dataset.max_len = max_len + val_dataset.max_len = max_len + return train_loader, val_loader, vocab, max_len + else: + return train_loader, val_loader, vocab \ No newline at end of file diff --git a/regression/models.py b/regression/models.py new file mode 100644 index 0000000..3eb680d --- /dev/null +++ b/regression/models.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn_utils + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +class CharMLP(nn.Module): + def __init__(self, vocab_size, pad_idx, embedding_size, input_size, hidden_size, device): + super(CharMLP, self).__init__() + self.device = device + self.input_size = input_size + self.embedding_size = embedding_size + self.embedding_layer = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx).to(self.device) + self.fc_layer = nn.Sequential( + nn.Linear(input_size * embedding_size, hidden_size), + nn.ReLU(), nn.Linear(hidden_size, hidden_size), + nn.ReLU(), nn.Linear(hidden_size, 1), + ).to(self.device) + + def forward(self, x): + x = self.embedding_layer(x).view(-1, self.input_size * self.embedding_size) + preds = self.fc_layer(x).squeeze(-1) + return preds + +class CharConv(nn.Module): + def __init__(self, vocab_size, pad_idx, embedding_size, input_size, hidden_size, device): + super(CharConv, self).__init__() + self.device = device + self.input_size = input_size + self.embedding_size = embedding_size + self.embedding_layer = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx).to(self.device) + self.conv1d1 = nn.Conv1d(in_channels=self.embedding_size, out_channels=9, kernel_size=9).to(self.device) + self.conv1d2 = nn.Conv1d(9, 9, kernel_size=9).to(self.device) + self.conv1d3 = nn.Conv1d(9, 10, kernel_size=11).to(self.device) + + self.temp = (self.input_size - 26) * 10 + self.linear_layer_1 = nn.Linear(self.temp, 256).to(self.device) + self.linear_layer_2 = nn.Linear(256, 1).to(self.device) + + def forward(self, x): + x = torch.transpose(self.embedding_layer(x), 1, 2) + x = F.relu(self.conv1d1(x)) + x = F.relu(self.conv1d2(x)) + x = F.relu(self.conv1d3(x)) + x = x.view(x.shape[0], -1) + x = F.selu(self.linear_layer_1(x)) + preds = self.linear_layer_2(x).squeeze(-1) + return preds + +class CharRNN(nn.Module): + def __init__(self, vocab_size, pad_idx, embedding_size, num_layers, hidden_size, device): + super(CharRNN, self).__init__() + self.device = device + self.embedding_layer = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx).to(self.device) + self.lstm_layer = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True).to(self.device) + self.linear_layer = nn.Linear(hidden_size, 1).to(self.device) + + def forward(self, x, lengths, hiddens=None): + x = self.embedding_layer(x) + x = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) + _, (h_last, _) = self.lstm_layer(x, hiddens) + preds = self.linear_layer(h_last[-1]).squeeze(-1) + return preds \ No newline at end of file diff --git a/regression/requirements_extra.txt b/regression/requirements_extra.txt new file mode 100644 index 0000000..66d3106 --- /dev/null +++ b/regression/requirements_extra.txt @@ -0,0 +1 @@ +scikit-learn==1.0.2 \ No newline at end of file diff --git a/regression/scripts/cedar/1.sh b/regression/scripts/cedar/1.sh new file mode 100644 index 0000000..d1ee5a6 --- /dev/null +++ b/regression/scripts/cedar/1.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=F2 seed=$SLURM_ARRAY_TASK_ID wandb_log=True string_rep=smiles wandb_run_name=smiles_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/10.sh b/regression/scripts/cedar/10.sh new file mode 100644 index 0000000..edb89f8 --- /dev/null +++ b/regression/scripts/cedar/10.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=ESR2 seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name=selfies_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/2.sh b/regression/scripts/cedar/2.sh new file mode 100644 index 0000000..7d6d854 --- /dev/null +++ b/regression/scripts/cedar/2.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=F2 seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name=selfies_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/3.sh b/regression/scripts/cedar/3.sh new file mode 100644 index 0000000..aa305de --- /dev/null +++ b/regression/scripts/cedar/3.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=KIT seed=$SLURM_ARRAY_TASK_ID wandb_log=True string_rep=smiles wandb_run_name=smiles_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/4.sh b/regression/scripts/cedar/4.sh new file mode 100644 index 0000000..44cee88 --- /dev/null +++ b/regression/scripts/cedar/4.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=KIT seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name=selfies_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/5.sh b/regression/scripts/cedar/5.sh new file mode 100644 index 0000000..f5a2775 --- /dev/null +++ b/regression/scripts/cedar/5.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=PARP1 seed=$SLURM_ARRAY_TASK_ID wandb_log=True string_rep=smiles wandb_run_name=smiles_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/6.sh b/regression/scripts/cedar/6.sh new file mode 100644 index 0000000..8b989f5 --- /dev/null +++ b/regression/scripts/cedar/6.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=PARP1 seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name=selfies_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/7.sh b/regression/scripts/cedar/7.sh new file mode 100644 index 0000000..28bd070 --- /dev/null +++ b/regression/scripts/cedar/7.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=PGR seed=$SLURM_ARRAY_TASK_ID wandb_log=True string_rep=smiles wandb_run_name=smiles_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/8.sh b/regression/scripts/cedar/8.sh new file mode 100644 index 0000000..f584fae --- /dev/null +++ b/regression/scripts/cedar/8.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=PGR seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name=selfies_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/cedar/9.sh b/regression/scripts/cedar/9.sh new file mode 100644 index 0000000..3c9df35 --- /dev/null +++ b/regression/scripts/cedar/9.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=00:45:00 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem_old.sif bash -c "source activate rl4chem && cd RL4Chem/regression && pip install pandas && pip install scikit-learn &&\ +python char_conv.py target=ESR2 seed=$SLURM_ARRAY_TASK_ID wandb_log=True string_rep=smiles wandb_run_name=smiles_chaconv_$SLURM_ARRAY_TASK_ID" \ No newline at end of file diff --git a/regression/scripts/mila/1.sh b/regression/scripts/mila/1.sh new file mode 100644 index 0000000..f98c662 --- /dev/null +++ b/regression/scripts/mila/1.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#SBATCH -t 00:30:00 +#SBATCH -c 4 +#SBATCH --partition=main +#SBATCH --mem=4G +#SBATCH --gres=gpu:1 +#SBATCH --array=1-3 + +array=(-1 "0.5" "10" "100") + +module --quiet load anaconda/3 +conda activate rl4chem +echo "activated conda environment" + +rsync -a $HOME/RL4Chem/ $SLURM_TMPDIR/RL4Chem +echo "moved code to slurm tmpdir" + +cd $SLURM_TMPDIR/RL4Chem/regression +python regression.py seed=$SLURM_ARRAY_TASK_ID wandb_log=True wandb_run_name="selfies_charnn_"$SLURM_ARRAY_TASK_ID \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c1e16ec..03b1cb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ importlib-resources==5.10.2 numpy==1.24.1 omegaconf==2.3.0 packaging==23.0 +pandas==1.5.3 pathtools==0.1.2 Pillow==9.4.0 protobuf==4.21.12 diff --git a/sac.py b/sac.py index 6256cec..a260f45 100644 --- a/sac.py +++ b/sac.py @@ -20,15 +20,16 @@ def forward(self, x): class Actor(nn.Module): def __init__(self, input_dims, hidden_dims, output_dims, dist='categorical'): super(Actor, self).__init__() - self.fc1 = nn.Linear(input_dims, hidden_dims) - self.fc2 = nn.Linear(hidden_dims, hidden_dims) - self.fc3 = nn.Linear(hidden_dims, output_dims) + + self.actor = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LayerNorm(hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, output_dims), + ) self.dist = dist def forward(self, x): - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - logits = self.fc3(x) + logits = self.actor(x) if self.dist == 'categorical': dist = td.Categorical(logits=logits) elif self.dist == 'one_hot_categorical': @@ -73,7 +74,6 @@ def __init__(self, device, obs_dims, num_actions, self.policy_update_interval = policy_update_interval self.target_update_interval = target_update_interval self.batch_size = batch_size - self.aug = NoiseAug() #exploration self.entropy_coefficient = entropy_coefficient @@ -86,7 +86,7 @@ def __init__(self, device, obs_dims, num_actions, self._init_networks(obs_dims, num_actions, hidden_dims) self._init_optims(lr) - def get_action(self, obs, step, eval=False): + def get_action(self, obs, eval=False): with torch.no_grad(): obs = torch.FloatTensor(obs).to(self.device) action_dist = self.actor(obs) @@ -103,24 +103,17 @@ def update(self, buffer, step): else: log = False - state_batch, action_batch, reward_batch, next_state_batch, done_batch = buffer.sample(self.batch_size) - - # state_batch = self.aug(torch.FloatTensor(state_batch).to(self.device)) - # next_state_batch = self.aug(torch.FloatTensor(next_state_batch).to(self.device)) - state_batch = torch.FloatTensor(state_batch).to(self.device) - next_state_batch = torch.FloatTensor(next_state_batch).to(self.device) - action_batch = torch.FloatTensor(action_batch).to(self.device) - reward_batch = torch.FloatTensor(reward_batch).to(self.device) - done_batch = torch.FloatTensor(done_batch).to(self.device) + state_batch, action_batch, reward_batch, next_state_batch, done_batch, time_batch = buffer.sample(self.batch_size) + state_batch = torch.tensor(state_batch, dtype=torch.float32, device=self.device) + next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32, device=self.device) + action_batch = torch.tensor(action_batch, dtype=torch.float32, device=self.device) + reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=self.device) + done_batch = torch.tensor(done_batch, dtype=torch.float32, device=self.device) discount_batch = self.gamma*(1-done_batch) self.update_critic(state_batch, action_batch, reward_batch, next_state_batch, discount_batch, log, metrics) - actor_log = False - if step % self.policy_update_interval == 0: - for _ in range(self.policy_update_interval): - actor_log = not actor_log if log else actor_log - self.update_actor(state_batch, actor_log, metrics) - + self.update_actor(state_batch, log, metrics) + if step%self.target_update_interval==0: utils.soft_update(self.critic_target, self.critic, self.tau) @@ -149,6 +142,8 @@ def update_critic(self, state_batch, action_batch, reward_batch, next_state_batc if log: metrics['mean_q_target'] = torch.mean(target_Q).item() + metrics['max_reward'] = torch.max(reward_batch).item() + metrics['min_reward'] = torch.min(reward_batch).item() metrics['variance_q_target'] = torch.var(target_Q).item() metrics['min_q_target'] = torch.min(target_Q).item() metrics['max_q_target'] = torch.max(target_Q).item() diff --git a/scripts/cedar/1.sh b/scripts/cedar/1.sh index 1a60480..5345f7c 100644 --- a/scripts/cedar/1.sh +++ b/scripts/cedar/1.sh @@ -9,8 +9,8 @@ array=(-1 "fa7" "parp1" "5ht1b") -rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_crl echo "moved code to slurm tmpdir" singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ -python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=23 wandb_log=True wandb_run_name=max_len_23 seed=1 num_sub_proc=20" \ No newline at end of file +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=22 wandb_log=True wandb_run_name=dqn_ln_max_len_22 seed=1 num_sub_proc=20" \ No newline at end of file diff --git a/scripts/cedar/2.sh b/scripts/cedar/2.sh index 1a60480..048d33f 100644 --- a/scripts/cedar/2.sh +++ b/scripts/cedar/2.sh @@ -13,4 +13,4 @@ rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --excl echo "moved code to slurm tmpdir" singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ -python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=23 wandb_log=True wandb_run_name=max_len_23 seed=1 num_sub_proc=20" \ No newline at end of file +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=22 wandb_log=True wandb_run_name=ln_max_len_22 seed=2 num_sub_proc=20" \ No newline at end of file diff --git a/scripts/cedar/3.sh b/scripts/cedar/3.sh index 1a60480..bf59e4e 100644 --- a/scripts/cedar/3.sh +++ b/scripts/cedar/3.sh @@ -13,4 +13,4 @@ rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --excl echo "moved code to slurm tmpdir" singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ -python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=23 wandb_log=True wandb_run_name=max_len_23 seed=1 num_sub_proc=20" \ No newline at end of file +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=22 wandb_log=True wandb_run_name=ln_max_len_22 seed=3 num_sub_proc=20" \ No newline at end of file diff --git a/scripts/cedar/4.sh b/scripts/cedar/4.sh index de81bd5..f5c39af 100644 --- a/scripts/cedar/4.sh +++ b/scripts/cedar/4.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --account=rrg-gberseth -#SBATCH --time=5:00:00 -#SBATCH --cpus-per-task=4 +#SBATCH --time=02:00:00 +#SBATCH --cpus-per-task=24 #SBATCH --mem=4G #SBATCH --gres=gpu:v100l:1 #SBATCH --array=1-3 -rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +array=(-1 "fa7" "parp1" "5ht1b") -cd $SLURM_TMPDIR/RL4Chem +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" -python train.py max_selfie_length=20 entropy_coefficient=0.1 wandb_log=True seed=$SLURM_ARRAY_TASK_ID wandb_run_name=max_len20_ent0.1_seed$SLURM_ARRAY_TASK_ID \ No newline at end of file +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=25 wandb_log=True wandb_run_name=ln_max_len_25 seed=1 num_sub_proc=24" \ No newline at end of file diff --git a/scripts/cedar/5.sh b/scripts/cedar/5.sh new file mode 100644 index 0000000..995a541 --- /dev/null +++ b/scripts/cedar/5.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=02:00:00 +#SBATCH --cpus-per-task=24 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +array=(-1 "fa7" "parp1" "5ht1b") + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=25 wandb_log=True wandb_run_name=ln_max_len_25 seed=2 num_sub_proc=24" \ No newline at end of file diff --git a/scripts/cedar/6.sh b/scripts/cedar/6.sh new file mode 100644 index 0000000..c307e3f --- /dev/null +++ b/scripts/cedar/6.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --account=rrg-gberseth +#SBATCH --time=02:00:00 +#SBATCH --cpus-per-task=24 +#SBATCH --mem=4G +#SBATCH --gres=gpu:v100l:1 +#SBATCH --array=1-3 + +array=(-1 "fa7" "parp1" "5ht1b") + +rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --exclude=env_rl4chem +echo "moved code to slurm tmpdir" + +singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\ +python train.py target=${array[SLURM_ARRAY_TASK_ID]} max_selfie_length=25 wandb_log=True wandb_run_name=ln_max_len_25 seed=3 num_sub_proc=24" \ No newline at end of file diff --git a/train.py b/train.py index b4c1e17..472f41e 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,19 @@ from pathlib import Path from omegaconf import DictConfig +from collections import defaultdict + +class Topk(): + def __init__(self): + self.top25 = [] + + def add(self, scores): + scores = np.sort(np.concatenate([self.top25, scores])) + self.top25 = scores[-25:] + + def top(self, k): + assert k <= 25 + return self.top25[-k] def make_agent(env, device, cfg): obs_dims = np.prod(env.observation_shape) @@ -21,6 +34,7 @@ def make_agent(env, device, cfg): env_buffer = utils.ReplayMemory(env_buffer_size, obs_dims, obs_dtype, action_dtype) fresh_env_buffer = utils.FreshReplayMemory(cfg.parallel_molecules, env.episode_length, obs_dims, obs_dtype, action_dtype) + docking_buffer = defaultdict(lambda: None) if cfg.agent == 'sac': from sac import SacAgent @@ -29,7 +43,7 @@ def make_agent(env, device, cfg): cfg.hidden_dims, cfg.wandb_log, cfg.agent_log_interval) else: raise NotImplementedError - return agent, env_buffer, fresh_env_buffer + return agent, env_buffer, fresh_env_buffer, docking_buffer def make_env(cfg): print(cfg.id) @@ -39,160 +53,147 @@ def make_env(cfg): else: raise NotImplementedError -class Workspace: - def __init__(self, cfg): - self.work_dir = Path.cwd() - self.cfg = cfg - if self.cfg.save_snapshot: - self.checkpoint_path = self.work_dir / 'checkpoints' - self.checkpoint_path.mkdir(exist_ok=True) - - self.set_seed() - self.device = torch.device(cfg.device) - self.train_env, self.eval_env = make_env(self.cfg) - self.agent, self.env_buffer, self.fresh_env_buffer = make_agent(self.train_env, self.device, self.cfg) - self.current_reward_batch = np.zeros((cfg.parallel_molecules,), dtype=np.float32) - self.current_reward_info = dict() - self._train_step = 0 - self._train_episode = 0 - self._best_eval_returns = -np.inf - self._best_train_returns = -np.inf - - def set_seed(self): - random.seed(self.cfg.seed) - np.random.seed(self.cfg.seed) - torch.manual_seed(self.cfg.seed) - torch.cuda.manual_seed_all(self.cfg.seed) - - def _explore(self): - print('random exploration of ', self.cfg.parallel_molecules, ' number of molecules begins') - explore_steps = self.cfg.parallel_molecules * self.train_env.episode_length - - state, done = self.train_env.reset(), False - for _ in range(explore_steps): - action = np.random.randint(self.train_env.num_actions) - next_state, reward, done, info = self.train_env.step(action) - - self.fresh_env_buffer.push((state, action, reward, next_state, done)) - - if done: - state, done = self.train_env.reset(), False - else: - state = next_state +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def collect_random_molecule(env, fresh_env_buffer): + state, done, t = env.reset(), False, 0 + while not done: + action = np.random.randint(env.num_actions) + next_state, reward, done, info = env.step(action) + fresh_env_buffer.push((state, action, reward, next_state, done, t)) + t += 1 + state = next_state + return info['episode'] + +def explore(cfg, train_env, env_buffer, fresh_env_buffer, docking_buffer): + explore_mols = 0 + while explore_mols < cfg.explore_molecules: + episode_info = collect_random_molecule(train_env, fresh_env_buffer) + + if docking_buffer[episode_info['smiles']] is not None: + fresh_env_buffer.remove_last_episode(episode_info['l']) + else: + docking_buffer[episode_info['smiles']] = 0 + train_env._add_smiles_to_batch(episode_info['smiles']) + explore_mols += 1 + + reward_start_time = time.time() + parallel_reward_batch, parallel_reward_info = train_env.get_reward_batch() + reward_eval_time = time.time() - reward_start_time + + #Update main buffer and docking_buffer and reset fresh buffer + fresh_env_buffer.update_final_rewards(parallel_reward_batch) + env_buffer.push_fresh_buffer(fresh_env_buffer) + fresh_env_buffer.reset() + for i, smiles_string in enumerate(parallel_reward_info['smiles']): + docking_buffer[smiles_string] = parallel_reward_batch[i] + + print('Total strings explored = ', cfg.explore_molecules, ' Reward evaluation time = ', reward_eval_time) + print(np.sort(parallel_reward_batch)) + + return parallel_reward_batch + +def collect_molecule(env, agent, fresh_env_buffer): + state, done, t = env.reset(), False, 0 + while not done: + action = agent.get_action(state) + next_state, reward, done, info = env.step(action) + fresh_env_buffer.push((state, action, reward, next_state, done, t)) + t += 1 + state = next_state + return info['episode'] + +def train(cfg): + set_seed(cfg.seed) + device = torch.device(cfg.device) + + #get train and eval envs + train_env, eval_env = make_env(cfg) + + #get agent and memory + agent, env_buffer, fresh_env_buffer, docking_buffer = make_agent(train_env, device, cfg) + topk = Topk() + + #explore + docking_scores = explore(cfg, train_env, env_buffer, fresh_env_buffer, docking_buffer) + topk.add(docking_scores) + + #train + train_step = 0 + molecule_counter = 0 + unique_molecule_counter = 0 + while train_step < cfg.num_train_steps: - reward_start_time = time.time() - self.current_reward_batch, self.current_reward_info = self.train_env.get_reward_batch() - reward_eval_time = time.time() - reward_start_time - self.fresh_env_buffer.update_final_rewards(self.current_reward_batch) - self.env_buffer.push_fresh_buffer(self.fresh_env_buffer) - self.fresh_env_buffer.reset() - print('Total strings = ', len(self.current_reward_info['selfies']), 'Unique strings = ', len(set(self.current_reward_info['selfies'])), ' Evaluation time = ', reward_eval_time) - print(np.sort(self.current_reward_batch)) - - def train(self): - self._eval() - self._explore() - - parallel_counter = 0 - state, done, episode_start_time, episode_metrics = self.train_env.reset(), False, time.time(), dict() + + episode_info = collect_molecule(train_env, agent, fresh_env_buffer) + molecule_counter += 1 + if docking_buffer[episode_info['smiles']] is not None: + fresh_env_buffer.update_last_episode_reward(docking_buffer[episode_info['smiles']]) + else: + docking_buffer[episode_info['smiles']] = 0 + train_env._add_smiles_to_batch(episode_info['smiles']) + unique_molecule_counter += 1 + + for _ in range(episode_info['l']): + agent.update(env_buffer, train_step) + train_step += 1 - for _ in range(1, self.cfg.num_train_steps): - action = self.agent.get_action(state, self._train_step) - next_state, reward, done, info = self.train_env.step(action) - self.fresh_env_buffer.push((state, action, reward, next_state, done)) - self._train_step += 1 + print('Total strings = ', molecule_counter, 'Unique strings = ', unique_molecule_counter) + + if molecule_counter % cfg.parallel_molecules == 0 and unique_molecule_counter != 0: + + reward_start_time = time.time() + parallel_reward_batch, parallel_reward_info = train_env.get_reward_batch() + reward_eval_time = time.time() - reward_start_time - if done: - self._train_episode += 1 - print("Episode: {}, total numsteps: {}".format(self._train_episode, self._train_step)) - state, done, episode_start_time = self.train_env.reset(), False, time.time() - if self.cfg.wandb_log: - episode_metrics['episodic_length'] = info["episode"]["l"] - episode_metrics['steps_per_second'] = info["episode"]["l"]/(time.time() - episode_start_time) - episode_metrics['env_buffer_length'] = len(self.env_buffer) - episode_metrics['episodic_reward'] = self.current_reward_batch[parallel_counter] - episode_metrics['episodic_selfies_len'] = self.current_reward_info['len_selfies'][parallel_counter] - wandb.log(episode_metrics, step=self._train_step) - parallel_counter += 1 - else: - state = next_state - - self.agent.update(self.env_buffer, self._train_step) - - if self._train_step % self.cfg.eval_episode_interval == 0: - self._eval() - - if self.cfg.save_snapshot and self._train_step % self.cfg.save_snapshot_interval == 0: - self.save_snapshot() - - if parallel_counter == self.cfg.parallel_molecules: - reward_start_time = time.time() - self.current_reward_batch, self.current_reward_info = self.train_env.get_reward_batch() - reward_eval_time = time.time() - reward_start_time - self.fresh_env_buffer.update_final_rewards(self.current_reward_batch) - self.env_buffer.push_fresh_buffer(self.fresh_env_buffer) - self.fresh_env_buffer.reset() - - unique_strings = len(set(self.current_reward_info['selfies'])) - print('Total strings = ', len(self.current_reward_info['selfies']), 'Unique strings = ', unique_strings, ' Evaluation time = ', reward_eval_time) - print(np.sort(self.current_reward_batch)) - best_idx = np.argmax(self.current_reward_batch) - print(self.current_reward_info['smiles'][best_idx]) - - if self.cfg.wandb_log: - wandb.log({'reward_eval_time' : reward_eval_time, - 'unique strings': unique_strings}, step = self._train_step) - parallel_counter = 0 - - def _eval(self): - steps = 0 - for _ in range(self.cfg.num_eval_episodes): - done = False - state = self.eval_env.reset() - while not done: - action = self.agent.get_action(state, self._train_step, True) - next_state, _, done ,info = self.eval_env.step(action) - state = next_state - - steps += info["episode"]["l"] - - final_rewards, _ = self.eval_env.get_reward_batch() - eval_metrics = dict() - eval_metrics['eval_episodic_return'] = sum(final_rewards)/self.cfg.num_eval_episodes - eval_metrics['eval_episodic_length'] = steps/self.cfg.num_eval_episodes - - print("Episode: {}, total numsteps: {}, average Evaluation return: {}".format(self._train_episode, self._train_step, round(eval_metrics['eval_episodic_return'], 2))) - - if self.cfg.save_snapshot and sum(final_rewards)/self.cfg.num_eval_episodes >= self._best_eval_returns: - self.save_snapshot(best=True) - self._best_eval_returns = sum(final_rewards)/self.cfg.num_eval_episodes - - if self.cfg.wandb_log: - wandb.log(eval_metrics, step = self._train_step) - - def save_snapshot(self, best=False): - if best: - snapshot = Path(self.checkpoint_path) / 'best.pt' - else: - snapshot = Path(self.checkpoint_path) / Path(str(self._train_step)+'.pt') - save_dict = self.agent.get_save_dict() - torch.save(save_dict, snapshot) + topk.add(parallel_reward_batch) + + #Update main buffer and docking_buffer and reset fresh buffer + fresh_env_buffer.update_final_rewards(parallel_reward_batch) + env_buffer.push_fresh_buffer(fresh_env_buffer) + fresh_env_buffer.reset() + for i, smiles_string in enumerate(parallel_reward_info['smiles']): + docking_buffer[smiles_string] = parallel_reward_batch[i] + + print('Evaluation time = ', reward_eval_time) + print(np.sort(parallel_reward_batch)) + + if cfg.wandb_log: + metrics = dict() + metrics['reward_eval_time'] = reward_eval_time + metrics['total_strings'] = molecule_counter + metrics['unique_strings'] = unique_molecule_counter + metrics['env_buffer_size'] = len(env_buffer) + metrics['top1'] = topk.top(1) + metrics['top5'] = topk.top(5) + metrics['top25'] = topk.top(25) + wandb.log(metrics, step=train_step) + + return docking_buffer + #eval + #To-do @hydra.main(config_path='cfgs', config_name='config', version_base=None) def main(cfg: DictConfig): - from train import Workspace as W + from train import train hydra_cfg = hydra.core.hydra_config.HydraConfig.get() if cfg.wandb_log: project_name = 'rl4chem_' + cfg.target - with wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=hydra_cfg['runtime']['output_dir']): - wandb.run.name = cfg.wandb_run_name - workspace = W(cfg) - workspace.train() - else: - workspace = W(cfg) - workspace.train() - + wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=hydra_cfg['runtime']['output_dir']) + wandb.run.name = cfg.wandb_run_name + + docking_buffer = train(cfg) + sorted_docking_buffer = sorted(docking_buffer.items(), key=lambda x:x[1]) + + with open(str(hydra_cfg['runtime']['output_dir']) + '/molecules.txt', 'w') as f: + for (smiles, score) in reversed(sorted_docking_buffer): + f.write(str(score) + ' || ' + smiles + "\n") + if __name__ == '__main__': main() \ No newline at end of file diff --git a/utils.py b/utils.py index 8ad5f0c..bb5a294 100644 --- a/utils.py +++ b/utils.py @@ -1,9 +1,10 @@ +import random import numpy as np import torch.nn as nn from typing import Iterable +from collections import namedtuple, deque # Replay memory - class ReplayMemory(): def __init__(self, buffer_limit, obs_dims, obs_dtype, action_dtype): self.buffer_limit = buffer_limit @@ -15,87 +16,103 @@ def __init__(self, buffer_limit, obs_dims, obs_dtype, action_dtype): self.action = np.empty((buffer_limit, 1), dtype=action_dtype) self.reward = np.empty((buffer_limit,), dtype=np.float32) self.terminal = np.empty((buffer_limit,), dtype=bool) + self.time_step = np.empty((buffer_limit,), dtype=np.uint8) self.idx = 0 self.full = False def push(self, transition): - state, action, reward, next_state, done = transition + state, action, reward, next_state, done, time_step = transition self.observation[self.idx] = state self.next_observation[self.idx] = next_state self.action[self.idx] = action self.reward[self.idx] = reward self.terminal[self.idx] = done + self.time_step[self.idx] = time_step self.idx = (self.idx + 1) % self.buffer_limit self.full = self.full or self.idx == 0 def push_batch(self, transitions, N): - states, actions, rewards, next_states, dones = transitions + states, actions, rewards, next_states, dones, time_steps = transitions idxs = np.arange(self.idx, self.idx + N) % self.buffer_limit self.observation[idxs] = states self.next_observation[idxs] = next_states self.action[idxs] = actions self.reward[idxs] = rewards self.terminal[idxs] = dones + self.time_step[idxs] = time_steps self.full = self.full or (self.idx + N >= self.buffer_limit) self.idx = (idxs[-1] + 1) % self.buffer_limit def push_fresh_buffer(self, fresh_buffer): - N = fresh_buffer.num_episodes * fresh_buffer.max_episode_len + N = fresh_buffer.buffer_limit if fresh_buffer.full else fresh_buffer.step_idx idxs = np.arange(self.idx, self.idx + N) % self.buffer_limit - self.observation[idxs] = fresh_buffer.observation.reshape((-1, self.obs_dims)) - self.next_observation[idxs] = fresh_buffer.next_observation.reshape((-1, self.obs_dims)) - self.action[idxs] = fresh_buffer.action.reshape((-1, 1)) - self.reward[idxs] = fresh_buffer.reward.reshape((-1,)) - self.terminal[idxs] = fresh_buffer.terminal.reshape((-1,)) - + self.observation[idxs] = fresh_buffer.observation[:N] + self.next_observation[idxs] = fresh_buffer.next_observation[:N] + self.action[idxs] = fresh_buffer.action[:N] + self.reward[idxs] = fresh_buffer.reward[:N] + self.terminal[idxs] = fresh_buffer.terminal[:N] + self.time_step[idxs] = fresh_buffer.time_step[:N] self.full = self.full or (self.idx + N >= self.buffer_limit) self.idx = (idxs[-1] + 1) % self.buffer_limit def sample(self, n): - idxes = np.random.randint(0, self.buffer_limit if self.full else self.idx, size=n) - - return self.observation[idxes], self.action[idxes], self.reward[idxes], self.next_observation[idxes], self.terminal[idxes] + idxes = np.random.randint(0, self.buffer_limit if self.full else self.idx, size=n) + return self.observation[idxes], self.action[idxes], self.reward[idxes], self.next_observation[idxes], self.terminal[idxes], self.time_step[idxes] def __len__(self): return self.buffer_limit if self.full else self.idx+1 class FreshReplayMemory(): def __init__(self, num_episodes, max_episode_len, obs_dims, obs_dtype, action_dtype): - self.obs_dtype = obs_dtype self.obs_dims = obs_dims self.action_dtype = action_dtype - self.num_episodes = num_episodes - self.max_episode_len = max_episode_len + self.buffer_limit = num_episodes * max_episode_len self.reset() def reset(self): - self.observation = np.empty((self.num_episodes, self.max_episode_len, self.obs_dims), dtype=self.obs_dtype) - self.next_observation = np.empty((self.num_episodes, self.max_episode_len, self.obs_dims), dtype=self.obs_dtype) - self.action = np.empty((self.num_episodes, self.max_episode_len, 1), dtype=self.action_dtype) - self.reward = np.empty((self.num_episodes, self.max_episode_len), dtype=np.float32) - self.terminal = np.empty((self.num_episodes, self.max_episode_len), dtype=bool) - + self.observation = np.empty((self.buffer_limit, self.obs_dims), dtype=self.obs_dtype) + self.next_observation = np.empty((self.buffer_limit, self.obs_dims), dtype=self.obs_dtype) + self.action = np.empty((self.buffer_limit, 1), dtype=self.action_dtype) + self.reward = np.empty((self.buffer_limit,), dtype=np.float32) + self.terminal = np.empty((self.buffer_limit,), dtype=bool) + self.time_step = np.empty((self.buffer_limit,), dtype=np.uint8) + self.reward_indices = [] self.full = False self.step_idx = 0 - self.episode_idx = 0 def push(self, transition): - state, action, reward, next_state, done = transition - self.observation[self.episode_idx, self.step_idx] = state - self.next_observation[self.episode_idx, self.step_idx] = next_state - self.action[self.episode_idx, self.step_idx] = action - self.reward[self.episode_idx, self.step_idx] = reward - self.terminal[self.episode_idx, self.step_idx] = done + state, action, reward, next_state, done, time_step = transition + self.observation[self.step_idx] = state + self.next_observation[self.step_idx] = next_state + self.action[self.step_idx] = action + self.reward[self.step_idx] = reward + self.terminal[self.step_idx] = done + self.time_step[self.step_idx] = time_step + + if done: + self.reward_indices.append(self.step_idx) + + self.step_idx = (self.step_idx + 1) % self.buffer_limit + self.full = self.full or self.step_idx == 0 + + def remove_last_episode(self, episode_len): + last_episode_start_id = self.step_idx - episode_len + if self.full : assert self.step_idx == 0 + if last_episode_start_id < 0 : assert self.full + if last_episode_start_id < 0 : assert self.step_idx == 0 + self.reward_indices.pop() + self.step_idx = last_episode_start_id % self.buffer_limit + if self.full : self.full = not last_episode_start_id < 0 - self.step_idx = (self.step_idx + 1) % self.max_episode_len - self.episode_idx = self.episode_idx + 1 * (self.step_idx == 0) - self.full = self.full or self.episode_idx == self.num_episodes + def update_last_episode_reward(self, reward): + self.reward[self.reward_indices[-1]] = reward + self.reward_indices.pop() def update_final_rewards(self, reward_batch): - self.reward[:, -1] = reward_batch - + self.reward[self.reward_indices] = reward_batch + # NN weight utils def weight_init(m):