From 716133c35b8342fe4976175db840ffcf0e1574a4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 13:23:14 -0500 Subject: [PATCH] fix for model_type and add mixtral support too --- examples/mistral/mixtral-qlora-fsdp.yml | 74 +++++++++++++++++++++++++ src/axolotl/core/policies/auto_wrap.py | 60 +++++--------------- src/axolotl/core/trainer_builder.py | 4 +- 3 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 examples/mistral/mixtral-qlora-fsdp.yml diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml new file mode 100644 index 0000000000..32db7073b7 --- /dev/null +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -0,0 +1,74 @@ +base_model: mistralai/Mixtral-8x7B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./qlora-out + +model_config: + output_router_logits: true + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock +special_tokens: diff --git a/src/axolotl/core/policies/auto_wrap.py b/src/axolotl/core/policies/auto_wrap.py index 5d49facc2b..d42b62ee08 100644 --- a/src/axolotl/core/policies/auto_wrap.py +++ b/src/axolotl/core/policies/auto_wrap.py @@ -2,73 +2,43 @@ import functools from peft import PrefixEncoder, PromptEmbedding, PromptEncoder -from torch import nn from torch.distributed.fsdp.wrap import ( _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy, ) -from transformers.models.llama.modeling_llama import ( - LLAMA_ATTENTION_CLASSES, - LlamaDecoderLayer, - LlamaMLP, -) -from transformers.models.mistral.modeling_mistral import ( - MISTRAL_ATTENTION_CLASSES, - MistralDecoderLayer, - MistralMLP, -) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ - "mistral", "llama", + "mistral", + "mixtral", ] def get_wrapping_policy_factory(model_type): if model_type == "llama": - attention_classes = LLAMA_ATTENTION_CLASSES layer_to_wrap = LlamaDecoderLayer - model_mlp = LlamaMLP elif model_type == "mistral": - attention_classes = MISTRAL_ATTENTION_CLASSES layer_to_wrap = MistralDecoderLayer - model_mlp = MistralMLP + elif model_type == "mixtral": + layer_to_wrap = MixtralDecoderLayer - def get_wrapping_policy(custom_policy: bool = False): + def get_wrapping_policy(): """This checks for lora layers (has weight and requires_grad)""" - if custom_policy: - - def lambda_policy_fn(module): - # LORA trainable layers. - return isinstance(module, nn.Sequential) and all( - m.weight.requires_grad for m in module - ) - else: - - def lambda_policy_fn(module): - return ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ) - - def self_attn_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, tuple(attention_classes.values())) - - def mlp_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, model_mlp) + def lambda_policy_fn(module): + return ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) lambda_policy = functools.partial( lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn ) - self_attn_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn - ) - mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn) transformer_layer_name = layer_to_wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -80,8 +50,6 @@ def mlp_policy_fn(module): ), ) policies = [lambda_policy, transformer_wrap_policy] - if custom_policy: - policies.extend([self_attn_policy, mlp_policy]) return functools.partial(_or_policy, policies=policies) return get_wrapping_policy diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b28c385728..990d814d9f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -501,9 +501,9 @@ def create_accelerator_and_postprocess(self): # load_param_skip_names = ['inv_freq'] if self.is_fsdp_enabled: - wrapping_policy = get_wrapping_policy_factory(self.model.config.model_type) + wrapping_policy = get_wrapping_policy_factory(self.args.model_type) fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=wrapping_policy(False), + auto_wrap_policy=wrapping_policy(), use_orig_params=False, limit_all_gathers=True, param_init_fn=lambda module: module.to_empty(