From 3a9ad7c66e99ce4f68ca59ae559bb252cbf5ed97 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 30 Mar 2024 22:55:15 -0400 Subject: [PATCH 1/5] add lisa support --- src/axolotl/core/trainer_builder.py | 24 ++++++ src/axolotl/utils/callbacks/lisa.py | 73 +++++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 18 +++++ 3 files changed, 115 insertions(+) create mode 100644 src/axolotl/utils/callbacks/lisa.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4d85b40dee..cc7275184e 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -45,6 +45,7 @@ causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, ) +from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments): orpo_alpha: Optional[float] = field( default=None, ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) class AxolotlTrainer(Trainer): @@ -938,6 +951,8 @@ def get_post_trainer_create_callbacks(self, trainer): ) callbacks.append(early_stop_cb) + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + callbacks.append(lisa_callback_factory(trainer)) return callbacks def _get_trainer_cls(self): @@ -1229,6 +1244,15 @@ def build(self, total_num_steps): "relora_prune_ratio" ] = self.cfg.relora_prune_ratio + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers + training_arguments_kwargs[ + "lisa_step_interval" + ] = self.cfg.lisa_step_interval + training_arguments_kwargs[ + "lisa_layers_attribute" + ] = self.cfg.lisa_layers_attribute + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py new file mode 100644 index 0000000000..a42f48bb82 --- /dev/null +++ b/src/axolotl/utils/callbacks/lisa.py @@ -0,0 +1,73 @@ +"""module for LISA""" +import ast +from typing import TYPE_CHECKING + +import numpy as np +from transformers import TrainerCallback + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainer + + +def lisa_callback_factory(trainer: "AxolotlTrainer"): + class LISACallback(TrainerCallback): + """trainer callback for lisa layer switching""" + + def __init__( + self, n_layers, step_interval, trainer, layers_attribute="model.layers" + ): + super().__init__() + self.n_layers = n_layers + self.step_interval = step_interval + self.layers_attribute = layers_attribute + self.trainer = trainer + + self.total_layers = len( + ast.literal_eval("self.trainer.model." + self.layers_attribute) + ) + self.freeze_all_layers() + self.active_layers_indices = [] + + def freeze_all_layers(self): + layers = ast.literal_eval( + "self.trainer.model." + self.layers_attribute + ) # Dynamically execute to get layers + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def on_step_begin( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # Check if it's time to switch active layers, including at step 0 + if state.global_step % self.step_interval == 0 or state.global_step == 1: + self.switch_active_layers() + + def switch_active_layers(self): + # First, disable gradients for all layers + self.freeze_all_layers() + + # Randomly select n_layers to activate + layers = ast.literal_eval( + "self.trainer.model" + self.layers_attribute + ) # Re-fetch layer references + self.active_layers_indices = np.random.choice( + range(self.total_layers), self.n_layers, replace=False + ) + print( + f"Activating layers at indices: {self.active_layers_indices} for the next steps." + ) + + # Enable gradients only for the selected layers + for idx in self.active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True + + lisa_callback = LISACallback( + n_layers=trainer.args.lisa_n_layers, + step_interval=trainer.args.lisa_step_interval, + trainer=trainer, + layers_attribute=trainer.args.lisa_layers_attribute, + ) + + return lisa_callback diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c07c0ff75a..c66ae70d43 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel): hf_mlflow_log_artifacts: Optional[bool] = None +class LISAConfig(BaseModel): + """LISA options""" + + lisa_n_layers: Optional[int] = Field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = Field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = Field( + default="", + metadata={"help": "path under the model to access the layers"}, + ) + + class WandbConfig(BaseModel): """wandb configuration subset""" @@ -404,6 +421,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + LISAConfig, RemappedParameters, DeprecatedParameters, BaseModel, From 21a5094226697ccf6085d99f0d101b8ddca0cd5e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 31 Mar 2024 00:27:04 -0400 Subject: [PATCH 2/5] fix default and fix attribute traversal for layers --- src/axolotl/utils/callbacks/lisa.py | 27 ++++++++++++------- .../config/models/input/v0_4_1/__init__.py | 2 +- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index a42f48bb82..4df3225bbb 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -1,5 +1,12 @@ -"""module for LISA""" -import ast +""" +module for LISA + +Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl +Arxiv: https://arxiv.org/abs/2403.17919 +License: Apache 2.0 +""" + +from functools import reduce from typing import TYPE_CHECKING import numpy as np @@ -22,16 +29,18 @@ def __init__( self.layers_attribute = layers_attribute self.trainer = trainer + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) + self.total_layers = len( - ast.literal_eval("self.trainer.model." + self.layers_attribute) + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) self.freeze_all_layers() self.active_layers_indices = [] def freeze_all_layers(self): - layers = ast.literal_eval( - "self.trainer.model." + self.layers_attribute - ) # Dynamically execute to get layers + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) for layer in layers: for param in layer.parameters(): param.requires_grad = False @@ -48,9 +57,9 @@ def switch_active_layers(self): self.freeze_all_layers() # Randomly select n_layers to activate - layers = ast.literal_eval( - "self.trainer.model" + self.layers_attribute - ) # Re-fetch layer references + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) self.active_layers_indices = np.random.choice( range(self.total_layers), self.n_layers, replace=False ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c66ae70d43..5a927602ff 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -382,7 +382,7 @@ class LISAConfig(BaseModel): metadata={"help": "how often to switch layers in LISA"}, ) lisa_layers_attribute: Optional[str] = Field( - default="", + default="model.layers", metadata={"help": "path under the model to access the layers"}, ) From b357c93f2348b0d4a6710e06d64d606c9d3d4356 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 1 Apr 2024 04:54:00 +0000 Subject: [PATCH 3/5] improve lisa callback logging --- src/axolotl/utils/callbacks/lisa.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index 4df3225bbb..6509cd2790 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -6,6 +6,7 @@ License: Apache 2.0 """ +import logging from functools import reduce from typing import TYPE_CHECKING @@ -15,6 +16,8 @@ if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer +LOG = logging.getLogger("axolotl.callbacks") + def lisa_callback_factory(trainer: "AxolotlTrainer"): class LISACallback(TrainerCallback): @@ -34,16 +37,20 @@ def __init__( self.total_layers = len( reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) - self.freeze_all_layers() + self.freeze_all_layers(True) self.active_layers_indices = [] - def freeze_all_layers(self): + def freeze_all_layers(self, summarize=False): layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) for layer in layers: for param in layer.parameters(): param.requires_grad = False + if summarize: + LOG.info( + f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" + ) def on_step_begin( self, args, state, control, **kwargs @@ -63,7 +70,7 @@ def switch_active_layers(self): self.active_layers_indices = np.random.choice( range(self.total_layers), self.n_layers, replace=False ) - print( + LOG.info( f"Activating layers at indices: {self.active_layers_indices} for the next steps." ) From 6185cd522776bf62c406a677be50dab9d7aca7d3 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 1 Apr 2024 06:57:28 +0000 Subject: [PATCH 4/5] fix LISA by ensuring params are not frozen during __init__ --- src/axolotl/utils/callbacks/lisa.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index 6509cd2790..ff20959a59 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer -LOG = logging.getLogger("axolotl.callbacks") +LOG = logging.getLogger("axolotl.callbacks.lisa") def lisa_callback_factory(trainer: "AxolotlTrainer"): @@ -37,20 +37,22 @@ def __init__( self.total_layers = len( reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) - self.freeze_all_layers(True) self.active_layers_indices = [] - def freeze_all_layers(self, summarize=False): + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) + LOG.info( + f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" + ) + + def freeze_all_layers(self): layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) for layer in layers: for param in layer.parameters(): param.requires_grad = False - if summarize: - LOG.info( - f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" - ) def on_step_begin( self, args, state, control, **kwargs From 5dd9364c005493bec8b943ff2a438c2616a2513d Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 1 Apr 2024 07:00:59 +0000 Subject: [PATCH 5/5] example config for lisa --- examples/llama-2/lisa.yml | 75 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 examples/llama-2/lisa.yml diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml new file mode 100644 index 0000000000..e692c7ac1e --- /dev/null +++ b/examples/llama-2/lisa.yml @@ -0,0 +1,75 @@ +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./lisa-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: +lora_model_dir: +lora_r: +lora_alpha: +lora_dropout: +lora_target_linear: +lora_fan_in_fan_out: + +lisa_n_layers: 4 +lisa_step_interval: 20 +lisa_layers_attribute: model.layers + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 5e-5 # recommendation from lisa paper for 7b + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +flash_attn_cross_entropy: false +flash_attn_rms_norm: true +flash_attn_fuse_qkv: false +flash_attn_fuse_mlp: true + +warmup_steps: 100 +evals_per_epoch: 4 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.1 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: ""