Skip to content

Commit

Permalink
Migrate to peft from opendelta for parameter efficient tuning methods (
Browse files Browse the repository at this point in the history
  • Loading branch information
glerzing committed Apr 21, 2023
1 parent 9bc0836 commit 551526b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 114 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ numpy==1.24.2
packaging==23.0
pandas==1.5.3
pathtools==0.1.2
peft==0.2.0
platformdirs==3.1.1
protobuf==4.22.1
psutil==5.9.4
Expand Down
23 changes: 12 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,31 @@ def test_hf_attr_getters(model_name: str):
"google/flan-t5-large",
],
)
def test_parse_delta_kwargs(model_name):
def test_parse_and_update_peft_kwargs(model_name):
config = transformers.AutoConfig.from_pretrained(model_name)

modified_modules_dict = modeling_utils.MODIFIED_MODULES_DICT[config.model_type]
for default_modifier, default_modified_modules in modified_modules_dict.items():
delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs(
delta_kwargs={"delta_type": "lora", "modified_modules": default_modifier},
target_modules_dict = modeling_utils.TARGET_MODULES_DICT[config.model_type]
for target_type, target_modules in target_modules_dict.items():
peft_kwargs = {"peft_type": "LORA", "target_modules": target_type}
peft_kwargs = modeling_utils.parse_and_update_peft_kwargs(
peft_kwargs=peft_kwargs,
config=config,
num_layers_unfrozen=4,
)
# Ensure the parsed module regex patterns capture the default module names
for kwarg_mod, default_mod in zip(delta_kwargs["modified_modules"], default_modified_modules):
for kwarg_mod, default_mod in zip(peft_kwargs["target_modules"], target_modules):
assert kwarg_mod.endswith(
default_mod
), f"Parsed modified module `{kwarg_mod}` should contain the trlx default `{default_mod}`"
assert delta_type == "lora", "Delta type should be lora"

# Ensure the defaults don't get used if the user specifies a list of `modified_modules`
delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs(
delta_kwargs={"delta_type": "lora", "modified_modules": ["a", "b"]},
# Ensure the defaults don't get used if the user specifies a list of `target_modules`
peft_kwargs = {"peft_type": "lora", "target_modules": ["a", "b"]}
modeling_utils.parse_and_update_peft_kwargs(
peft_kwargs=peft_kwargs,
config=config,
num_layers_unfrozen=2,
)
for kwarg_mod in delta_kwargs["modified_modules"]:
for kwarg_mod in peft_kwargs["target_modules"]:
assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']"


Expand Down
34 changes: 21 additions & 13 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,33 @@ class ModelConfig:
-1 means all layers are unfrozen.
:type num_layers_unfrozen: int
:param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning.
Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set
the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`)
{"delta_type": "lora", "modified_modules": "all", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.0}
:param peft_kwargs: arguments used by the peft (Parameter Efficient Fine-Tuning) library.
Here is an example of LORA configuration:
{"peft_type": "LORA", "target_modules": "all", "r": 8, "lora_alpha": 32, "lora_dropout": 0.1}
or in YAML format:
delta_kwargs:
delta_type: lora
modified_modules: "all"
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0
See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig
:type delta_kwargs: Optional[Dict[str, Any]]
peft_kwargs:
peft_type: "LORA"
target_modules: "all"
r: 8
lora_alpha: 32
lora_dropout: 0.1
Supported peft types include "LORA", "PROMPT_TUNING", "P_TUNING" and "PREFIX_TUNING"
If "target_modules" is set to "attention", "mlp" or "all", trlx will automatically
choose the corresponding names for the given model architecture.
Some examples of peft configurations can be found here :
https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py
(this was previously done with OpenDelta, which is no longer supported)
:type peft_kwargs: Optional[Dict[str, Any]]
"""

model_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
delta_kwargs: Optional[Dict[str, Any]] = None
peft_kwargs: Optional[Dict[str, Any]] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand Down
19 changes: 10 additions & 9 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from accelerate import Accelerator # type: ignore
from ray.air import session
from peft import get_peft_config, get_peft_model
from rich.console import Console
from rich.table import Table
from transformers import AutoTokenizer
Expand All @@ -32,8 +33,7 @@
freeze_bottom_causal_layers,
freeze_bottom_seq2seq_layers,
gather_dict,
get_delta_model_class,
parse_delta_kwargs,
parse_and_update_peft_kwargs,
)

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -142,17 +142,18 @@ def setup_model(self):
else:
freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen)
# Set the delta tuning strategies
if self.config.model.delta_kwargs is not None:
delta_type, delta_kwargs = parse_delta_kwargs(
if self.config.model.peft_kwargs is not None:
parse_and_update_peft_kwargs(
model.base_model.config,
self.config.model.delta_kwargs,
self.config.model.peft_kwargs,
self.config.model.num_layers_unfrozen,
)
delta_model_class = get_delta_model_class(delta_type)
delta_model = delta_model_class(model.base_model, **delta_kwargs)
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
peft_config = get_peft_config(self.config.model.peft_kwargs)
peft_model = get_peft_model(model.base_model, peft_config)

if self.accelerator.is_main_process:
delta_model.log()
peft_model.print_trainable_parameters()

return model

def setup_optimizer(self):
Expand Down
134 changes: 53 additions & 81 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,13 @@

import accelerate
import numpy as np
from peft.utils.config import PeftType
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import transformers

try:
from opendelta import (
AdapterModel,
BitFitModel,
LoraModel,
PrefixModel,
SoftPromptModel,
)

HAS_OPENDELTA = True
except ModuleNotFoundError:
HAS_OPENDELTA = False


def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential:
"""Returns a generic sequential MLP head."""
return nn.Sequential(
Expand Down Expand Up @@ -313,10 +300,29 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt()


# OpenDelta utilities
# Peft utilities

def parse_and_update_peft_kwargs(
config: transformers.PretrainedConfig,
peft_kwargs: Dict[str, Any],
num_layers_unfrozen: int = -1,
):
"""Check the argument peft_type.
If peft_type is equal to "LORA", update the argument peft_kwargs"""

# Check the argument peft_type
assert "peft_type" in peft_kwargs, "peft_type must be specified in peft_kwargs"
assert peft_kwargs["peft_type"] in PeftType._member_names_, f"peft_type must be in {PeftType._member_names_}"

MODIFIED_MODULES_DICT = {
# The argument target_modules only exists for LORA
if peft_kwargs["peft_type"] == "LORA":
update_peft_lora_target_modules(config, peft_kwargs, num_layers_unfrozen)


# All of what is below is only used for LORA
# Other peft_types don't require to specify the target modules

TARGET_MODULES_DICT = {
"gptj": {
"attention": ["attn.q_proj", "attn.k_proj", "attn.v_proj"],
"mlp": ["mlp.fc_in", "mlp.fc_out"],
Expand Down Expand Up @@ -398,11 +404,41 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
},
}

def update_peft_lora_target_modules(
config: transformers.PretrainedConfig,
peft_kwargs: Dict[str, Any],
num_layers_unfrozen: int,
):
"""If target_modules is not set explicitly (i.e. is in ["all", "attention", "mlp"]):
Replace target_modules with a regex that matches the name of every
unfrozen LORA target modules."""

target_modules = peft_kwargs.get("target_modules", None)
if target_modules in ["all", "attention", "mlp"]:
if config.model_type not in TARGET_MODULES_DICT:
raise ValueError(
f"`{config.model_type}` is not currently supported for peft training"
" with peft_type 'LORA' and target_modules 'target_modules'."
f"\nEither choose another peft_type in {PeftType._member_names_}"
f", or a model among {list(TARGET_MODULES_DICT.keys())}"
", or specify manually target_modules with a list of modules to modify."
)
target_modules = TARGET_MODULES_DICT[config.model_type][target_modules]

unfrozen_layers_pattern = generate_layer_regex(config, num_layers_unfrozen)

# TODO (jon-tow): `decoder.block.` is hardcoded to support T5 layer naming.
prefix = ".*decoder.block." if config.is_encoder_decoder else ".*"
module_list = [prefix + unfrozen_layers_pattern + module for module in target_modules]

# Update target_modules in peft_kwargs
# Here it is a regex string, but peft would also accept an exhaustive list of modules
peft_kwargs["target_modules"] = "|".join(module_list)

def generate_layer_regex(config: transformers.PretrainedConfig, num_layers_unfrozen: int = -1) -> str:
"""Generates a regex range for the specified number of learnable layers."""
if num_layers_unfrozen == -1:
return "(\d)+."
return r"(\d)+."
num_hidden_layers = hf_get_num_hidden_layers(config)
start_layer = num_hidden_layers - num_layers_unfrozen
if start_layer < 0:
Expand All @@ -411,70 +447,6 @@ def generate_layer_regex(config: transformers.PretrainedConfig, num_layers_unfro
return f"{pattern}"


def get_delta_modified_modules(
config: transformers.PretrainedConfig,
modified_modules: List[str],
num_layers_unfrozen: int = -1,
) -> List[str]:
"""Returns a list of module names to be modified for a given delta method with
the specified number of learnable layers."""
unfrozen_layers_pattern = generate_layer_regex(config, num_layers_unfrozen)

# [r] for regex as per https://github.com/thunlp/OpenDelta/blob/main/opendelta/utils/name_based_addressing.py#L20
regex_prefix = "[r]"
# TODO (jon-tow): `decoder.block.` is hardcoded to support T5 layer naming.
decoder_prefix = "decoder.block." if config.is_encoder_decoder else ""
module_list = [regex_prefix + decoder_prefix + unfrozen_layers_pattern + module for module in modified_modules]
return module_list


def get_delta_model_class(model_type: str):
if not HAS_OPENDELTA:
raise ValueError("OpenDelta package required to train with delta models. https://github.com/thunlp/OpenDelta.")
delta_models = {
"bitfit": BitFitModel,
"adapter": AdapterModel,
"prefix": PrefixModel,
"lora": LoraModel,
"softprompt": SoftPromptModel,
}
return delta_models[model_type]


def parse_delta_kwargs(
config: transformers.PretrainedConfig,
delta_kwargs: Dict[str, Any],
num_layers_unfrozen: int = -1,
) -> Tuple[str, Dict[str, Any]]:
"""Parses through delta kwargs to get delta type and proper modified modules."""
# This function is needed to parse through the `delta_kwargs` in order to:
# 1) Get the `delta_type` method name to access the correct `delta_model_class`
# 2a) Accept user specified `modified_modules` and if not provided use the `trlx` default mapping
# 2b) Convert the list of `modified_modules` to a range of layers that fit within the range
# of learnable layers as specified by `num_layers_unfrozen`

# Pop `delta_type` to allow passing the kwargs to the model constructor since
# `delta_type` is not a valid argument of the constructor
delta_type = delta_kwargs.pop("delta_type")
assert delta_type in ["lora"], "Only `LoRA` based delta models are supported"

# Use `trlx` default modified modules if none are specified
modified_modules = delta_kwargs.get("modified_modules", "all")
if modified_modules in ["all", "attention", "mlp"]:
if config.model_type not in MODIFIED_MODULES_DICT:
raise ValueError(
f"Model type `{config.model_type}` is not currently supported for "
"delta training with default modified modules."
)
modified_modules = MODIFIED_MODULES_DICT[config.model_type][modified_modules]
# Update the `modified_modules` with the correct layer ranges
delta_kwargs["modified_modules"] = get_delta_modified_modules(
config, modified_modules, num_layers_unfrozen=num_layers_unfrozen
)

return delta_type, delta_kwargs


def regex_for_range(min_: int, max_: int) -> str: # noqa
"""Returns a regex that matches all numbers in the given range.
Expand Down

0 comments on commit 551526b

Please sign in to comment.