-
Notifications
You must be signed in to change notification settings - Fork 49
/
train.py
319 lines (270 loc) · 11.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We modified the code based on Alpaca train.py. Author: Zheng Yuan, Hongyi Yuan
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import io
import torch
import torch.nn.functional as F
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
import json
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
stop_response: bool = field(default=False)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
rrhf_weight: float = field(default=100.0)
length_penalty: float = field(default=1.0)
only_use_provide: bool = field(default=False)
only_use_sample: bool = field(default=False)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
class ScoreDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(ScoreDataset, self).__init__()
logging.warning("Loading data...")
with open(data_path, 'r') as f:
lines = f.readlines()
self.data = [json.loads(line.strip()) for line in lines]
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return dict(input_ids=self.data[i])
def _single_tokenize(text, tokenizer, max_len=None):
if max_len is None:
max_len = tokenizer.model_max_length
toked = tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_len,
truncation=True,
)
return toked['input_ids'][0]
def stop_response(res):
stops = ['\n\nHuman:', '\n\nAssistant:', '\n\nhuman:', '\n\nassistant:']
for stop in stops:
if res.find(stop) >= 0:
res = res[:res.find(stop)].strip()
return res
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
stop_response: bool
def __call__(self, instances):
idxs = []
all_scores = []
input_ids = []
score_mask = []
labels = []
for idx, ins in enumerate(instances):
ins = ins['input_ids'] # hack
query = ins['query']
responses = ins['responses']
scores = ins['scores']
all_scores.append(scores)
idxs.append([idx] * len(scores))
query_input_ids = _single_tokenize(query, self.tokenizer)
query_target = torch.LongTensor([IGNORE_INDEX] * (query_input_ids.shape[0] - 1))
dummy_target = torch.LongTensor([IGNORE_INDEX])
for res in responses:
if self.stop_response:
r = stop_response(res)
else:
r = res
res_input_ids = _single_tokenize(r + self.tokenizer.eos_token, self.tokenizer, max_len=self.tokenizer.model_max_length-query_input_ids.shape[0]) # eos here
input_ids.append(torch.cat((query_input_ids, res_input_ids), dim=0))
labels.append(torch.cat((query_target, res_input_ids, dummy_target), dim=0))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return dict(
input_ids=input_ids,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
labels=labels,
idxs=torch.LongTensor(idxs),
scores=torch.FloatTensor(all_scores),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = ScoreDataset(tokenizer=tokenizer, data_path=data_args.data_path)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, stop_response=data_args.stop_response)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
class RRHFTrainer(Trainer):
def gather_logits_labels(self, logits, labels):
mask = (labels != -100).float()
new_logits = logits.clone() # Create a copy to avoid in-place modification
labels[labels == -100] = 0
output = torch.gather(new_logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
output = output * mask # B * L
return output
def get_score(self, logit_label, labels):
mask = (labels != -100).float()
length = mask.sum(-1)
scores = logit_label.sum(-1) / (length ** self.args.length_penalty)
return scores
def rrhf_loss(self, scores, idxs, rw_scores):
diff = scores.unsqueeze(0) - scores.unsqueeze(-1) # b * b
rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b
aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0]
return -diff[aval].sum()
def sft_loss(self, logit_label, idxs, rw_scores):
max_idx = torch.argmax(rw_scores)
return -logit_label[max_idx].mean()
def compute_loss(self, model, inputs, return_outputs=False):
if self.args.only_use_provide:
inputs['input_ids'] = inputs['input_ids'][-2:]
inputs['attention_mask'] = inputs['attention_mask'][-2:]
inputs['labels'] = inputs['labels'][-2:]
inputs["idxs"] = inputs["idxs"][:,-2:]
inputs["scores"] = inputs["scores"][:,-2:]
if self.args.only_use_sample:
inputs['input_ids'] = inputs['input_ids'][:-2]
inputs['attention_mask'] = inputs['attention_mask'][:-2]
inputs['labels'] = inputs['labels'][:-2]
inputs["idxs"] = inputs["idxs"][:,:-2]
inputs["scores"] = inputs["scores"][:,:-2]
logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V
logits = F.log_softmax(logits, dim=-1)
logit_label = self.gather_logits_labels(logits, inputs.get("labels"))
scores = self.get_score(logit_label, inputs.get("labels"))
rrhf_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores"))
sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores"))
loss = self.args.rrhf_weight * rrhf_loss + sft_loss
return (loss, scores) if return_outputs else loss
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = RRHFTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train()