diff --git a/examples/llama-2/ia3.yml b/examples/llama-2/ia3.yml
new file mode 100644
index 0000000000..a914a91796
--- /dev/null
+++ b/examples/llama-2/ia3.yml
@@ -0,0 +1,72 @@
+base_model: meta-llama/Llama-2-7b-hf
+base_model_config: meta-llama/Llama-2-7b-hf
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./ia3-out
+
+sequence_len: 4096
+sample_packing: true
+pad_to_sequence_len: true
+
+adapter: ia3
+ia3_model_dir:
+ia3_target_modules:
+ - k_proj
+ - v_proj
+ - down_proj
+ia3_feedforward_modules:
+ - down_proj
+ia3_fan_in_fan_out: false
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 1
+micro_batch_size: 2
+num_epochs: 5
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 0.05
+eval_table_size:
+eval_table_max_new_tokens:
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.1
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 36607f2a2d..c8cdbeb849 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -430,6 +430,8 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
+ if adapter in ["ia3"]:
+ return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
@@ -513,3 +515,36 @@ def load_lora(model, cfg, inference=False):
model.print_trainable_parameters()
return model, lora_config
+
+
+def load_ia3(model, cfg, inference=False):
+ # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
+
+ from peft import IA3Config, PeftModel, get_peft_model
+
+ ia3_config_kwargs = {}
+ if cfg.ia3_init_ia3_weights is not None:
+ ia3_config_kwargs["init_ia3_weights"] = cfg.ia3_init_ia3_weights
+ if cfg.ia3_fan_in_fan_out is not None:
+ ia3_config_kwargs["fan_in_fan_out"] = cfg.ia3_fan_in_fan_out
+
+ ia3_config = IA3Config(
+ target_modules=cfg.ia3_target_modules,
+ feedforward_modules=cfg.ia3_feedforward_modules,
+ modules_to_save=cfg.ia3_modules_to_save,
+ **ia3_config_kwargs,
+ )
+
+ if cfg.ia3_model_dir:
+ LOG.debug("Loading pretained PEFT - IA3")
+ model = PeftModel.from_pretrained(
+ model,
+ cfg.ia3_model_dir,
+ is_trainable=(not inference),
+ )
+ else:
+ model = get_peft_model(model, ia3_config)
+
+ model.print_trainable_parameters()
+
+ return model, ia3_config