-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdialogpt.py
145 lines (117 loc) · 6.31 KB
/
dialogpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pytorch_lightning as pl
import transformers
from transformers import GPT2DoubleHeadsModel
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import logging
import math
logger = logging.getLogger(__file__)
try:
import wandb
except ImportError:
logger.warning("Unable to import wandb. Table-level logging will not work -- only inference, or training with no logging will work")
from pytorch_lightning.metrics import Accuracy
from load_data import SPEAKER1_ID, SPECIAL_TOKENS, MAX_GPT2_LENGTH
from utils import MODEL_INPUTS, PAD_VALUE
class HuggingFaceModel(pl.LightningModule):
def __init__(self, model_name, config, tokenizer):
super().__init__()
# todo: validate config structure
self.config = config
self.model_name = model_name
self.model = GPT2DoubleHeadsModel.from_pretrained(model_name)
self.tokenizer = tokenizer
self.model.resize_token_embeddings(len(tokenizer))
self.curr_eval_table = []
self.accuracy = Accuracy()
def configure_optimizers(self):
opt_config = self.config["optimizer"]
opt_name = opt_config["name"]
if hasattr(optim, opt_config["name"]):
try: # Default: PyTorch optimizer
optimizer = getattr(optim, opt_name)(self.model.parameters(), **opt_config["kwargs"]) # must include LR, for one
except TypeError: # possibly a transformers optimizer (AdamW)
optimizer = getattr(transformers, opt_name)(self.model.parameters(), **opt_config["kwargs"])
else:
raise Exception('Unexpected learning algorithm "{}"'.format(opt_name))
scheduler_config = self.config["scheduler"]
scheduler = {
'scheduler': OneCycleLR(optimizer, opt_config["kwargs"]["lr"], **scheduler_config), # todo: don't hardcode this
'interval': 'step'
}
return [optimizer], [scheduler]
def forward(self, batch):
batch[1] = batch[1].squeeze(-1) # mc_token_ids
batch[3] = batch[3].squeeze(-1) # mc_labels
inputs = dict(zip(MODEL_INPUTS, batch))
return self.model(**inputs)
def training_step(self, batch, batch_idx):
# model type: GPT2LMHEadModel (https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel)
train_config = self.config["train"]
outputs = self(batch)
lm_loss, mc_loss, _, mc_logits = outputs[:4]
loss = lm_loss * train_config["lm_weight"] + mc_loss * train_config["mc_weight"]
mc_acc = self.accuracy(mc_logits, batch[MODEL_INPUTS.index("mc_labels")])
self.log('loss', loss)
self.log('mc_acc', mc_acc, prog_bar=False, on_epoch=True)
self.log('lm_loss', lm_loss, prog_bar=True)
self.log('ppl', math.exp(lm_loss), prog_bar=True)
self.log('mc_loss', mc_loss, prog_bar=True)
return {'loss': loss}
def eval_step(self, batch, batch_idx):
bos, eos, speaker1, speaker2 = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
train_config = self.config["train"]
outputs = self(batch)
lm_loss, mc_loss, _, mc_logits = outputs[:4]
loss = lm_loss * train_config["lm_weight"] + mc_loss * train_config["mc_weight"]
mc_labels = batch[MODEL_INPUTS.index("mc_labels")]
mc_acc = self.accuracy(mc_logits, mc_labels)
input_ids = batch[MODEL_INPUTS.index("input_ids")] # (bs, 2, len)
distractor, orig = input_ids[:, 0], input_ids[:, 1]
# orig_token_type_ids = batch[MODEL_INPUTS.index("token_type_ids")][:, 1]
targets = torch.index_select(batch[MODEL_INPUTS.index("labels")], 1, mc_labels.view(-1)).squeeze(1) # (bs, len)
short_distractor = distractor[orig != distractor] # (bs, short_len)
switch_tensor = torch.tensor([speaker2], device=self.model.device)
short_orig = torch.cat([orig[orig_token_type_ids != PAD_VALUE], switch_tensor], dim=-1) # [any, any, any, ... , <speaker2> -- s.t. model will generate until EOS]
if short_orig.ndim == 1:
short_orig = short_orig.unsqueeze(0) # shape (n, len)
dynamic_config = self.config['inference']
dynamic_config['min_length'] = short_orig.size(-1) + self.config['inference']['min_length']
dynamic_config['max_length'] = min(short_orig.size(-1) + self.config['inference']['max_length'], MAX_GPT2_LENGTH)
candidate_sents = self.model.generate(short_orig,
pad_token_id=self.tokenizer.eos_token_id,
**dynamic_config)
self.log_text_predictions(short_orig,
short_distractor,
targets,
candidate_sents[:, short_orig.size(-1):])
self.log('val_loss', loss)
self.log('val_mc_acc', mc_acc)
self.log('val_lm_loss', lm_loss)
self.log('val_mc_loss', mc_loss)
return loss
def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx)
def test_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx)
def eval_epoch_end(self, batches, table_name):
table = wandb.Table(data=self.curr_eval_table,
columns=["Original", "Target", "Distractor", "Predicted"])
self.logger.experiment.log({table_name: table})
self.curr_eval_table = []
def log_text_predictions(self, orig, distractor, labels, predictions):
original_text = self.tokenizer.batch_decode(orig)
predictions_text = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels = labels[labels != PAD_VALUE]
# handle batch size = 1 edge case
if labels.ndim == 1: labels = labels.unsqueeze(0)
if distractor.ndim == 1: distractor = distractor.unsqueeze(0)
targets_text = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
distractor_text = self.tokenizer.batch_decode(distractor, skip_special_tokens=True)
self.curr_eval_table += list(zip(original_text, targets_text, distractor_text, predictions_text))
logger.info("Generated: '{original_text}' => '{predictions_text}'")
def validation_epoch_end(self, batches):
self.eval_epoch_end(batches, f"textgen_val_{self.current_epoch}_step{self.global_step}")
def test_epoch_end(self, batches):
self.eval_epoch_end(batches, "textgen_test")