-
Notifications
You must be signed in to change notification settings - Fork 48
/
main.py
108 lines (84 loc) · 3.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding:utf-8 -*-
# @author: 木子川
# @Email: [email protected]
# @VX:fylaicai
import os
import time
from tqdm import tqdm
from config import parsers
from utils import read_data, MyDataset
from torch.utils.data import DataLoader
from model import BertTextModel_encode_layer, BertTextModel_last_layer
import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score
def train(model, device, trainLoader, opt, epoch):
model.train()
loss_sum, count = 0, 0
for batch_index, batch_con in enumerate(trainLoader):
batch_con = tuple(p.to(device) for p in batch_con)
pred = model(batch_con)
opt.zero_grad()
loss = loss_fn(pred, batch_con[-1])
loss.backward()
opt.step()
loss_sum += loss
count += 1
if len(trainLoader) - batch_index <= len(trainLoader) % 1000 and count == len(trainLoader) % 1000:
msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
loss_sum, count = 0.0, 0
if batch_index % 1000 == 999:
msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
loss_sum, count = 0.0, 0
def dev(model, device, devLoader, save_best):
global acc_min
model.eval()
all_true, all_pred = [], []
for batch_con in tqdm(devLoader):
batch_con = tuple(p.to(device) for p in batch_con)
pred = model(batch_con)
pred = torch.argmax(pred, dim=1)
pred_label = pred.cpu().numpy().tolist()
true_label = batch_con[-1].cpu().numpy().tolist()
all_true.extend(true_label)
all_pred.extend(pred_label)
acc = accuracy_score(all_true, all_pred)
print(f"dev acc:{acc:.4f}")
if acc > acc_min:
acc_min = acc
torch.save(model.state_dict(), save_best)
print(f"以保存最佳模型")
if __name__ == "__main__":
start = time.time()
args = parsers()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
train_text, train_label = read_data(args.train_file)
dev_text, dev_label = read_data(args.dev_file)
trainData = MyDataset(train_text, train_label, with_labels=True)
trainLoader = DataLoader(trainData, batch_size=args.batch_size, shuffle=True)
devData = MyDataset(dev_text, dev_label, with_labels=True)
devLoader = DataLoader(devData, batch_size=args.batch_size, shuffle=True)
root, name = os.path.split(args.save_model_best)
save_best = os.path.join(root, str(args.select_model_last) + "_" +name)
root, name = os.path.split(args.save_model_last)
save_last = os.path.join(root, str(args.select_model_last) + "_" +name)
# 选择模型
if args.select_model_last:
# 模型1
model = BertTextModel_last_layer().to(device)
else:
# 模型2
model = BertTextModel_encode_layer().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = CrossEntropyLoss()
acc_min = float("-inf")
for epoch in range(args.epochs):
train(model, device, trainLoader, opt, epoch)
dev(model, device, devLoader, save_best)
model.eval()
torch.save(model.state_dict(), save_last)
end = time.time()
print(f"运行时间:{(end-start)/60%60:.4f} min")