-
Notifications
You must be signed in to change notification settings - Fork 4
/
sample.py
114 lines (99 loc) · 3.59 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright: Wentao Shi, 2021
from dataloader import SELFIEVocab, RegExVocab, CharVocab
from model import RNN
import argparse
import torch
import yaml
import selfies as sf
from tqdm import tqdm
from rdkit import Chem
# suppress rdkit error
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
def get_args():
parser = argparse.ArgumentParser("python")
parser.add_argument("-result_dir",
required=True,
help="directory of result files including configuration, \
loss, trained model, and sampled molecules"
)
parser.add_argument("-batch_size",
required=False,
default=2048,
help="number of samples to generate per mini-batch"
)
parser.add_argument("-num_batches",
required=False,
default=20,
help="number of batches to generate"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
result_dir = args.result_dir
batch_size = int(args.batch_size)
num_batches = int(args.num_batches)
# load the configuartion file in output
config_dir = result_dir + "config.yaml"
with open(config_dir, 'r') as f:
config = yaml.full_load(f)
# detect cpu or gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)
# load vocab
which_vocab, vocab_path = config["which_vocab"], config["vocab_path"]
if which_vocab == "selfies":
vocab = SELFIEVocab(vocab_path)
elif which_vocab == "regex":
vocab = RegExVocab(vocab_path)
elif which_vocab == "char":
vocab = CharVocab(vocab_path)
else:
raise ValueError("Wrong vacab name for configuration which_vocab!")
# load model
rnn_config = config['rnn_config']
model = RNN(rnn_config).to(device)
model.load_state_dict(torch.load(
config['out_dir'] + 'trained_model.pt',
map_location=torch.device(device)))
model.eval()
# sample, filter out invalid molecules, and save the valid molecules
out_file = open(result_dir + "sampled_molecules.out", "w")
num_valid, num_invalid = 0, 0
for _ in tqdm(range(num_batches)):
# sample molecules as integers
sampled_ints = model.sample(
batch_size=batch_size,
vocab=vocab,
device=device
)
# convert integers back to SMILES
molecules = []
sampled_ints = sampled_ints.tolist()
for ints in sampled_ints:
molecule = []
for x in ints:
if vocab.int2tocken[x] == '<eos>':
break
else:
molecule.append(vocab.int2tocken[x])
molecules.append("".join(molecule))
# convert SELFIES back to SMILES
if vocab.name == 'selfies':
molecules = [sf.decoder(x) for x in molecules]
# save the valid sampled SMILES to output file,
for smiles in molecules:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
num_invalid += 1
else:
num_valid += 1
out_file.write(smiles + '\n')
except:
num_valid += 1
pass
# and compute the valid rate
print("sampled {} valid SMILES out of {}, success rate: {}".format(
num_valid, num_valid + num_invalid, num_valid / (num_valid + num_invalid))
)