Skip to content

Commit

Permalink
修正一些不规范表达
Browse files Browse the repository at this point in the history
  • Loading branch information
gaohongkui committed Mar 23, 2023
1 parent 045fb7c commit 64d5cb8
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 35 deletions.
7 changes: 0 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
<!--
* @Date: 2021-06-24 15:57:09
* @LastEditors: GodK
* @LastEditTime: 2021-07-27 22:33:36
-->

# GlobalPointer_pytorch

> 喜欢本项目的话,欢迎点击右上角的star,感谢每一个点赞的你。
Expand Down Expand Up @@ -66,4 +60,3 @@ python evaluate.py
默认配置(超参数已在 `config.py` 文件中),数据集是 CLUENER

* 验证集 Best F1:0.7966
* 测试集 F1:
2 changes: 1 addition & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch


def multilabel_categorical_crossentropy(y_pred, y_true):
def multilabel_categorical_crossentropy(y_true, y_pred):
"""
https://kexue.fm/archives/7359
"""
Expand Down
12 changes: 6 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"exp_name": "cluener",
"encoder": "BERT",
"data_home": "./datasets",
"bert_path": "./pretrained_models/bert-base-chinese", # bert-base-cased, bert-base-chinese
"bert_path": "./pretrained_models/bert-base-chinese", # bert-base-chinese or other plm from https://huggingface.co/models
"run_type": "train", # train, eval
"f1_2_save": 0.5, # 存模型的最低f1值
"logger": "default" # wandb or default,default意味着只输出日志到控制台
Expand All @@ -23,24 +23,24 @@
train_config = {
"train_data": "train.json",
"valid_data": "dev.json",
"test_data": "test.json",
"test_data": "dev.json",
"ent2id": "ent2id.json",
"path_to_save_model": "./outputs", # 在logger不是wandb时生效
"hyper_parameters": {
"lr": 5e-5,
"lr": 2e-5,
"batch_size": 64,
"epochs": 50,
"seed": 2333,
"max_seq_len": 128,
"scheduler": "CAWR"
"scheduler": "CAWR" # CAWR, Step, None
}
}

eval_config = {
"model_state_dir": "./outputs/cluener/", # 预测时注意填写模型路径(时间tag文件夹)
"run_id": "",
"last_k_model": 1, # 取倒数第几个model_state
"test_data": "test.json",
"predict_data": "test.json",
"ent2id": "ent2id.json",
"save_res_dir": "./results",
"hyper_parameters": {
Expand All @@ -58,7 +58,7 @@
step_scheduler = {
# StepLR
"decay_rate": 0.999,
"decay_steps": 100,
"decay_steps": 200,
}

# ---------------------------------------------
Expand Down
27 changes: 13 additions & 14 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"], add_special_tokens=True, do_lower_case=False)


def load_data(data_path, data_type="test"):
if data_type == "test":
def load_data(data_path, data_type="predict"):
if data_type == "predict":
datas = []
with open(data_path, encoding="utf-8") as f:
for line in f:
Expand All @@ -43,16 +43,16 @@ def load_data(data_path, data_type="test"):
ent_type_size = len(ent2id)


def data_generator(data_type="test"):
def data_generator(data_type="predict"):
"""
读取数据,生成DataLoader。
"""

if data_type == "test":
test_data_path = os.path.join(config["data_home"], config["exp_name"], config["test_data"])
test_data = load_data(test_data_path, "test")
if data_type == "predict":
predict_data_path = os.path.join(config["data_home"], config["exp_name"], config["predict_data"])
predict_data = load_data(predict_data_path, "predict")

all_data = test_data
all_data = predict_data

# TODO:句子截取
max_tok_num = 0
Expand All @@ -65,17 +65,16 @@ def data_generator(data_type="test"):

data_maker = DataMaker(tokenizer)

if data_type == "test":
# test_inputs = data_maker.generate_inputs(test_data, max_seq_len, ent2id, data_type="test")
test_dataloader = DataLoader(MyDataset(test_data),
if data_type == "predict":
predict_dataloader = DataLoader(MyDataset(predict_data),
batch_size=hyper_parameters["batch_size"],
shuffle=False,
num_workers=config["num_workers"],
drop_last=False,
collate_fn=lambda x: data_maker.generate_batch(x, max_seq_len, ent2id,
data_type="test")
data_type="predict")
)
return test_dataloader
return predict_dataloader


def decode_ent(text, pred_matrix, tokenizer, threshold=0):
Expand Down Expand Up @@ -137,11 +136,11 @@ def load_model():


def evaluate():
test_dataloader = data_generator(data_type="test")
predict_dataloader = data_generator(data_type="predict")

model = load_model()

predict_res = predict(test_dataloader, model)
predict_res = predict(predict_dataloader, model)

if not os.path.exists(os.path.join(config["save_res_dir"], config["exp_name"])):
os.mkdir(os.path.join(config["save_res_dir"], config["exp_name"]))
Expand Down
8 changes: 4 additions & 4 deletions models/GlobalPointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def generate_inputs(self, datas, max_seq_len, ent2id, data_type="train"):
)

labels = None
if data_type != "test":
if data_type != "predict":
ent2token_spans = self.preprocessor.get_ent2token_spans(
sample["text"], sample["entity_list"]
)
Expand Down Expand Up @@ -89,13 +89,13 @@ def generate_batch(self, batch_data, max_seq_len, ent2id, data_type="train"):
input_ids_list.append(sample[1])
attention_mask_list.append(sample[2])
token_type_ids_list.append(sample[3])
if data_type != "test":
if data_type != "predict":
labels_list.append(sample[4])

batch_input_ids = torch.stack(input_ids_list, dim=0)
batch_attention_mask = torch.stack(attention_mask_list, dim=0)
batch_token_type_ids = torch.stack(token_type_ids_list, dim=0)
batch_labels = torch.stack(labels_list, dim=0) if data_type != "test" else None
batch_labels = torch.stack(labels_list, dim=0) if data_type != "predict" else None

return sample_list, batch_input_ids, batch_attention_mask, batch_token_type_ids, batch_labels

Expand Down Expand Up @@ -173,7 +173,7 @@ def forward(self, input_ids, attention_mask, token_type_ids):
# outputs:(batch_size, seq_len, ent_type_size, inner_dim*2)
outputs = torch.stack(outputs, dim=-2)
# qw,kw:(batch_size, seq_len, ent_type_size, inner_dim)
qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:] # TODO:修改为Linear获取?
qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]

if self.RoPE:
# pos_emb:(batch_size, seq_len, inner_dim)
Expand Down
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

model_state_dict_dir = wandb.run.dir
logger = wandb
else:
elif config["run_type"] == "train":
model_state_dict_dir = os.path.join(config["path_to_save_model"], config["exp_name"],
time.strftime("%Y-%m-%d_%H.%M.%S", time.gmtime()))
if not os.path.exists(model_state_dict_dir):
Expand Down Expand Up @@ -160,7 +160,7 @@ def train_step(batch_train, model, optimizer, criterion):

logits = model(batch_input_ids, batch_attention_mask, batch_token_type_ids)

loss = criterion(logits, batch_labels)
loss = criterion(batch_labels, logits)

optimizer.zero_grad()
loss.backward()
Expand Down Expand Up @@ -203,6 +203,8 @@ def loss_fun(y_true, y_pred):
decay_rate = hyper_parameters["decay_rate"]
decay_steps = hyper_parameters["decay_steps"]
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=decay_steps, gamma=decay_rate)
else:
scheduler = None

pbar = tqdm(enumerate(dataloader), total=len(dataloader))
total_loss = 0.
Expand All @@ -213,7 +215,8 @@ def loss_fun(y_true, y_pred):
total_loss += loss

avg_loss = total_loss / (batch_ind + 1)
scheduler.step()
if scheduler is not None:
scheduler.step()

pbar.set_description(
f'Project:{config["exp_name"]}, Epoch: {epoch + 1}/{hyper_parameters["epochs"]}, Step: {batch_ind + 1}/{len(dataloader)}')
Expand Down

0 comments on commit 64d5cb8

Please sign in to comment.