Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split learning LLM example. #366

Merged
merged 27 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
620b7b9
added LLM split learning using current APIs.
dixiyao Nov 7, 2023
dadd62f
renamed names of functions and configuration variables.
dixiyao Nov 7, 2023
4988d0b
cleaned up the trainer.
dixiyao Nov 7, 2023
90158ae
Added two more API functions in split learning trainer.
dixiyao Nov 7, 2023
af2b55b
fixed a bug related to config.
dixiyao Nov 7, 2023
4bb06d5
added self.training args. and fixed a bug in copy_weight.
dixiyao Nov 7, 2023
4a9dc40
fixed a bug related to gradients.
dixiyao Nov 7, 2023
6f5704b
moved loading datasource to the init in the server split learning.
dixiyao Nov 7, 2023
91461ba
changed name of configuration file.
dixiyao Nov 7, 2023
fe7c91c
updated the docs/examples.md about LLM split learning.
dixiyao Nov 7, 2023
9a7fd67
added support for other LLMs in Huggingface.
dixiyao Nov 7, 2023
14c81fd
resolved issues raised by / in model name.
dixiyao Nov 7, 2023
37ab27a
revised the make the name matching as the completely matching.
dixiyao Nov 8, 2023
c4a586d
added support for Llama2 by passing use_auth_token.
dixiyao Nov 8, 2023
9755c99
added llama2 example.
dixiyao Nov 8, 2023
da8e79b
added support for LoRA model.
dixiyao Nov 8, 2023
72d9202
fixed a bug in loading lora model.
dixiyao Nov 8, 2023
fb7f49e
extract and load lora weights onnly.
dixiyao Nov 8, 2023
9ef9253
updated the examples.md about using LoRA to finetune.
dixiyao Nov 8, 2023
ed988d2
revised grammar and spelling in examples.md
dixiyao Nov 8, 2023
913ac63
fixed a bug in calculating the accuracy.
dixiyao Nov 8, 2023
b7821b3
Simplified split learning main function.
HeyHao Nov 9, 2023
d3e79b9
removed useless save_metrics.
dixiyao Nov 9, 2023
b57ebc7
used the checkpoint path as the temporary path for huggingface trainer.
dixiyao Nov 9, 2023
e5092df
use the checkpoint path as the temporary path for the Hugging Face tr…
dixiyao Nov 9, 2023
04d35d5
Check spelling in comments and names
silviafeiwang Nov 10, 2023
5f48c7f
Add libraries needed into requirement.txt
silviafeiwang Nov 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,19 @@ python examples/split_learning/controlnet_split_learning/split_learning_main.py
```
````

````{admonition} **Split Learning for Training LLM**
This is an example of fine-tuning the Hugging Face large language model with split learning. The fine-tuning policy includes training the whole model and fine-tuning with the LoRA algorithm. The cut layer in the configuration file should be set as an integer, indicating cutting at which transformer block in the transformer model.

Fine-tune the whole model
```shell
python ./examples/split_learning/llm_split_learning/split_learning_main.py -c ./examples/split_learning/llm_split_learning/split_learning_wikitext103_gpt2.yml
```
Fine-tune with LoRA
```shell
python ./examples/split_learning/llm_split_learning/split_learning_main.py -c ./examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2_lora.yml
```
````

#### Personalized Federated Learning Algorithms

````{admonition} **FedRep**
Expand Down
194 changes: 194 additions & 0 deletions examples/split_learning/llm_split_learning/split_learning_llm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Obtain LLM models from HuggingFace, specifically designed for split learning
"""
import torch
from transformers import AutoModelForCausalLM, AutoConfig
from peft import get_peft_model, LoraConfig
from plato.config import Config


def get_lora_model(model):
"""Apply LoRA optimization over the model"""
lora_config = Config().parameters.lora
model = get_peft_model(model, LoraConfig(**lora_config._asdict()))
model.print_trainable_parameters()
return model


def get_module(start_module: torch.nn.Module, module_names):
"""
Recursively get a PyTorch module starting from the start module with
a given list of module names.
"""
module = start_module
for module_name in module_names:
module = getattr(module, module_name)
return module


class BaseModel(torch.nn.Module):
"""
The basic model loading HuggingFace model used for the server model and the client model
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.model_name = Config().trainer.model_name
use_auth_token = None
if hasattr(Config().parameters, "huggingface_token"):
use_auth_token = Config().parameters.huggingface_token
config_kwargs = {
"cache_dir": None,
"revision": "main",
"use_auth_token": use_auth_token,
}

self.config = AutoConfig.from_pretrained(self.model_name, **config_kwargs)

self.base_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
config=self.config,
cache_dir=Config().params["model_path"] + "/huggingface",
token=use_auth_token,
)
self.cut_layer = Config().parameters.model.cut_layer

def get_input_embeddings(self):
"""
Return the base model get input embeddings.
"""
return self.base_model.get_input_embeddings()

def forward(self, inputs):
"""
The forward function for the base model.
"""
return self.base_model(inputs)


class ClientModel(BaseModel):
"""
The model on the clients in split learning with LLM.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# replace the layers in the base model
# which should be on the cloud with Identity layers()
transformer_module = self.base_model
for module_name in Config().parameters.model.transformer_module_name.split("."):
transformer_module = getattr(transformer_module, module_name)
client_layers = transformer_module[: self.cut_layer]
client_module_names = Config().parameters.model.transformer_module_name.split(
"."
)
client_module = get_module(self.base_model, client_module_names[:-1])
setattr(client_module, client_module_names[-1], client_layers)
# Set layers not on the clients to Identity
for layer in Config().parameters.model.layers_after_transformer:
layer = layer.split(".")
if len(layer) > 1:
module = get_module(self.base_model, layer[:-1])
setattr(module, layer[-1], torch.nn.Identity())
else:
setattr(self.base_model, layer[0], torch.nn.Identity())
# Apply LoRA optimization
if hasattr(Config().parameters, "lora"):
self.base_model = get_lora_model(self.base_model)

def forward(self, inputs):
"""
The forward function on the client
"""
inputs = inputs.long()
return self.base_model(input_ids=inputs, return_dict=False)

def forward_to(self, inputs):
"""
Forward to the cut layer and output intermediate feature
"""
outputs = self.forward(inputs)
return outputs[0]


class ServerModel(BaseModel):
"""
The model used on the cloud
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# In this design, we have two copies of the model
# The first copy of the model is the whole model which is used for test.
# The second copy of the model only contains the layers on the server
# used for training.
self.server_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
config=self.config,
cache_dir=Config().params["model_path"] + "/huggingface",
)
transformer_module = get_module(
self.base_model,
Config().parameters.model.transformer_module_name.split("."),
)
server_layers = transformer_module[self.cut_layer :]
server_module_names = Config().parameters.model.transformer_module_name.split(
"."
)
server_module = get_module(self.server_model, server_module_names[:-1])
setattr(server_module, server_module_names[-1], server_layers)
# Apply LoRA optimization
if hasattr(Config().parameters, "lora"):
self.base_model = get_lora_model(self.base_model)
self.server_model = get_lora_model(self.server_model)

def copy_weight(self):
"""
Copy the weights of the training model to the testing model
"""
basic_name = Config().parameters.model.transformer_module_name
# There will a module named base_model.model in LoRA model
if hasattr(Config().parameters, "lora"):
basic_name = "base_model.model." + basic_name
base_model_weights = self.base_model.state_dict()
server_model_weights = self.server_model.state_dict()

transformer_module = self.base_model
for module_name in basic_name.split("."):
transformer_module = getattr(transformer_module, module_name)
layer_names = [
basic_name + "." + str(index)
for index in range(
self.cut_layer,
len(transformer_module),
)
]
for weight_name in base_model_weights.keys():
# Copy the weights of transformer blocks
for layer_index, layer_name in enumerate(layer_names):
if layer_name in weight_name:
suffix = weight_name[
weight_name.find(layer_name) + len(layer_name) :
]
# The name should be completely matched
if not suffix[0] == ".":
continue
server_weight_name = basic_name + "." + str(layer_index) + suffix
base_model_weights[weight_name] = server_model_weights[
server_weight_name
]
# Copy the weights of layers after transformer blocks
for layer in Config().parameters.model.layers_after_transformer:
layer_name = basic_name + "." + layer
if layer_name in weight_name:
base_model_weights[weight_name] = server_model_weights[weight_name]

self.base_model.load_state_dict(base_model_weights)

def forward_from(self, inputs, labels):
"""
Forward from the intermediate feature on the server.
"""
labels = labels.long()
outputs = self.server_model(inputs_embeds=inputs, labels=labels)
return outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
A split learning algorithm supporting LoRA fine-tuning LLMs.
"""
from peft import (
set_peft_model_state_dict,
get_peft_model_state_dict,
)
from plato.algorithms import split_learning


class Algorithm(split_learning.Algorithm):
"""
Extract and load only the LoRA weights.
"""

def extract_weights(self, model=None):
# Extract LoRA wegiths
return {
k: v.cpu()
for k, v in get_peft_model_state_dict(self.model.base_model).items()
}

def load_weights(self, weights):
# Load LoRA weights
return set_peft_model_state_dict(self.model.base_model, weights)
28 changes: 28 additions & 0 deletions examples/split_learning/llm_split_learning/split_learning_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Finetune HuggingFace large language models using split learning.
"""
import split_learning_trainer
from split_learning_llm_model import ServerModel, ClientModel
from split_learning_lora_algorithm import Algorithm as LoRAAlgorithm

from plato.servers.split_learning import Server
from plato.clients.split_learning import Client
from plato.config import Config


def main():
"""A Plato federated learning training session using the split learning algorithm."""

algorithm = LoRAAlgorithm if hasattr(Config().parameters, "lora") else None

client = Client(
trainer=split_learning_trainer.Trainer, model=ClientModel, algorithm=algorithm
)
server = Server(
trainer=split_learning_trainer.Trainer, model=ServerModel, algorithm=algorithm
)
server.run(client)


if __name__ == "__main__":
main()
Loading