Skip to content

Commit

Permalink
fix for model_type and add mixtral support too
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 8, 2024
1 parent 24340cd commit 716133c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 48 deletions.
74 changes: 74 additions & 0 deletions examples/mistral/mixtral-qlora-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out

model_config:
output_router_logits: true

adapter: qlora
lora_model_dir:

sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:
60 changes: 14 additions & 46 deletions src/axolotl/core/policies/auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,43 @@
import functools

from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch import nn
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
LlamaDecoderLayer,
LlamaMLP,
)
from transformers.models.mistral.modeling_mistral import (
MISTRAL_ATTENTION_CLASSES,
MistralDecoderLayer,
MistralMLP,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"mistral",
"llama",
"mistral",
"mixtral",
]


def get_wrapping_policy_factory(model_type):
if model_type == "llama":
attention_classes = LLAMA_ATTENTION_CLASSES
layer_to_wrap = LlamaDecoderLayer
model_mlp = LlamaMLP
elif model_type == "mistral":
attention_classes = MISTRAL_ATTENTION_CLASSES
layer_to_wrap = MistralDecoderLayer
model_mlp = MistralMLP
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer

def get_wrapping_policy(custom_policy: bool = False):
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
if custom_policy:

def lambda_policy_fn(module):
# LORA trainable layers.
return isinstance(module, nn.Sequential) and all(
m.weight.requires_grad for m in module
)

else:

def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)

def self_attn_policy_fn(module):
# Check module name is self_attn.
return isinstance(module, tuple(attention_classes.values()))

def mlp_policy_fn(module):
# Check module name is self_attn.
return isinstance(module, model_mlp)
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)

lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
self_attn_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn
)
mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
Expand All @@ -80,8 +50,6 @@ def mlp_policy_fn(module):
),
)
policies = [lambda_policy, transformer_wrap_policy]
if custom_policy:
policies.extend([self_attn_policy, mlp_policy])
return functools.partial(_or_policy, policies=policies)

return get_wrapping_policy
4 changes: 2 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ def create_accelerator_and_postprocess(self):
# load_param_skip_names = ['inv_freq']

if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.model.config.model_type)
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(False),
auto_wrap_policy=wrapping_policy(),
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
Expand Down

0 comments on commit 716133c

Please sign in to comment.