diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml new file mode 100644 index 0000000000..e692c7ac1e --- /dev/null +++ b/examples/llama-2/lisa.yml @@ -0,0 +1,75 @@ +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./lisa-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: +lora_model_dir: +lora_r: +lora_alpha: +lora_dropout: +lora_target_linear: +lora_fan_in_fan_out: + +lisa_n_layers: 4 +lisa_step_interval: 20 +lisa_layers_attribute: model.layers + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 5e-5 # recommendation from lisa paper for 7b + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +flash_attn_cross_entropy: false +flash_attn_rms_norm: true +flash_attn_fuse_qkv: false +flash_attn_fuse_mlp: true + +warmup_steps: 100 +evals_per_epoch: 4 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.1 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4d85b40dee..cc7275184e 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -45,6 +45,7 @@ causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, ) +from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments): orpo_alpha: Optional[float] = field( default=None, ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) class AxolotlTrainer(Trainer): @@ -938,6 +951,8 @@ def get_post_trainer_create_callbacks(self, trainer): ) callbacks.append(early_stop_cb) + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + callbacks.append(lisa_callback_factory(trainer)) return callbacks def _get_trainer_cls(self): @@ -1229,6 +1244,15 @@ def build(self, total_num_steps): "relora_prune_ratio" ] = self.cfg.relora_prune_ratio + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers + training_arguments_kwargs[ + "lisa_step_interval" + ] = self.cfg.lisa_step_interval + training_arguments_kwargs[ + "lisa_layers_attribute" + ] = self.cfg.lisa_layers_attribute + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py new file mode 100644 index 0000000000..ff20959a59 --- /dev/null +++ b/src/axolotl/utils/callbacks/lisa.py @@ -0,0 +1,91 @@ +""" +module for LISA + +Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl +Arxiv: https://arxiv.org/abs/2403.17919 +License: Apache 2.0 +""" + +import logging +from functools import reduce +from typing import TYPE_CHECKING + +import numpy as np +from transformers import TrainerCallback + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainer + +LOG = logging.getLogger("axolotl.callbacks.lisa") + + +def lisa_callback_factory(trainer: "AxolotlTrainer"): + class LISACallback(TrainerCallback): + """trainer callback for lisa layer switching""" + + def __init__( + self, n_layers, step_interval, trainer, layers_attribute="model.layers" + ): + super().__init__() + self.n_layers = n_layers + self.step_interval = step_interval + self.layers_attribute = layers_attribute + self.trainer = trainer + + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) + + self.total_layers = len( + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) + ) + self.active_layers_indices = [] + + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) + LOG.info( + f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" + ) + + def freeze_all_layers(self): + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def on_step_begin( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # Check if it's time to switch active layers, including at step 0 + if state.global_step % self.step_interval == 0 or state.global_step == 1: + self.switch_active_layers() + + def switch_active_layers(self): + # First, disable gradients for all layers + self.freeze_all_layers() + + # Randomly select n_layers to activate + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) + self.active_layers_indices = np.random.choice( + range(self.total_layers), self.n_layers, replace=False + ) + LOG.info( + f"Activating layers at indices: {self.active_layers_indices} for the next steps." + ) + + # Enable gradients only for the selected layers + for idx in self.active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True + + lisa_callback = LISACallback( + n_layers=trainer.args.lisa_n_layers, + step_interval=trainer.args.lisa_step_interval, + trainer=trainer, + layers_attribute=trainer.args.lisa_layers_attribute, + ) + + return lisa_callback diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c07c0ff75a..5a927602ff 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel): hf_mlflow_log_artifacts: Optional[bool] = None +class LISAConfig(BaseModel): + """LISA options""" + + lisa_n_layers: Optional[int] = Field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = Field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = Field( + default="model.layers", + metadata={"help": "path under the model to access the layers"}, + ) + + class WandbConfig(BaseModel): """wandb configuration subset""" @@ -404,6 +421,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + LISAConfig, RemappedParameters, DeprecatedParameters, BaseModel,