-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathmain.py
124 lines (102 loc) · 4.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import utils
import config
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from train import train, test, translate
from data_loader import MTDataset
from utils import english_tokenizer_load
from model import make_model, LabelSmoothing
class NoamOpt:
"""Optim wrapper that implements rate."""
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"""Update parameters and rate"""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above"""
if step is None:
step = self._step
return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
def get_std_opt(model):
"""for batch_size 32, 5530 steps for one epoch, 2 epoch for warm-up"""
return NoamOpt(model.src_embed[0].d_model, 1, 10000,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
def run():
utils.set_logger(config.log_path)
train_dataset = MTDataset(config.train_data_path)
dev_dataset = MTDataset(config.dev_data_path)
test_dataset = MTDataset(config.test_data_path)
logging.info("-------- Dataset Build! --------")
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size,
collate_fn=train_dataset.collate_fn)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=config.batch_size,
collate_fn=dev_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size,
collate_fn=test_dataset.collate_fn)
logging.info("-------- Get Dataloader! --------")
# 初始化模型
model = make_model(config.src_vocab_size, config.tgt_vocab_size, config.n_layers,
config.d_model, config.d_ff, config.n_heads, config.dropout)
model_par = torch.nn.DataParallel(model)
# 训练
if config.use_smoothing:
criterion = LabelSmoothing(size=config.tgt_vocab_size, padding_idx=config.padding_idx, smoothing=0.1)
criterion.cuda()
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='sum')
if config.use_noamopt:
optimizer = get_std_opt(model)
else:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
train(train_dataloader, dev_dataloader, model, model_par, criterion, optimizer)
test(test_dataloader, model, criterion)
def check_opt():
"""check learning rate changes"""
import numpy as np
import matplotlib.pyplot as plt
model = make_model(config.src_vocab_size, config.tgt_vocab_size, config.n_layers,
config.d_model, config.d_ff, config.n_heads, config.dropout)
opt = get_std_opt(model)
# Three settings of the lrate hyperparameters.
opts = [opt,
NoamOpt(512, 1, 20000, None),
NoamOpt(256, 1, 10000, None)]
plt.plot(np.arange(1, 50000), [[opt.rate(i) for opt in opts] for i in range(1, 50000)])
plt.legend(["512:10000", "512:20000", "256:10000"])
plt.show()
def one_sentence_translate(sent, beam_search=True):
# 初始化模型
model = make_model(config.src_vocab_size, config.tgt_vocab_size, config.n_layers,
config.d_model, config.d_ff, config.n_heads, config.dropout)
BOS = english_tokenizer_load().bos_id() # 2
EOS = english_tokenizer_load().eos_id() # 3
src_tokens = [[BOS] + english_tokenizer_load().EncodeAsIds(sent) + [EOS]]
batch_input = torch.LongTensor(np.array(src_tokens)).to(config.device)
translate(batch_input, model, use_beam=beam_search)
def translate_example():
"""单句翻译示例"""
sent = "The near-term policy remedies are clear: raise the minimum wage to a level that will keep a " \
"fully employed worker and his or her family out of poverty, and extend the earned-income tax credit " \
"to childless workers."
# tgt: 近期的政策对策很明确:把最低工资提升到足以一个全职工人及其家庭免于贫困的水平,扩大对无子女劳动者的工资所得税减免。
one_sentence_translate(sent, beam_search=True)
if __name__ == "__main__":
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3'
import warnings
warnings.filterwarnings('ignore')
# run()
translate_example()