forked from studio-ousia/luke
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
323 lines (279 loc) · 14.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import os
import sys
import yaml
import json
import time
import logging
import argparse
import functools
import subprocess
logging.getLogger("transformers").setLevel(logging.WARNING)
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import RobertaTokenizer
from transformers import WEIGHTS_NAME
from luke.utils.model_utils import ModelArchive
from luke.utils.entity_vocab import MASK_TOKEN, PAD_TOKEN
from entity_disambiguation.model import LukeForEntityDisambiguation
from utils import set_seed
from utils.evaluate import evaluate
from utils.trainer import Trainer
from utils.dataset import EntityDisambiguationDataset, convert_documents_to_features
import wandb
logger = logging.getLogger(__name__)
LOG_FORMAT = "[%(asctime)s] [%(levelname)s] %(message)s (%(funcName)s@%(filename)s:%(lineno)s)"
def main():
wandb.init(project="reproduce-luke-ed", entity="xuzf")
parser = argparse.ArgumentParser()
# parameters
## experiment path
parser.add_argument("--model-file", type=str, help='path of pre-tained model weights and config')
parser.add_argument("--data-dir", type=str, help='path of training or evaluate data')
parser.add_argument("--output-dir",
default="./output/run-" + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()),
help='output directory')
## train/eval/test scope
parser.add_argument("--do_train", action='store_true', help='finetune model on Conll2003')
parser.add_argument("--do_eval", action='store_true', help='eval on testb')
parser.add_argument("--do_test", action='store_true', help='test on the chosen datasets')
parser.add_argument("--test-set",
nargs='+',
help='multi-select the test dataset',
choices=["test_b", "test_b_ppr", "ace2004", "aquaint", "msnbc", "wikipedia", "clueweb"])
## hardware setting
parser.add_argument("--master-port", type=int)
parser.add_argument("--num-gpus", type=int, help='the number of gpus')
parser.add_argument("--local-rank",
type=int,
help='local_rank is used in the case of multiple gpus,\
so local_rank is valid only when num_gpus > 0.\
it exactly represents the GPU index used in the current process,\
(local_rank=-1) means to use the default cuda; \
(local_rank>=0) means to use corresponding GPU.')
## hyperparameter
parser.add_argument("--num-train-epochs", type=int)
parser.add_argument("--train-batch-size", type=int)
parser.add_argument("--max-seq-length", type=int)
parser.add_argument("--max-candidate-length", type=int)
parser.add_argument("--masked-entity-prob", type=int)
parser.add_argument("--document-split-mode", choices=["simple", "per_mention"])
parser.add_argument("--update-entity-emb", action='store_true', help='default fixed entity embedding')
parser.add_argument("--update-entity-bias", action='store_true', help='default fixed entity bias')
parser.add_argument("--seed")
parser.add_argument("--no-context-entities", action='store_true', help='context entities is used by default')
parser.add_argument("--context-entity-selection-order",
choices=["natural", "random", "highest_prob"],
help='order of entity disambiguation')
# 默认为 config.json 中的配置,命令行参数优先级更高
with open('./config.yaml', 'r') as fconfig:
args_config = yaml.load(fconfig, Loader=yaml.FullLoader)
args = parser.parse_args()
if 'train_args' in args_config.keys():
train_args = args_config['train_args']
else:
train_args = dict({})
for key, value in vars(args).items():
if value is None or value is False:
continue
if key in train_args.keys():
train_args[key] = value
else:
args_config[key] = value
args_config['train_args'] = train_args
args_config_log = args_config.copy()
args_config.pop('train_args')
args_config.update(train_args)
args = argparse.Namespace(**args_config)
# 记录当前 run 的参数设置
wandb.config.update(args)
if args.local_rank == -1 and args.num_gpus > 1: ## 单机多卡——进程处理
current_env = os.environ.copy()
current_env["MASTER_ADDR"] = "127.0.0.1"
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(args.num_gpus)
processes = []
# local_rank 从 0 到 num_gpus 创建 num_gpus 个子进程
for args.local_rank in range(0, args.num_gpus):
current_env["RANK"] = str(args.local_rank)
current_env["LOCAL_RANK"] = str(args.local_rank)
cmd = [sys.executable, "-u", "-m", "main.py", "--local-rank={}".format(args.local_rank)]
cmd.extend(sys.argv[1:])
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
for process in processes:
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
sys.exit(0)
else: ## 单卡进程处理
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
if args.local_rank not in (-1, 0):
logging.basicConfig(format=LOG_FORMAT, level=logging.WARNING)
else:
logging.basicConfig(format=LOG_FORMAT, level=logging.INFO)
fh = logging.FileHandler(filename=os.path.join(args.output_dir, 'run.log'), mode='w+', encoding='utf-8')
fh.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(fh)
# 保存该次运行的配置
logger.info("Output dir: %s", args.output_dir)
with open(os.path.join(args.output_dir, 'config.yaml'), 'w+') as fconfig:
yaml.dump(args_config_log, fconfig)
logger.info("Save config: %s", os.path.join(args.output_dir, 'config.yaml'))
# 指定使用的 GPU/CPU
if args.num_gpus == 0:
args.device = torch.device("cpu")
elif args.local_rank == -1:
args.device = torch.device("cuda")
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
args.device = torch.device("cuda")
# 加载预训练模型
if args.model_file:
model_archive = ModelArchive.load(args.model_file)
args.entity_vocab = model_archive.entity_vocab
args.bert_model_name = model_archive.bert_model_name
if model_archive.bert_model_name.startswith("roberta"):
# the current example code does not support the fast tokenizer
args.tokenizer = RobertaTokenizer.from_pretrained(model_archive.bert_model_name)
else:
args.tokenizer = model_archive.tokenizer
args.model_config = model_archive.config
args.max_mention_length = model_archive.max_mention_length
args.model_weights = model_archive.state_dict
run(args=args)
def run(args: argparse.Namespace):
set_seed(args.seed)
# 读入数据集
dataset = EntityDisambiguationDataset(args.data_dir)
# 获取全部实体的title
entity_titles = []
for data in dataset.get_all_datasets():
for document in data:
for mention in document.mentions:
entity_titles.append(mention.title)
for candidate in mention.candidates:
entity_titles.append(candidate.title)
entity_titles = frozenset(entity_titles)
# 构建entity词典
entity_vocab = {PAD_TOKEN: 0, MASK_TOKEN: 1}
for n, title in enumerate(sorted(entity_titles), 2):
entity_vocab[title] = n
# 根据 model_config 和 args 构建模型
model_config = args.model_config
model_config.entity_vocab_size = len(entity_vocab)
model_weights = args.model_weights
orig_entity_vocab = args.entity_vocab
orig_entity_emb = model_weights["entity_embeddings.entity_embeddings.weight"]
if orig_entity_emb.size(0) != len(entity_vocab): # detect whether the model is fine-tuned
entity_emb = orig_entity_emb.new_zeros((len(entity_titles) + 2, model_config.hidden_size))
orig_entity_bias = model_weights["entity_predictions.bias"]
entity_bias = orig_entity_bias.new_zeros(len(entity_titles) + 2)
for title, index in entity_vocab.items():
if title in orig_entity_vocab:
orig_index = orig_entity_vocab[title]
entity_emb[index] = orig_entity_emb[orig_index]
entity_bias[index] = orig_entity_bias[orig_index]
model_weights["entity_embeddings.entity_embeddings.weight"] = entity_emb
model_weights["entity_embeddings.mask_embedding"] = entity_emb[1].view(1, -1)
model_weights["entity_predictions.decoder.weight"] = entity_emb
model_weights["entity_predictions.bias"] = entity_bias
del orig_entity_bias, entity_emb, entity_bias
del orig_entity_emb
model = LukeForEntityDisambiguation(model_config)
model.load_state_dict(model_weights, strict=False)
model.to(args.device)
wandb.watch(model, log_graph=True)
def collate_fn(batch, is_eval=False):
def create_padded_sequence(attr_name, padding_value):
tensors = [torch.tensor(getattr(o, attr_name), dtype=torch.long) for o in batch]
return torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)
ret = dict(
word_ids=create_padded_sequence("word_ids", args.tokenizer.pad_token_id),
word_segment_ids=create_padded_sequence("word_segment_ids", 0),
word_attention_mask=create_padded_sequence("word_attention_mask", 0),
entity_ids=create_padded_sequence("entity_ids", 0),
entity_position_ids=create_padded_sequence("entity_position_ids", -1),
entity_segment_ids=create_padded_sequence("entity_segment_ids", 0),
entity_attention_mask=create_padded_sequence("entity_attention_mask", 0),
)
ret["entity_candidate_ids"] = create_padded_sequence("entity_candidate_ids", 0)
if is_eval:
ret["document"] = [o.document for o in batch]
ret["mentions"] = [o.mentions for o in batch]
ret["target_mention_indices"] = [o.target_mention_indices for o in batch]
return ret
if args.do_train:
train_data = convert_documents_to_features(
dataset.train,
args.tokenizer,
entity_vocab,
"train",
"simple",
args.max_seq_length,
args.max_candidate_length,
args.max_mention_length,
)
train_dataloader = DataLoader(train_data, batch_size=args.train_batch_size, collate_fn=collate_fn, shuffle=True)
logger.info("Update entity embeddings during training: %s", args.update_entity_emb)
if not args.update_entity_emb:
model.entity_embeddings.entity_embeddings.weight.requires_grad = False
logger.info("Update entity bias during training: %s", args.update_entity_bias)
if not args.update_entity_bias:
model.entity_predictions.bias.requires_grad = False
num_train_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
trainer = EntityDisambiguationTrainer(args, model, train_dataloader, num_train_steps)
trainer.train()
if args.output_dir:
logger.info("Saving model to %s", args.output_dir)
torch.save(model.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
results = {}
if args.do_eval:
model.eval()
for dataset_name in set(args.test_set):
print("***** Dataset: %s *****" % dataset_name)
eval_documents = getattr(dataset, dataset_name)
eval_data = convert_documents_to_features(
eval_documents,
args.tokenizer,
entity_vocab,
"eval",
args.document_split_mode,
args.max_seq_length,
args.max_candidate_length,
args.max_mention_length,
)
eval_dataloader = DataLoader(eval_data,
batch_size=1,
collate_fn=functools.partial(collate_fn, is_eval=True))
predictions_file = None
if args.output_dir:
predictions_file = os.path.join(args.output_dir, "eval_predictions_%s.jsonl" % dataset_name)
results[dataset_name] = evaluate(args, eval_dataloader, model, entity_vocab, predictions_file)
f1 = results[dataset_name]['f1']
precision = results[dataset_name]['precision']
recall = results[dataset_name]['recall']
writer = SummaryWriter(os.path.join(args.output_dir, 'tensorboard/eval_' + dataset_name))
writer.add_histogram('f1/', f1)
writer.add_histogram('percision/', precision)
writer.add_histogram('recall/', recall)
wandb.log({dataset_name+'/precision': precision, dataset_name+'/recall': recall, dataset_name+'/f1': f1})
if args.output_dir:
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as f:
json.dump(results, f, indent=2, sort_keys=True)
logger.info('tensorboard --logdir=' + args.output_dir)
return results
class EntityDisambiguationTrainer(Trainer):
def _create_model_arguments(self, batch):
batch["entity_labels"] = batch["entity_ids"].clone()
for index, entity_length in enumerate(batch["entity_attention_mask"].sum(1).tolist()):
masked_entity_length = max(1, round(entity_length * self.args.masked_entity_prob))
permutated_indices = torch.randperm(entity_length)[:masked_entity_length]
batch["entity_ids"][index, permutated_indices[:masked_entity_length]] = 1 # [MASK]
batch["entity_labels"][index, permutated_indices[masked_entity_length:]] = -1
return batch
if __name__ == "__main__":
main()