diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml
new file mode 100644
index 0000000000..e692c7ac1e
--- /dev/null
+++ b/examples/llama-2/lisa.yml
@@ -0,0 +1,75 @@
+base_model: NousResearch/Llama-2-7b-hf
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.05
+output_dir: ./lisa-out
+
+sequence_len: 4096
+sample_packing: true
+pad_to_sequence_len: true
+
+adapter:
+lora_model_dir:
+lora_r:
+lora_alpha:
+lora_dropout:
+lora_target_linear:
+lora_fan_in_fan_out:
+
+lisa_n_layers: 4
+lisa_step_interval: 20
+lisa_layers_attribute: model.layers
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 2
+micro_batch_size: 1
+num_epochs: 1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 5e-5 # recommendation from lisa paper for 7b
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+flash_attn_cross_entropy: false
+flash_attn_rms_norm: true
+flash_attn_fuse_qkv: false
+flash_attn_fuse_mlp: true
+
+warmup_steps: 100
+evals_per_epoch: 4
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.1
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index 4d85b40dee..cc7275184e 100644
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -45,6 +45,7 @@
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
+from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments):
orpo_alpha: Optional[float] = field(
default=None,
)
+ lisa_n_layers: Optional[int] = field(
+ default=None,
+ metadata={"help": "the number of activate layers in LISA"},
+ )
+ lisa_step_interval: Optional[int] = field(
+ default=None,
+ metadata={"help": "how often to switch layers in LISA"},
+ )
+ lisa_layers_attribute: Optional[str] = field(
+ default=None,
+ metadata={"help": "path under the model to access the layers"},
+ )
class AxolotlTrainer(Trainer):
@@ -938,6 +951,8 @@ def get_post_trainer_create_callbacks(self, trainer):
)
callbacks.append(early_stop_cb)
+ if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
+ callbacks.append(lisa_callback_factory(trainer))
return callbacks
def _get_trainer_cls(self):
@@ -1229,6 +1244,15 @@ def build(self, total_num_steps):
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
+ if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
+ training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
+ training_arguments_kwargs[
+ "lisa_step_interval"
+ ] = self.cfg.lisa_step_interval
+ training_arguments_kwargs[
+ "lisa_layers_attribute"
+ ] = self.cfg.lisa_layers_attribute
+
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py
new file mode 100644
index 0000000000..ff20959a59
--- /dev/null
+++ b/src/axolotl/utils/callbacks/lisa.py
@@ -0,0 +1,91 @@
+"""
+module for LISA
+
+Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
+Arxiv: https://arxiv.org/abs/2403.17919
+License: Apache 2.0
+"""
+
+import logging
+from functools import reduce
+from typing import TYPE_CHECKING
+
+import numpy as np
+from transformers import TrainerCallback
+
+if TYPE_CHECKING:
+ from axolotl.core.trainer_builder import AxolotlTrainer
+
+LOG = logging.getLogger("axolotl.callbacks.lisa")
+
+
+def lisa_callback_factory(trainer: "AxolotlTrainer"):
+ class LISACallback(TrainerCallback):
+ """trainer callback for lisa layer switching"""
+
+ def __init__(
+ self, n_layers, step_interval, trainer, layers_attribute="model.layers"
+ ):
+ super().__init__()
+ self.n_layers = n_layers
+ self.step_interval = step_interval
+ self.layers_attribute = layers_attribute
+ self.trainer = trainer
+
+ reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
+
+ self.total_layers = len(
+ reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
+ )
+ self.active_layers_indices = []
+
+ layers = reduce(
+ getattr, self.layers_attribute.split("."), self.trainer.model
+ )
+ LOG.info(
+ f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
+ )
+
+ def freeze_all_layers(self):
+ layers = reduce(
+ getattr, self.layers_attribute.split("."), self.trainer.model
+ )
+ for layer in layers:
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def on_step_begin(
+ self, args, state, control, **kwargs
+ ): # pylint: disable=unused-argument
+ # Check if it's time to switch active layers, including at step 0
+ if state.global_step % self.step_interval == 0 or state.global_step == 1:
+ self.switch_active_layers()
+
+ def switch_active_layers(self):
+ # First, disable gradients for all layers
+ self.freeze_all_layers()
+
+ # Randomly select n_layers to activate
+ layers = reduce(
+ getattr, self.layers_attribute.split("."), self.trainer.model
+ )
+ self.active_layers_indices = np.random.choice(
+ range(self.total_layers), self.n_layers, replace=False
+ )
+ LOG.info(
+ f"Activating layers at indices: {self.active_layers_indices} for the next steps."
+ )
+
+ # Enable gradients only for the selected layers
+ for idx in self.active_layers_indices:
+ for param in layers[idx].parameters():
+ param.requires_grad = True
+
+ lisa_callback = LISACallback(
+ n_layers=trainer.args.lisa_n_layers,
+ step_interval=trainer.args.lisa_step_interval,
+ trainer=trainer,
+ layers_attribute=trainer.args.lisa_layers_attribute,
+ )
+
+ return lisa_callback
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index c07c0ff75a..5a927602ff 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel):
hf_mlflow_log_artifacts: Optional[bool] = None
+class LISAConfig(BaseModel):
+ """LISA options"""
+
+ lisa_n_layers: Optional[int] = Field(
+ default=None,
+ metadata={"help": "the number of activate layers in LISA"},
+ )
+ lisa_step_interval: Optional[int] = Field(
+ default=None,
+ metadata={"help": "how often to switch layers in LISA"},
+ )
+ lisa_layers_attribute: Optional[str] = Field(
+ default="model.layers",
+ metadata={"help": "path under the model to access the layers"},
+ )
+
+
class WandbConfig(BaseModel):
"""wandb configuration subset"""
@@ -404,6 +421,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
+ LISAConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,