Skip to content

Commit

Permalink
Sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
shiwentao00 committed Feb 26, 2021
1 parent 8eeac5c commit c826c65
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
6 changes: 6 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, vocab_path) -> None:
with open(vocab_path, 'r') as f:
self.vocab = yaml.full_load(f)

self.int2tocken = {value: key for key, value in self.vocab.items()}
self.int2tocken[0] = '<pad>'

def tokenize_smiles(self, mol):
"""convert the smiles to selfies, then return
integer tokens."""
Expand All @@ -78,6 +81,9 @@ def tokenize_smiles(self, mol):

return ints

def list2selfies(self, selfies):
return "".join(selfies)


def pad_collate(batch):
"""
Expand Down
42 changes: 39 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.functional import softmax


class RNN(torch.nn.Module):
Expand Down Expand Up @@ -47,9 +48,44 @@ def forward(self, data, lengths):
# the targets will also be packed.
return embeddings

def sample(self):
# TO-DO
pass
def sample(self, vocab):
output = []

# get integer of "start of sequence"
start_int = vocab.vocab['<sos>']

# create a tensor of shape [batch_size=1, seq_step=1]
sos = torch.tensor(start_int).unsqueeze(
dim=0).unsqueeze(dim=0)

# sample first output
x = self.embedding_layer(sos)
x, (h, c) = self.rnn(x)
x = self.linear(x)
x = softmax(x, dim=-1)
x = torch.multinomial(x.squeeze(), 1)
output.append(x.item())

# use first output to iteratively sample until <eos> occurs
while output[-1] != vocab.vocab['<eos>']:
x = x.unsqueeze(dim=0)
x = self.embedding_layer(x)
x, (h, c) = self.rnn(x)
x = self.linear(x)
x = softmax(x, dim=-1)
x = torch.multinomial(x.squeeze(), 1)
output.append(x.item())

# convert ingete to tokens
output = [vocab.int2tocken[x] for x in output]

# popout <eos>
output.pop()

# convert to SLEFIES
output = vocab.list2selfies(output)

return output

def beam_search(self):
# TO-DO
Expand Down
6 changes: 6 additions & 0 deletions output_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import selfies as sf
print(sf.decoder('[C][/N][Br][=P+expl][/B-expl][C@expl][Br+expl][S@+expl][C@expl][/C@expl][=O+expl][Br][I+expl][=Nexpl][/C@Hexpl][/C@expl][C][\Siexpl][C][IH2expl][Branch2_3][=B][=N-expl][/B][ # C][=O+expl][\S@@expl][Branch2_1][\O-expl][/C@expl][\C@@expl][Branch2_1][=O+expl][=N][/CH-expl][Branch1_2][/O-expl][P@@expl][/S][C@expl][C][\C@expl][=S+expl][/C@@Hexpl][B-expl][\Cl][Expl\Ring2][/N+expl][\C@@Hexpl][B@@-expl][S@@+expl][Branch1_3][/F][#C][C][O][Expl=Ring2][\S@expl][C][C][Branch1_2][C][Oexpl][C@Hexpl][B@@-expl][\C-expl][P+expl][CH2-expl][=S@expl][=B][B-expl][SHexpl][C][O+expl][\I][P@expl][/C@@expl][P@Hexpl][N][C][C][B][P][\S@expl][P@@expl][Sn+2expl][Br+expl][\S@expl][C@expl][=N-expl][/B][=N+expl][/B-expl][\C@@expl][I+expl][\O][Oexpl][\Snexpl][C][C][C][=P][/CH-expl][C][C][C-expl][SH3expl][Branch1_3][Br][\C-expl][/N][/Cexpl][C][=S@@expl][C][N-expl][C][IH2expl][\Siexpl][Sn+3expl][C][N][S@@expl][=I][=S@+expl][P@expl][=O][/Br][#N][/P@@expl][=P][/C@@expl][\O-expl][Cl][/O][/S][B][N@@H+expl][Expl/Ring2][#C][I+expl][S@Hexpl][P@@Hexpl][/P@@expl][=N+expl][Sn+3expl][\C@expl][Branch1_2][/P@@expl][/C@expl][Br][\P][/I][\B][\O-expl][\C@expl]'))
print(sf.decoder('[C][/N][Br][=P+expl][/B-expl][C@expl][Br+expl][S@+expl][C@expl][/C@expl][=O+expl][Br][I+expl][=Nexpl][/C@Hexpl][/C@expl][C][\Siexpl][C][IH2expl][Branch2_3][=B][=N-expl][/B][#C][=O+expl][\S@@expl][Branch2_1][\O-expl][/C@expl][\C@@expl][Branch2_1][=O+expl][=N][/CH-expl][Branch1_2][/O-expl][P@@expl][/S][C@expl][C][\C@expl][=S+expl][/C@@Hexpl][B-expl][\Cl][Expl\Ring2][/N+expl][\C@@Hexpl][B@@-expl][S@@+expl][Branch1_3][/F][#C][C][O][Expl=Ring2][\S@expl][C][C][Branch1_2][C][Oexpl][C@Hexpl][B@@-expl][\C-expl][P+expl][CH2-expl][=S@expl][=B][B-expl][SHexpl][C][O+expl][\I][P@expl][/C@@expl][P@Hexpl][N][C][C][B][P][\S@expl][P@@expl][Sn+2expl][Br+expl][\S@expl][C@expl][=N-expl][/B][=N+expl][/B-expl][\C@@expl][I+expl][\O][Oexpl][\Snexpl][C][C][C][=P][/CH-expl][C][C][C-expl][SH3expl][Branch1_3][Br][\C-expl][/N][/Cexpl][C][=S@@expl][C][N-expl][C][IH2expl][\Siexpl][Sn+3expl][C][N][S@@expl][=I][=S@+expl][P@expl][=O][/Br][#N][/P@@expl][=P][/C@@expl][\O-expl][Cl][/O][/S][B][N@@H+expl][Expl/Ring2][#C][I+expl][S@Hexpl][P@@Hexpl][/P@@expl][=N+expl][Sn+3expl][\C@expl][Branch1_2][/P@@expl][/C@expl][Br][\P][/I][\B][\O-expl][\C@expl]'))

# Sampled SMILES
# C/NBr
35 changes: 35 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import yaml
from dataloader import SELFIEVocab
import selfies as sf
from model import RNN


if __name__ == "__main__":
# detect cpu or gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)

# load the configuartion file in output
config_dir = "../results/run_1/config.yaml"
with open(config_dir, 'r') as f:
config = yaml.full_load(f)

# load vocab
vocab = SELFIEVocab(vocab_path=config['vocab_path'])

# load model
rnn_config = config['rnn_config']
model = RNN(rnn_config).to(device)
model.load_state_dict(torch.load(
config['out_dir'] + 'trained_model.pt',
map_location=torch.device('cpu')))

# feed the model <sos> and start sampling
# output sampled SELFIES
selfies = model.sample(vocab)
print('Sampled SELFIES: \n', selfies)

# output sampled SMILES
smiles = sf.decoder(selfies)
print('Sampled SMILES: \n', smiles)

0 comments on commit c826c65

Please sign in to comment.