Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lisa #1469

Merged
merged 5 commits into from
Apr 1, 2024
Merged

Lisa #1469

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/llama-2/lisa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./lisa-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

lisa_n_layers: 4
lisa_step_interval: 20
lisa_layers_attribute: model.layers

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 5e-5 # recommendation from lisa paper for 7b

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true

warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
24 changes: 24 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
Expand Down Expand Up @@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments):
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -938,6 +951,8 @@ def get_post_trainer_create_callbacks(self, trainer):
)
callbacks.append(early_stop_cb)

if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
return callbacks

def _get_trainer_cls(self):
Expand Down Expand Up @@ -1229,6 +1244,15 @@ def build(self, total_num_steps):
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio

if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[
"lisa_step_interval"
] = self.cfg.lisa_step_interval
training_arguments_kwargs[
"lisa_layers_attribute"
] = self.cfg.lisa_layers_attribute

training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
Expand Down
91 changes: 91 additions & 0 deletions src/axolotl/utils/callbacks/lisa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
module for LISA

Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0
"""

import logging
from functools import reduce
from typing import TYPE_CHECKING

import numpy as np
from transformers import TrainerCallback

if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainer

LOG = logging.getLogger("axolotl.callbacks.lisa")


def lisa_callback_factory(trainer: "AxolotlTrainer"):
class LISACallback(TrainerCallback):
"""trainer callback for lisa layer switching"""

def __init__(
self, n_layers, step_interval, trainer, layers_attribute="model.layers"
):
super().__init__()
self.n_layers = n_layers
self.step_interval = step_interval
self.layers_attribute = layers_attribute
self.trainer = trainer

reduce(getattr, self.layers_attribute.split("."), self.trainer.model)

self.total_layers = len(
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
)
self.active_layers_indices = []

layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
)

def freeze_all_layers(self):
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
for layer in layers:
for param in layer.parameters():
param.requires_grad = False

def on_step_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1:
self.switch_active_layers()

def switch_active_layers(self):
# First, disable gradients for all layers
self.freeze_all_layers()

# Randomly select n_layers to activate
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False
)
LOG.info(
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
)

# Enable gradients only for the selected layers
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True

lisa_callback = LISACallback(
n_layers=trainer.args.lisa_n_layers,
step_interval=trainer.args.lisa_step_interval,
trainer=trainer,
layers_attribute=trainer.args.lisa_layers_attribute,
)

return lisa_callback
18 changes: 18 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel):
hf_mlflow_log_artifacts: Optional[bool] = None


class LISAConfig(BaseModel):
"""LISA options"""

lisa_n_layers: Optional[int] = Field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = Field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = Field(
default="model.layers",
metadata={"help": "path under the model to access the layers"},
)


class WandbConfig(BaseModel):
"""wandb configuration subset"""

Expand Down Expand Up @@ -404,6 +421,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
LISAConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
Expand Down
Loading