-
-
Notifications
You must be signed in to change notification settings - Fork 899
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add lisa support * fix default and fix attribute traversal for layers * improve lisa callback logging * fix LISA by ensuring params are not frozen during __init__ * example config for lisa --------- Co-authored-by: Aman Karmani <[email protected]>
- Loading branch information
Showing
4 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
base_model: NousResearch/Llama-2-7b-hf | ||
model_type: LlamaForCausalLM | ||
tokenizer_type: LlamaTokenizer | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
datasets: | ||
- path: teknium/GPT4-LLM-Cleaned | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
val_set_size: 0.05 | ||
output_dir: ./lisa-out | ||
|
||
sequence_len: 4096 | ||
sample_packing: true | ||
pad_to_sequence_len: true | ||
|
||
adapter: | ||
lora_model_dir: | ||
lora_r: | ||
lora_alpha: | ||
lora_dropout: | ||
lora_target_linear: | ||
lora_fan_in_fan_out: | ||
|
||
lisa_n_layers: 4 | ||
lisa_step_interval: 20 | ||
lisa_layers_attribute: model.layers | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 2 | ||
micro_batch_size: 1 | ||
num_epochs: 1 | ||
optimizer: adamw_bnb_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 5e-5 # recommendation from lisa paper for 7b | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
flash_attn_cross_entropy: false | ||
flash_attn_rms_norm: true | ||
flash_attn_fuse_qkv: false | ||
flash_attn_fuse_mlp: true | ||
|
||
warmup_steps: 100 | ||
evals_per_epoch: 4 | ||
eval_table_size: | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.1 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: | ||
bos_token: "<s>" | ||
eos_token: "</s>" | ||
unk_token: "<unk>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
module for LISA | ||
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl | ||
Arxiv: https://arxiv.org/abs/2403.17919 | ||
License: Apache 2.0 | ||
""" | ||
|
||
import logging | ||
from functools import reduce | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
from transformers import TrainerCallback | ||
|
||
if TYPE_CHECKING: | ||
from axolotl.core.trainer_builder import AxolotlTrainer | ||
|
||
LOG = logging.getLogger("axolotl.callbacks.lisa") | ||
|
||
|
||
def lisa_callback_factory(trainer: "AxolotlTrainer"): | ||
class LISACallback(TrainerCallback): | ||
"""trainer callback for lisa layer switching""" | ||
|
||
def __init__( | ||
self, n_layers, step_interval, trainer, layers_attribute="model.layers" | ||
): | ||
super().__init__() | ||
self.n_layers = n_layers | ||
self.step_interval = step_interval | ||
self.layers_attribute = layers_attribute | ||
self.trainer = trainer | ||
|
||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model) | ||
|
||
self.total_layers = len( | ||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model) | ||
) | ||
self.active_layers_indices = [] | ||
|
||
layers = reduce( | ||
getattr, self.layers_attribute.split("."), self.trainer.model | ||
) | ||
LOG.info( | ||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" | ||
) | ||
|
||
def freeze_all_layers(self): | ||
layers = reduce( | ||
getattr, self.layers_attribute.split("."), self.trainer.model | ||
) | ||
for layer in layers: | ||
for param in layer.parameters(): | ||
param.requires_grad = False | ||
|
||
def on_step_begin( | ||
self, args, state, control, **kwargs | ||
): # pylint: disable=unused-argument | ||
# Check if it's time to switch active layers, including at step 0 | ||
if state.global_step % self.step_interval == 0 or state.global_step == 1: | ||
self.switch_active_layers() | ||
|
||
def switch_active_layers(self): | ||
# First, disable gradients for all layers | ||
self.freeze_all_layers() | ||
|
||
# Randomly select n_layers to activate | ||
layers = reduce( | ||
getattr, self.layers_attribute.split("."), self.trainer.model | ||
) | ||
self.active_layers_indices = np.random.choice( | ||
range(self.total_layers), self.n_layers, replace=False | ||
) | ||
LOG.info( | ||
f"Activating layers at indices: {self.active_layers_indices} for the next steps." | ||
) | ||
|
||
# Enable gradients only for the selected layers | ||
for idx in self.active_layers_indices: | ||
for param in layers[idx].parameters(): | ||
param.requires_grad = True | ||
|
||
lisa_callback = LISACallback( | ||
n_layers=trainer.args.lisa_n_layers, | ||
step_interval=trainer.args.lisa_step_interval, | ||
trainer=trainer, | ||
layers_attribute=trainer.args.lisa_layers_attribute, | ||
) | ||
|
||
return lisa_callback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters