forked from rui-ye/OpenFedLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_sft.py
120 lines (93 loc) · 5.24 KB
/
main_sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import copy
import os
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM
from peft import get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, prepare_model_for_kbit_training
from utils import *
from federated_learning import *
from config import get_config, save_config, get_model_config, get_training_args
# ===== Define the arguments =====
script_args, fed_args, peft_config = get_config()
training_args = get_training_args(script_args, script_args.learning_rate)
save_config(script_args, fed_args)
print(script_args, fed_args)
# ===== Load the dataset =====
dataset = get_dataset(script_args.dataset_name, script_args.local_data_dir)
dataset = process_sft_dataset(script_args.dataset_name, dataset, script_args.dataset_sample)
# ===== Split the dataset into clients =====
local_datasets = split_dataset(fed_args, script_args, dataset)
sample_num_list = [len(local_datasets[i]) for i in range(fed_args.num_clients)]
# ===== Get model config =====
device_map, quantization_config, torch_dtype = get_model_config(script_args)
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=script_args.trust_remote_code,
torch_dtype=torch_dtype,
)
if script_args.load_in_8bit or script_args.load_in_4bit:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# ===== Define the global and local models =====
global_dict = copy.deepcopy(get_peft_model_state_dict(model))
local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]
proxy_dict, opt_proxy_dict = get_proxy_dict(fed_args, global_dict)
global_auxiliary, auxiliary_model_list, auxiliary_delta_dict = get_auxiliary_dict(fed_args, global_dict)
# ===== Define the tokenizer =====
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, use_fast=False, padding_side="right")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token # following vicuna
# ===== Define the formatting function (cater to TRL SFTTrainer)=====
formatting_prompts_func, response_template = get_formatting_prompts_func(script_args.template, tokenizer.eos_token)
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]` for Llama2
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
# ===== Start federated training =====
training_loss = [[] for i in range(fed_args.num_clients)]
for round in tqdm(range(fed_args.num_rounds)):
clients_this_round = get_clients_this_round(fed_args, round)
print(f">> ==================== Round {round+1} : {clients_this_round} ====================")
for client in range(fed_args.num_clients):
if client not in clients_this_round:
training_loss[client].append(-1) # -1 is an indicator of not training
continue
set_peft_model_state_dict(model, global_dict) # sync the global model to the local model
sub_dataset = get_dataset_this_round(local_datasets[client], round, fed_args, script_args) # get the required sub-dataset for this round
new_lr = cosine_learning_rate(round, fed_args.num_rounds, script_args.learning_rate, 1e-6) # manually schedule the learning rate
training_args = get_training_args(script_args, new_lr)
# ===== Train local model on the client side =====
trainer = get_fed_local_sft_trainer(
model=model,
tokenizer=tokenizer,
training_args=training_args,
local_dataset=sub_dataset,
formatting_prompts_func=formatting_prompts_func,
data_collator=data_collator,
global_dict=global_dict,
fed_args=fed_args,
script_args=script_args,
local_auxiliary=auxiliary_model_list[client],
global_auxiliary=global_auxiliary,
)
results = trainer.train()
training_loss[client].append(results.training_loss)
# ===== Client transmits local information to server =====
if fed_args.fed_alg == 'scaffold':
auxiliary_model_list[client], auxiliary_delta_dict[client] = trainer.get_auxiliary_param()
local_dict_list[client] = copy.deepcopy(get_peft_model_state_dict(model)) # copy is needed!
# ===== Server aggregates the local models =====
global_dict, global_auxiliary = global_aggregate(
fed_args, global_dict, local_dict_list, sample_num_list, \
clients_this_round, round, proxy_dict=proxy_dict, \
opt_proxy_dict=opt_proxy_dict, auxiliary_info=(global_auxiliary, auxiliary_delta_dict)
)
set_peft_model_state_dict(model, global_dict) # Update global model
# ===== Save the model =====
if (round+1) % 50 == 0:
trainer.save_model(os.path.join(script_args.output_dir, f"checkpoint-{round+1}"))
np.save(os.path.join(script_args.output_dir, "training_loss.npy"), np.array(training_loss))