Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
shiwentao00 committed Feb 27, 2021
1 parent c6703be commit ee4b718
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import selfies as sf
import yaml
from torch.nn.utils.rnn import pad_sequence
from pad_idx import PADDING_IDX


def dataloader_gen(dataset_dir, percentage, vocab_path, batch_size, shuffle, drop_last=False):
Expand All @@ -23,7 +24,6 @@ def pad_collate(batch):
"""
Put the sequences of different lengths in a minibatch by paddding.
"""
global PADDING_IDX

lengths = [len(x) for x in batch]

Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.functional import softmax
from pad_idx import PADDING_IDX


class RNN(torch.nn.Module):
def __init__(self, rnn_config):
super(RNN, self).__init__()
global PADDING_IDX

self.embedding_layer = nn.Embedding(
num_embeddings=rnn_config['num_embeddings'],
Expand Down
10 changes: 10 additions & 0 deletions pad_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# the padding index (last index) for batching
import yaml

config_dir = "./train.yaml"
with open(config_dir, 'r') as f:
config = yaml.full_load(f)
rnn_config = config['rnn_config']
global PADDING_IDX
PADDING_IDX = rnn_config['num_embeddings'] - 1
print("padding index: ", PADDING_IDX)
7 changes: 1 addition & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@
dataloader, train_size = dataloader_gen(
dataset_dir, percentage, vocab_path, batch_size, shuffle, drop_last=False)

# the padding index (last index) for batching, it is set to global because
# the collate_fn of dataloader needs it.
rnn_config = config['rnn_config']
global PADDING_IDX
PADDING_IDX = rnn_config['num_embeddings'] - 1

# model and training configuration
rnn_config = config['rnn_config']
model = RNN(rnn_config).to(device)
learning_rate = config['learning_rate']
weight_decay = config['weight_decay']
Expand Down
6 changes: 3 additions & 3 deletions train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
out_dir: '../results/run_1/'
dataset_dir: "../zinc-smiles/"
vocab_path: "./vocab/selfies_vocab.yaml"
percentage: 0.2
percentage: 1.0

rnn_config:
# embedding
Expand All @@ -15,8 +15,8 @@
num_layers: 3
dropout: 0.6

batch_size: 512
batch_size: 1024
shuffle: True
num_epoch: 200
num_epoch: 30
learning_rate: 0.003
weight_decay: 0.0007
174 changes: 174 additions & 0 deletions vocab/selfies_vocab.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
<eos>: 171
<pad>: 173
<sos>: 172
'[#C-expl]': 108
'[#C]': 116
'[#N+expl]': 152
'[#N]': 40
'[#S]': 158
'[/B-expl]': 18
'[/B]': 26
'[/Br]': 84
'[/C-expl]': 7
'[/C@@Hexpl]': 17
'[/C@@expl]': 145
'[/C@Hexpl]': 87
'[/C@expl]': 27
'[/CH-expl]': 120
'[/C]': 29
'[/Cexpl]': 161
'[/Cl]': 165
'[/F]': 22
'[/I]': 70
'[/N+expl]': 86
'[/NHexpl]': 46
'[/N]': 25
'[/O+expl]': 90
'[/O-expl]': 67
'[/O]': 88
'[/Oexpl]': 151
'[/P@@expl]': 14
'[/P@expl]': 4
'[/P]': 96
'[/S+expl]': 92
'[/S@@expl]': 166
'[/S@expl]': 121
'[/S]': 80
'[/Siexpl]': 111
'[/Snexpl]': 68
'[=17Oexpl]': 110
'[=B]': 79
'[=C]': 20
'[=IH2expl]': 105
'[=I]': 77
'[=N+expl]': 125
'[=N-expl]': 43
'[=N]': 49
'[=Nexpl]': 19
'[=O+expl]': 109
'[=O]': 170
'[=P+expl]': 164
'[=P@@Hexpl]': 10
'[=P@@expl]': 107
'[=P@Hexpl]': 1
'[=P@expl]': 47
'[=PHexpl]': 148
'[=P]': 100
'[=S+expl]': 41
'[=S@+expl]': 48
'[=S@@+expl]': 73
'[=S@@expl]': 74
'[=S@expl]': 146
'[=SHexpl]': 150
'[=S]': 98
'[=Siexpl]': 157
'[=Snexpl]': 95
'[B-expl]': 57
'[B@-expl]': 2
'[B@@-expl]': 32
'[BH-expl]': 82
'[BH2-expl]': 39
'[BH3-expl]': 134
'[B]': 101
'[Br+expl]': 94
'[Br]': 5
'[Branch1_1]': 123
'[Branch1_2]': 44
'[Branch1_3]': 35
'[Branch2_1]': 128
'[Branch2_2]': 56
'[Branch2_3]': 122
'[C+expl]': 23
'[C-expl]': 144
'[C@@Hexpl]': 131
'[C@@expl]': 167
'[C@Hexpl]': 72
'[C@expl]': 114
'[CH-expl]': 147
'[CH2-expl]': 51
'[CH2expl]': 12
'[CHexpl]': 135
'[C]': 140
'[Cexpl]': 83
'[Cl]': 65
'[Expl/Ring1]': 91
'[Expl/Ring2]': 138
'[Expl=Ring1]': 30
'[Expl=Ring2]': 154
'[Expl\Ring1]': 133
'[Expl\Ring2]': 50
'[F]': 59
'[I+expl]': 75
'[IH2expl]': 54
'[I]': 76
'[N+expl]': 24
'[N-expl]': 63
'[N@+expl]': 169
'[N@@+expl]': 6
'[N@@H+expl]': 163
'[N@H+expl]': 13
'[NHexpl]': 97
'[N]': 15
'[Nexpl]': 156
'[O+expl]': 104
'[O-expl]': 99
'[O]': 33
'[Oexpl]': 153
'[P+expl]': 160
'[P@+expl]': 89
'[P@@+expl]': 9
'[P@@Hexpl]': 62
'[P@@expl]': 21
'[P@Hexpl]': 136
'[P@expl]': 28
'[PHexpl]': 3
'[P]': 81
'[Ring1]': 115
'[Ring2]': 36
'[S+expl]': 113
'[S@+expl]': 58
'[S@@+expl]': 139
'[S@@Hexpl]': 130
'[S@@expl]': 118
'[S@Hexpl]': 53
'[S@expl]': 168
'[SH3expl]': 159
'[SHexpl]': 132
'[S]': 34
'[Si-expl]': 117
'[SiH3expl]': 42
'[Siexpl]': 55
'[Sn+2expl]': 11
'[Sn+3expl]': 119
'[Sn+expl]': 61
'[SnH2+expl]': 103
'[SnH2expl]': 142
'[SnH4+2expl]': 8
'[SnH6+3expl]': 69
'[SnHexpl]': 143
'[Snexpl]': 149
'[\B]': 102
'[\Br]': 124
'[\C-expl]': 106
'[\C@@Hexpl]': 52
'[\C@@expl]': 137
'[\C@Hexpl]': 129
'[\C@expl]': 78
'[\CH-expl]': 16
'[\C]': 162
'[\Cl]': 66
'[\F]': 112
'[\I]': 141
'[\N+expl]': 31
'[\NHexpl]': 127
'[\N]': 37
'[\O-expl]': 38
'[\O]': 71
'[\Oexpl]': 60
'[\P]': 0
'[\S+expl]': 45
'[\S@@expl]': 93
'[\S@expl]': 126
'[\S]': 85
'[\Siexpl]': 155
'[\Snexpl]': 64

0 comments on commit ee4b718

Please sign in to comment.