diff --git a/examples/llama-2/ia3.yml b/examples/llama-2/ia3.yml new file mode 100644 index 0000000000..a914a91796 --- /dev/null +++ b/examples/llama-2/ia3.yml @@ -0,0 +1,72 @@ +base_model: meta-llama/Llama-2-7b-hf +base_model_config: meta-llama/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./ia3-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: ia3 +ia3_model_dir: +ia3_target_modules: + - k_proj + - v_proj + - down_proj +ia3_feedforward_modules: + - down_proj +ia3_fan_in_fan_out: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 2 +num_epochs: 5 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 0.05 +eval_table_size: +eval_table_max_new_tokens: +save_steps: +debug: +deepspeed: +weight_decay: 0.1 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 36607f2a2d..c8cdbeb849 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -430,6 +430,8 @@ def load_adapter(model, cfg, adapter, inference=False): return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() + if adapter in ["ia3"]: + return load_ia3(model, cfg, inference=inference) if adapter in ["lora", "qlora"]: return load_lora(model, cfg, inference=inference) if adapter == "llama-adapter": @@ -513,3 +515,36 @@ def load_lora(model, cfg, inference=False): model.print_trainable_parameters() return model, lora_config + + +def load_ia3(model, cfg, inference=False): + # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + + from peft import IA3Config, PeftModel, get_peft_model + + ia3_config_kwargs = {} + if cfg.ia3_init_ia3_weights is not None: + ia3_config_kwargs["init_ia3_weights"] = cfg.ia3_init_ia3_weights + if cfg.ia3_fan_in_fan_out is not None: + ia3_config_kwargs["fan_in_fan_out"] = cfg.ia3_fan_in_fan_out + + ia3_config = IA3Config( + target_modules=cfg.ia3_target_modules, + feedforward_modules=cfg.ia3_feedforward_modules, + modules_to_save=cfg.ia3_modules_to_save, + **ia3_config_kwargs, + ) + + if cfg.ia3_model_dir: + LOG.debug("Loading pretained PEFT - IA3") + model = PeftModel.from_pretrained( + model, + cfg.ia3_model_dir, + is_trainable=(not inference), + ) + else: + model = get_peft_model(model, ia3_config) + + model.print_trainable_parameters() + + return model, ia3_config