-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
110 lines (96 loc) · 4.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
from transformers import AutoTokenizer
import torch.optim as optim
from model.citation_model import *
from utils.scheduler import WarmupMultiStepLR
from train_valid.dataset_train import dataset_train
from train_valid.dataset_valid import dataset_valid
from utils.dataload import *
from utils.util import *
from sklearn.metrics import classification_report, confusion_matrix
import optuna
import time
import argparse
import json
parser = argparse.ArgumentParser()
parser.add_argument("--mode", help="decide find parameters or train", default=None, type=str)
parser.add_argument("--tp", help="type of params", default=None, type=str)
parser.add_argument("--dataname", help="dataname", default=None, type=str)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
def run_optuna(params, path, dev):
print('Run optuna')
setup_seed(0)
token = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
criterion = nn.CrossEntropyLoss()
dataset = load_data(16, params.dataname, radio=0.8)
def objective(trial):
model = Model('allenai/scibert_scivocab_uncased')
# n_epoch = trial.suggest_int('n_epoch', 140, 170, log=True)
lr = trial.suggest_float('lr', 1e-4, 1e-3, log=True)
auw = trial.suggest_float('auw', 0.001, 0.01, log=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4)
scheduler = WarmupMultiStepLR(optimizer, [90, 110], gamma=0.1, warmup_epochs=5)
best_model_f1, best_epoch = dataset_train(model, token, dataset, criterion, optimizer, 151, auw, dev,
scheduler, model_path=path)
return best_model_f1
study = optuna.create_study(study_name='studyname', direction='maximize', storage='sqlite:///optuna.db', load_if_exists=True)
study.optimize(objective, n_trials=5)
print("Best_Params:{} \t Best_Value:{}".format(study.best_params, study.best_value))
history = study.trials_dataframe(attrs=('number', 'value', 'params', 'state'))
print(history)
args.lr = float(format(study.best_params['lr'], '.6f'))
args.auw = float(format(study.best_params['auw'], '.6f'))
main_run(args, 'citation_mul_rev_model.pth', device)
def main_run(params, path, dev):
setup_seed(0)
token = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = Model('allenai/scibert_scivocab_uncased')
criterion = nn.CrossEntropyLoss()
lr = 0.000184
au_weight = 0.007413
n_epoch = 151
dataset = load_data(16, params.dataname, radio=0.8)
optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=0.9, weight_decay=2e-4)
scheduler = WarmupMultiStepLR(optimizer, [90, 110], gamma=0.1, warmup_epochs=5)
best_model_f1, best_epoch = dataset_train(model, token, dataset, criterion, optimizer, n_epoch, params.auw, dev,
scheduler, model_path=path)
# best_model_f1 = 0.11111
# best_epoch = 20
print("best_model_f1:{} \t best_epoch:{}".format(best_model_f1, best_epoch))
test_f1, test_micro_f1, test_true_label, test_pre_label = dataset_valid(model, token,
dataset['test'], device,
mode='test', path=path)
print('Test'.center(20, '='))
print('Test_True_Label:', collections.Counter(test_true_label))
print('Test_Pre_Label:', collections.Counter(test_pre_label))
print('Test macro F1: %.4f \t Test micro F1: %.4f' % (test_f1, test_micro_f1))
print('Test'.center(20, '='))
test_true = torch.Tensor(test_true_label).tolist()
test_pre = torch.Tensor(test_pre_label).tolist()
generate_submission(test_pre, 'mul_rev_val_f1_{:.5}_best_epoch_{}'.format(best_model_f1, best_epoch), test_f1, params.dataname)
# labels = None
if params.dataname == 'scicite':
labels = [0, 1, 2]
else:
labels = [0, 1, 2, 3, 4, 5]
c_matrix = confusion_matrix(test_true, test_pre, labels=labels)
per_eval = classification_report(test_true, test_pre, labels=labels)
log_result(test_f1, best_model_f1, c_matrix, per_eval, lr=params.lr, epoch=n_epoch, fun_name='main_multi_rev')
if __name__ == "__main__":
args = parser.parse_args()
tst = time.time()
modelpath = "citation_mul_rev_model.pth"
if args.dataname == "ACT":
modelpath = "bbnACT"
elif args.dataname == "ACL":
modelpath = "bbn for acl"
if args.mode =='optuna':
run_optuna(args, 'citation_mul_rev_model.pth', device)
else:
with open('params.json', 'r', encoding='utf-8') as f:
config = json.load(f)
args.lr = config[args.dataname][args.tp]['lr']
args.auw = config[args.dataname][args.tp]['auw']
main_run(args, 'modelpath', device)
ten = time.time()
print('Total time: {}'.format((ten - tst)))