diff --git a/README.md b/README.md index 1478497..0903961 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,14 @@ Molecule-RNN is a recurrent neural network built with Pytorch to generate molecu 2. Modify the path of dataset in ```train.yaml``` to your downloaded dataset by setting the value of ```dataset_dir```. 3. Run the training script. -```python train.py``` +``` +python train.py +``` The training loss: ## Sampling We can generate molecules by sampling the model according to the output distribution. -```python sample.py``` \ No newline at end of file +``` +python sample.py +``` \ No newline at end of file diff --git a/dataloader.py b/dataloader.py index 862d16d..c202a5c 100644 --- a/dataloader.py +++ b/dataloader.py @@ -19,6 +19,22 @@ def dataloader_gen(dataset_dir, percentage, vocab_path, batch_size, shuffle, dro return dataloader, len(dataset) +def pad_collate(batch): + """ + Put the sequences of different lengths in a minibatch by paddding. + """ + global PADDING_IDX + + lengths = [len(x) for x in batch] + + batch = [torch.tensor(x) for x in batch] + + # use any ingeter that is not in vocab as padding + x_padded = pad_sequence(batch, batch_first=True, padding_value=PADDING_IDX) + + return x_padded, lengths + + class SMILESDataset(Dataset): def __init__(self, dataset_dir: str, percentage: float, vocab): """ @@ -67,7 +83,6 @@ def __init__(self, vocab_path) -> None: self.vocab = yaml.full_load(f) self.int2tocken = {value: key for key, value in self.vocab.items()} - self.int2tocken[0] = '' def tokenize_smiles(self, mol): """convert the smiles to selfies, then return @@ -83,16 +98,3 @@ def tokenize_smiles(self, mol): def list2selfies(self, selfies): return "".join(selfies) - - -def pad_collate(batch): - """ - Put the sequences of different lengths in a minibatch by paddding. - """ - lengths = [len(x) for x in batch] - - batch = [torch.tensor(x) for x in batch] - - x_padded = pad_sequence(batch, batch_first=True, padding_value=0) - - return x_padded, lengths diff --git a/model.py b/model.py index bba5383..59af406 100644 --- a/model.py +++ b/model.py @@ -1,19 +1,18 @@ import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence from torch.nn.functional import softmax class RNN(torch.nn.Module): def __init__(self, rnn_config): super(RNN, self).__init__() + global PADDING_IDX self.embedding_layer = nn.Embedding( num_embeddings=rnn_config['num_embeddings'], embedding_dim=rnn_config['embedding_dim'], - padding_idx=0 + padding_idx=PADDING_IDX ) self.rnn = nn.LSTM( @@ -24,8 +23,10 @@ def __init__(self, rnn_config): dropout=rnn_config['dropout'] ) + # output does not include , so + # decrease the num_embeddings by 2 self.linear = nn.Linear( - rnn_config['hidden_size'], rnn_config['num_embeddings']) + rnn_config['hidden_size'], rnn_config['num_embeddings'] - 1) def forward(self, data, lengths): embeddings = self.embedding_layer(data) @@ -86,7 +87,3 @@ def sample(self, vocab): output = vocab.list2selfies(output) return output - - def beam_search(self): - # TO-DO - pass diff --git a/output_sample.py b/output_sample.py deleted file mode 100644 index 6e8a053..0000000 --- a/output_sample.py +++ /dev/null @@ -1,6 +0,0 @@ -import selfies as sf -print(sf.decoder('[C][/N][Br][=P+expl][/B-expl][C@expl][Br+expl][S@+expl][C@expl][/C@expl][=O+expl][Br][I+expl][=Nexpl][/C@Hexpl][/C@expl][C][\Siexpl][C][IH2expl][Branch2_3][=B][=N-expl][/B][ # C][=O+expl][\S@@expl][Branch2_1][\O-expl][/C@expl][\C@@expl][Branch2_1][=O+expl][=N][/CH-expl][Branch1_2][/O-expl][P@@expl][/S][C@expl][C][\C@expl][=S+expl][/C@@Hexpl][B-expl][\Cl][Expl\Ring2][/N+expl][\C@@Hexpl][B@@-expl][S@@+expl][Branch1_3][/F][#C][C][O][Expl=Ring2][\S@expl][C][C][Branch1_2][C][Oexpl][C@Hexpl][B@@-expl][\C-expl][P+expl][CH2-expl][=S@expl][=B][B-expl][SHexpl][C][O+expl][\I][P@expl][/C@@expl][P@Hexpl][N][C][C][B][P][\S@expl][P@@expl][Sn+2expl][Br+expl][\S@expl][C@expl][=N-expl][/B][=N+expl][/B-expl][\C@@expl][I+expl][\O][Oexpl][\Snexpl][C][C][C][=P][/CH-expl][C][C][C-expl][SH3expl][Branch1_3][Br][\C-expl][/N][/Cexpl][C][=S@@expl][C][N-expl][C][IH2expl][\Siexpl][Sn+3expl][C][N][S@@expl][=I][=S@+expl][P@expl][=O][/Br][#N][/P@@expl][=P][/C@@expl][\O-expl][Cl][/O][/S][B][N@@H+expl][Expl/Ring2][#C][I+expl][S@Hexpl][P@@Hexpl][/P@@expl][=N+expl][Sn+3expl][\C@expl][Branch1_2][/P@@expl][/C@expl][Br][\P][/I][\B][\O-expl][\C@expl]')) -print(sf.decoder('[C][/N][Br][=P+expl][/B-expl][C@expl][Br+expl][S@+expl][C@expl][/C@expl][=O+expl][Br][I+expl][=Nexpl][/C@Hexpl][/C@expl][C][\Siexpl][C][IH2expl][Branch2_3][=B][=N-expl][/B][#C][=O+expl][\S@@expl][Branch2_1][\O-expl][/C@expl][\C@@expl][Branch2_1][=O+expl][=N][/CH-expl][Branch1_2][/O-expl][P@@expl][/S][C@expl][C][\C@expl][=S+expl][/C@@Hexpl][B-expl][\Cl][Expl\Ring2][/N+expl][\C@@Hexpl][B@@-expl][S@@+expl][Branch1_3][/F][#C][C][O][Expl=Ring2][\S@expl][C][C][Branch1_2][C][Oexpl][C@Hexpl][B@@-expl][\C-expl][P+expl][CH2-expl][=S@expl][=B][B-expl][SHexpl][C][O+expl][\I][P@expl][/C@@expl][P@Hexpl][N][C][C][B][P][\S@expl][P@@expl][Sn+2expl][Br+expl][\S@expl][C@expl][=N-expl][/B][=N+expl][/B-expl][\C@@expl][I+expl][\O][Oexpl][\Snexpl][C][C][C][=P][/CH-expl][C][C][C-expl][SH3expl][Branch1_3][Br][\C-expl][/N][/Cexpl][C][=S@@expl][C][N-expl][C][IH2expl][\Siexpl][Sn+3expl][C][N][S@@expl][=I][=S@+expl][P@expl][=O][/Br][#N][/P@@expl][=P][/C@@expl][\O-expl][Cl][/O][/S][B][N@@H+expl][Expl/Ring2][#C][I+expl][S@Hexpl][P@@Hexpl][/P@@expl][=N+expl][Sn+3expl][\C@expl][Branch1_2][/P@@expl][/C@expl][Br][\P][/I][\B][\O-expl][\C@expl]')) - -# Sampled SMILES -# C/NBr diff --git a/train.py b/train.py index 04b5ab0..4a2e649 100644 --- a/train.py +++ b/train.py @@ -39,8 +39,13 @@ dataloader, train_size = dataloader_gen( dataset_dir, percentage, vocab_path, batch_size, shuffle, drop_last=False) - # model and training configuration + # the padding idex for batching, it is set to global because + # the collate_fn of dataloader needs it. rnn_config = config['rnn_config'] + global PADDING_IDX + PADDING_IDX = rnn_config['num_embeddings'] - 1 + + # model and training configuration model = RNN(rnn_config).to(device) learning_rate = config['learning_rate'] weight_decay = config['weight_decay'] diff --git a/train.yaml b/train.yaml index 2de72a3..e84507f 100644 --- a/train.yaml +++ b/train.yaml @@ -1,12 +1,12 @@ --- out_dir: '../results/run_1/' dataset_dir: "../zinc-smiles/" - vocab_path: "./vocab.yaml" + vocab_path: "./vocab/selfies_vocab.yaml" percentage: 0.2 rnn_config: # embedding - num_embeddings: 174 # 173 from vocab.yaml + padding(0) + num_embeddings: 174 # size of vocab + + embedding_dim: 256 # rnn diff --git a/vocab.yaml b/vocab.yaml deleted file mode 100644 index 21fb5e4..0000000 --- a/vocab.yaml +++ /dev/null @@ -1,173 +0,0 @@ -: 121 -: 104 -'[#C-expl]': 17 -'[#C]': 7 -'[#N+expl]': 143 -'[#N]': 88 -'[#S]': 71 -'[/B-expl]': 91 -'[/B]': 32 -'[/Br]': 170 -'[/C-expl]': 70 -'[/C@@Hexpl]': 74 -'[/C@@expl]': 64 -'[/C@Hexpl]': 85 -'[/C@expl]': 81 -'[/CH-expl]': 155 -'[/C]': 150 -'[/Cexpl]': 59 -'[/Cl]': 63 -'[/F]': 163 -'[/I]': 110 -'[/N+expl]': 90 -'[/NHexpl]': 116 -'[/N]': 21 -'[/O+expl]': 142 -'[/O-expl]': 100 -'[/O]': 1 -'[/Oexpl]': 48 -'[/P@@expl]': 94 -'[/P@expl]': 126 -'[/P]': 78 -'[/S+expl]': 145 -'[/S@@expl]': 169 -'[/S@expl]': 98 -'[/S]': 159 -'[/Siexpl]': 107 -'[/Snexpl]': 79 -'[=17Oexpl]': 40 -'[=B]': 38 -'[=C]': 45 -'[=IH2expl]': 47 -'[=I]': 36 -'[=N+expl]': 113 -'[=N-expl]': 4 -'[=N]': 31 -'[=Nexpl]': 34 -'[=O+expl]': 18 -'[=O]': 172 -'[=P+expl]': 67 -'[=P@@Hexpl]': 25 -'[=P@@expl]': 97 -'[=P@Hexpl]': 108 -'[=P@expl]': 14 -'[=PHexpl]': 83 -'[=P]': 99 -'[=S+expl]': 154 -'[=S@+expl]': 118 -'[=S@@+expl]': 130 -'[=S@@expl]': 92 -'[=S@expl]': 9 -'[=SHexpl]': 77 -'[=S]': 80 -'[=Siexpl]': 129 -'[=Snexpl]': 171 -'[B-expl]': 23 -'[B@-expl]': 84 -'[B@@-expl]': 12 -'[BH-expl]': 86 -'[BH2-expl]': 11 -'[BH3-expl]': 53 -'[B]': 119 -'[Br+expl]': 114 -'[Br]': 41 -'[Branch1_1]': 149 -'[Branch1_2]': 55 -'[Branch1_3]': 6 -'[Branch2_1]': 10 -'[Branch2_2]': 134 -'[Branch2_3]': 5 -'[C+expl]': 127 -'[C-expl]': 115 -'[C@@Hexpl]': 168 -'[C@@expl]': 13 -'[C@Hexpl]': 162 -'[C@expl]': 165 -'[CH-expl]': 105 -'[CH2-expl]': 161 -'[CH2expl]': 50 -'[CHexpl]': 62 -'[C]': 57 -'[Cexpl]': 15 -'[Cl]': 20 -'[Expl/Ring1]': 141 -'[Expl/Ring2]': 96 -'[Expl=Ring1]': 26 -'[Expl=Ring2]': 140 -'[Expl\Ring1]': 30 -'[Expl\Ring2]': 122 -'[F]': 93 -'[I+expl]': 103 -'[IH2expl]': 102 -'[I]': 101 -'[N+expl]': 125 -'[N-expl]': 160 -'[N@+expl]': 39 -'[N@@+expl]': 153 -'[N@@H+expl]': 42 -'[N@H+expl]': 2 -'[NHexpl]': 60 -'[N]': 151 -'[Nexpl]': 43 -'[O+expl]': 44 -'[O-expl]': 128 -'[O]': 135 -'[Oexpl]': 139 -'[P+expl]': 156 -'[P@+expl]': 22 -'[P@@+expl]': 146 -'[P@@Hexpl]': 124 -'[P@@expl]': 138 -'[P@Hexpl]': 33 -'[P@expl]': 123 -'[PHexpl]': 52 -'[P]': 132 -'[Ring1]': 54 -'[Ring2]': 144 -'[S+expl]': 49 -'[S@+expl]': 89 -'[S@@+expl]': 106 -'[S@@Hexpl]': 131 -'[S@@expl]': 66 -'[S@Hexpl]': 75 -'[S@expl]': 173 -'[SH3expl]': 51 -'[SHexpl]': 157 -'[S]': 27 -'[Si-expl]': 58 -'[SiH3expl]': 147 -'[Siexpl]': 3 -'[Sn+2expl]': 76 -'[Sn+3expl]': 158 -'[Sn+expl]': 46 -'[SnH2+expl]': 65 -'[SnH2expl]': 152 -'[SnH4+2expl]': 35 -'[SnH6+3expl]': 8 -'[SnHexpl]': 136 -'[Snexpl]': 37 -'[\B]': 24 -'[\Br]': 95 -'[\C-expl]': 166 -'[\C@@Hexpl]': 61 -'[\C@@expl]': 68 -'[\C@Hexpl]': 120 -'[\C@expl]': 117 -'[\CH-expl]': 28 -'[\C]': 69 -'[\Cl]': 111 -'[\F]': 56 -'[\I]': 109 -'[\N+expl]': 82 -'[\NHexpl]': 29 -'[\N]': 16 -'[\O-expl]': 19 -'[\O]': 73 -'[\Oexpl]': 164 -'[\P]': 72 -'[\S+expl]': 167 -'[\S@@expl]': 133 -'[\S@expl]': 112 -'[\S]': 148 -'[\Siexpl]': 137 -'[\Snexpl]': 87 diff --git a/vocab.py b/vocab/selfies_vocab.py similarity index 76% rename from vocab.py rename to vocab/selfies_vocab.py index 5de61a4..c156987 100644 --- a/vocab.py +++ b/vocab/selfies_vocab.py @@ -15,27 +15,31 @@ def read_smiles_file(path, percentage): if __name__ == "__main__": - dataset_dir = "../zinc-smiles/" - output_vocab = "../vocab.yaml" + dataset_dir = "../../zinc-smiles/" + output_vocab = "./selfies_vocab.yaml" smiles_files = [f for f in listdir( dataset_dir) if isfile(join(dataset_dir, f))] all_selfies = [] for i, f in enumerate(smiles_files): - smiles = read_smiles_file(dataset_dir + f, 1) + smiles = read_smiles_file(dataset_dir + f, 0.0001) selfies = [sf.encoder(x) for x in smiles if sf.encoder(x) is not None] all_selfies.extend(selfies) print('{} out of {} files processed.'.format(i, len(smiles_files))) vocab = sf.get_alphabet_from_selfies(all_selfies) - vocab.add('') - vocab.add('') vocab_dict = {} for i, token in enumerate(vocab): - # reserve 0 for padding - vocab_dict[token] = i + 1 + vocab_dict[token] = i + + i += 1 + vocab_dict[''] = i + i += 1 + vocab_dict[''] = i + i += 1 + vocab_dict[''] = i with open(output_vocab, 'w') as f: yaml.dump(vocab_dict, f)