-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune.py
214 lines (176 loc) · 9.56 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from .base import BaseClass
from peft import get_peft_config, get_peft_model
from transformers import set_seed, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import json
from loguru import logger
from sklearn.model_selection import train_test_split
import torch
from .preprocessing import DatasetProcessor
from .basic_model_loader import load_model, load_tokenizer
from .utils import get_max_length, log
class Finetune(BaseClass):
def __init__(self, preprocessor=DatasetProcessor(), config_file="../configs/config_finetune.json", **kwargs):
"""
Initializes the Finetune class.
Args:
preprocessor (DatasetProcessor, optional): The preprocessor object to use for dataset processing. Defaults to DatasetProcessor().
config_file (str, optional): The path to the configuration file. Defaults to "../configs/config_finetune.json".
**kwargs: Additional keyword arguments to update the configuration.
Attributes:
config (dict): The configuration dictionary loaded from the config file.
model (None): The model object (initialized as None).
dtype (torch.dtype): The data type to use for tensors (initialized as torch.float32).
deepspeed_config (dict): The DeepSpeed configuration dictionary (initialized as None if use_deepspeed is False).
lora_config_peft (dict): The PEFT configuration dictionary generated from lora_config.
"""
self.config = json.load(open(config_file, "r"))
for kwarg in kwargs:
self.config[kwarg] = kwargs[kwarg]
self.__dict__.update(self.config)
self.model = None
self.dtype = torch.float32
if self.fp16:
self.dtype = torch.float16
if self.bf16:
self.dtype = torch.bfloat16
if not self.use_deepspeed:
deepspeed_config = None
self.config["deepspeed_config_file"] = None
else:
deepspeed_config = json.load(open(self.deepspeed_config_file, "r"))
self.config["model_name"] = None
self.config["deepspeed_config"] = deepspeed_config
super().__init__(**self.config, preprocessor=preprocessor)
self.lora_config_peft = get_peft_config(self.lora_config)
def load_model(self, model_name, model=None):
"""
Loads a pre-trained model for contamination detection.
Args:
model_name (str): The name of the model to load.
model (Optional[torch.nn.Module]): An optional pre-initialized model object.
Returns:
None
"""
revision = 'main'
# phi-2 kept updating the modeling file which made the code break several times, we therefore use a specific revision
if model_name == 'microsoft/phi-2':
revision = '39afec137e35c2bd2c67d939af6efae96198fd00'
if model is not None:
self.model = model
else:
log(logger.debug, f"Loading model for {model_name} with revision {revision}")
self.model = load_model(model_name, dtype=self.dtype, revision=revision)
if self.use_lora:
self.model = get_peft_model(self.model, self.lora_config_peft)
def finetune(self, model_name, dataset, data_collator=None, model=None, neptune_run=None, **kwargs):
"""
Fine-tunes the model with the given dataset.
Args:
model_name (str): The name of the model to be fine-tuned.
dataset (Dataset): The dataset used for fine-tuning.
data_collator (DataCollator, optional): The data collator used for batching and preparing the data. Defaults to None.
model (Model, optional): The pre-trained model to be fine-tuned. Defaults to None.
neptune_run (NeptuneRun, optional): The Neptune run object for logging. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Model: The fine-tuned model.
"""
log(logger.info, f"Finetuning model with {self} and dataset with size {len(dataset)}")
self.model_name = model_name
dataset = self.preprocessor.prepare_dataset(dataset, self.model_name)
set_seed(42)
if not self.reload or self.model is None:
log(logger.debug, "Loading model")
self.load_model(model_name, model=model)
tokenizer = load_tokenizer(model_name)
self.model.config.pad_token_id = tokenizer.pad_token_id
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
max_length = get_max_length(self.model.config, default_length=self.max_length_default)
if len(dataset) > 1 and self.test_split_size > 0:
log(logger.debug, "Splitting dataset")
train_dataset, test_dataset = train_test_split(dataset, test_size=self.test_split_size, random_state=42)
else:
train_dataset = dataset
test_dataset = None
callbacks = []
report_to = 'all'
if neptune_run is not None:
report_to = 'none'
training_args = TrainingArguments(
output_dir=self.output_dir, # output directory
num_train_epochs=self.num_train_epochs, # total number of training epochs
per_device_train_batch_size=self.per_device_train_batch_size, # batch size per device during training
per_device_eval_batch_size=self.per_device_eval_batch_size, # batch size for evaluation
warmup_ratio=self.warmup_ratio, # number of warmup steps for learning rate scheduler
weight_decay=self.weight_decay, # strength of weight decay
logging_dir=self.logging_dir, # directory for storing logs
logging_steps=self.logging_steps,
learning_rate=self.learning_rate,
save_steps=self.save_steps,
deepspeed=self.deepspeed_config_file,
save_total_limit=self.save_total_limit,
eval_steps=self.eval_steps,
evaluation_strategy="steps",
save_strategy="steps",
fp16=self.fp16,
bf16=self.bf16,
report_to=report_to,
)
trainer = Trainer(
model=self.model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=test_dataset, # evaluation dataset
data_collator=data_collator,
callbacks=callbacks,
)
log(logger.info, "Starting Training")
trainer.train()
return self.model
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
A data collator for completion-only language modeling tasks.
This class extends the `DataCollatorForLanguageModeling` class and provides a custom implementation
of the `torch_call` method. It is specifically designed for completion-only language modeling tasks.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer used for tokenizing the input text.
Attributes:
tokenizer (PreTrainedTokenizer): The tokenizer used for tokenizing the input text.
Methods:
torch_call(examples): Overrides the base class method to perform custom processing on the input examples.
"""
def torch_call(self, examples):
batch = super().torch_call([{"input_ids": example["input_ids"]} for example in examples])
if self.tokenizer.padding_side == "left":
batch_labels = [torch.cat([-100 * torch.ones(len(input_) - len(example["labels"]), dtype=torch.long), example["labels"]])
for input_, example in zip(batch["input_ids"], examples)]
else:
batch_labels = [torch.cat([example["labels"], -100 * torch.ones(len(input_) - len(example["labels"]), dtype=torch.long)])
for input_, example in zip(batch["input_ids"], examples)]
batch_labels = torch.stack(batch_labels)
batch["labels"] = batch_labels.long()
return batch
class FinetuneInstructions(Finetune):
def __init__(self, preprocessor, config_file="../configs/config_finetune.json", **kwargs):
"""
Initializes the Finetune class. This class is specifically designed for fine-tuning models on instruction-based datasets.
Args:
preprocessor: The preprocessor object used for data preprocessing.
config_file: The path to the configuration file (default: "../configs/config_finetune.json").
**kwargs: Additional keyword arguments.
Returns:
None
"""
super().__init__(preprocessor, config_file, **kwargs)
def finetune(self, model_name, dataset, data_collator=None, model=None, **kwargs):
if data_collator is None:
log(logger.info, "Using Data collator for completion only LM")
tokenizer = load_tokenizer(model_name)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForCompletionOnlyLM(tokenizer=tokenizer,
mlm=False)
return super().finetune(model_name, dataset, data_collator,
model=model, **kwargs)