Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DBRX Model Support #1462

Merged
merged 15 commits into from
Apr 12, 2024
Merged

DBRX Model Support #1462

merged 15 commits into from
Apr 12, 2024

Conversation

winglian
Copy link
Collaborator

@winglian winglian commented Mar 30, 2024

DBRX MoE

Currently, for LoRA, only the q_proj, k_proj, v_proj out_proj and layer Linear layers are trainable.

We are using the "converted" base models based on this issue
where the Experts are fused as an nn.Parameter rather than a nn.Linear layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those w1, w2 and v1 layers
results in the trainer hanging.

FSDP

We've tested using the LnL-AI/dbrx-base-converted-v2 model as the base model for FSDP.

The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.

  • 16-bit LoRA w/ FSDP
    • ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
    • ❌ w/ CPU Offload - paged_adamw_8bit optimizer errors from being on cpu
  • ✅ 8-bit LoRA w/ FSDP
  • ❌ 4-bit QLoRA w/ FSDP - errors w/: Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu
  • ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)

Deepspeed

WIP

@Qubitium
Copy link

@winglian https://huggingface.co/LnL-AI/dbrx-base-converted-v2 is up with potential better quant compat due to split q,k,v layers. If possible, can you test validate its inference quality vs original? I ran out of gpu to test base inf sanity since all my gpus are maxed out testing training on it.

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Mar 30, 2024

Could we also add the above to the readme support matrix for easy viewing?


Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu

I wonder if this is due to bnb.

@NanoCode012 NanoCode012 linked an issue Mar 30, 2024 that may be closed by this pull request
5 tasks
@NanoCode012 NanoCode012 mentioned this pull request Mar 30, 2024
5 tasks
@erenup
Copy link

erenup commented Mar 30, 2024

I will test the 8bit-lora and 16-bit lora.

@erenup
Copy link

erenup commented Mar 30, 2024

I am testing examples/dbrx/8bit-lora.yaml and will update when I get the results.

return dist.is_available() and dist.is_initialized()
global distributed_state # pylint: disable=global-statement
if not distributed_state:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be good to document this env

def n_loading_workers(quant_method: str, param_count: float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
model_params_b = 70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be moved to a function param instead of hardcode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Not sure what to do about this atm. Even the answer.ai fsdp-qlora example has this hardcoded. Not sure there is a good way to get the number of parameters in the model before we actually load the model.

elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we did not have this. What is the effect of this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a warning currently in transformers that passing load_in_8bit will be deprecated soon and that we should use the quantization config instead.

@winglian
Copy link
Collaborator Author

winglian commented Apr 1, 2024

using offload_param for deepspeed zero3, from #1466, per gpu VRAM utilization is ~50GB/gpu @ batch size 1

@Qubitium
Copy link

Qubitium commented Apr 3, 2024

@winglian Please test using tokenizer from https://huggingface.co/LnL-AI/dbrx-base-tokenizer and not the tiktoken one from dbrx which has several problems. It resolves 3 issues I found which negatively affect training:

  1. eos == pad (many models do this but this is not good). If this was good, why did they use distinct pad_token for instruct
  2. ??? original padded incorrect vocab size to correct size with <|extra_id_N> tokens but encoder not updated so you can't use them.
  3. config.json vocab size != tokenizer.vocab_size 100352 vs 100277

My tokenizer is based of the one create by hf staff Xenova @ (https://huggingface.co/Xenova/dbrx-instruct-tokenizer). I am trying to validate this with him to see if the tokenizer is 100% encode/decode compatible. https://huggingface.co/Xenova/dbrx-instruct-tokenizer/discussions/1 . We are also doing to do some internal testing on this.

My changes:

  1. Remove non-base model tokens
  2. Keep/Add <|pad|> special token to make sure padding can be differentiated from eos/bos.
  3. Expose 15 unused/reserved <|extra_N|> for use
  4. [NOT FIXED] config.json vocab size != tokenizer.vocab_size 100352 vs 100277
# pad token
 "100256": {
      "content": "<|pad|>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },

# 15 unused/reserved extra tokens
"<|extra_0|>": 100261
"<|extra_1|>": 100262
...
"<|extra_14|>": 100275

EDIT: removed wrong attribution that len(tokenizer) != tokenizer.vocab_size

@NanoCode012
Copy link
Collaborator

@Qubitium , hey, was wondering whether that new tokenizer vocab has the same size as the model embed size or whether that needs to be resized as well?

@Qubitium
Copy link

Qubitium commented Apr 3, 2024

@Qubitium , hey, was wondering whether that new tokenizer vocab has the same size as the model embed size or whether that needs to be resized as well?

It's the same size. I did not add any new tokens beyond the original embed size.

@Qubitium
Copy link

Qubitium commented Apr 3, 2024

@NanoCode012 Correction. If you believe the embed size "vocab_size": 100352 from config.json then we are smaller < embed size. But BPE based tokenizer should not be resized post train. My tokenizer size is same as original base tokenizer.

EDIT: I am unsure about the resize too now.

EDIT: removed wrong attribution that len(tokenizer) != tokenizer.vocab_size in original tokenizer. So issue is just the extra tokens never exposed to encoder and pad token == eos

@winglian
Copy link
Collaborator Author

winglian commented Apr 4, 2024

@NanoCode012 Correction. If you believe the embed size "vocab_size": 100352 from config.json then we are smaller < embed size. But BPE based tokenizer should not be resized post train. My tokenizer size is same as original base tokenizer.

EDIT: I am unsure about the resize too now.

EDIT: removed wrong attribution that len(tokenizer) != tokenizer.vocab_size in original tokenizer. So issue is just the extra tokens never exposed to encoder and pad token == eos

Is there anything we need to review from here? I'm hoping to get this merged today if possible. Even if it's only preliminary

@NanoCode012
Copy link
Collaborator

@winglian , I think this can be merged first. The only consequence of the above comment I believe is that, for adapter training, the embed_len and lm_head needs to be targeted due to resize.

@winglian winglian merged commit 132eb74 into main Apr 12, 2024
7 checks passed
@winglian winglian deleted the dbrx branch April 12, 2024 13:02
@daje0601
Copy link

When fsdp'ing with meta-llama/Meta-Llama-3.1-70B-Instruct, I'm using the error ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32.

So I tested it with meta-llama/Meta-Llama-3-70B-Instruct to see if it trains normally, and it does.

How can you resolve this error? Please give me hint!

We'll share the error, config, and training code for your reference.

error message

    trainer.train(resume_from_checkpoint=checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2085, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1274, in prepare
    result = tuple(
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1275, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1151, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1434, in prepare_model
    model = FSDP(model, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in __init__
    _auto_wrap(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 573, in __init__
    self._init_flat_param_and_metadata(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 761, in _validate_tensors_to_flatten
    raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[2024-09-10 04:25:43,508] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 29390) of binary: /usr/bin/python

config

dataset_path: ""
max_seq_length: 2048
output_dir: "./model/20240909_model"
report_to: "wandb"
learning_rate: 0.00005
lr_scheduler_type: "constant"
num_train_epochs: 2
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
gradient_accumulation_steps: 8
optim: "adamw_torch"
logging_steps: 10
save_strategy: "epoch"
max_grad_norm: 0.3
warmup_ratio: 0.03
bf16: true
tf32: true
gradient_checkpointing: true
fsdp: "full_shard auto_wrap"
fsdp_config:
  backward_prefetch: "backward_pre"
  forward_prefetch: "false"
  use_orig_params: "false"

training

from dataclasses import dataclass, field
import os
import json 
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import  TrlParser
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
        set_seed,

)
from trl import setup_chat_format
from peft import LoraConfig


from trl import (
   SFTTrainer)

from huggingface_hub import login
login(
    token="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",
    add_to_git_credential=True
)

@dataclass
class ScriptArguments:
    dataset_path: str = field(
        default=None,
        metadata={
            "help": "Path to the dataset"
        },
    )
    model_id: str = field(
        default=None, metadata={"help": "Model ID to use for SFT training"}
    )
    max_seq_length: int = field(
        default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}
    )


def training_function(script_args, training_args):
    ################
    # Dataset
    ################

    train_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.dataset_path, "chat_template_train.jsonl"),
        split="train",
    )
    test_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.dataset_path, "chat_template_test.jsonl"),
        split="train",
    )

    ################
    # Model & Tokenizer
    ################

    # Tokenizer        
    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
    tokenizer.model_max_length = script_args.max_seq_length
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    # template dataset
    def template_dataset(examples):
        messages = json.loads(examples["text"])["messages"]
        return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}

    train_dataset = train_dataset.map(template_dataset, remove_columns=["text"])
    test_dataset = test_dataset.map(template_dataset, remove_columns=["text"])
    # print random sample
    with training_args.main_process_first(
        desc="Log a few random samples from the processed training set"
    ):
        for index in random.sample(range(len(train_dataset)), 2):
            print(train_dataset[index]["text"])

    # Model    
    torch_dtype = torch.bfloat16
    quant_storage_dtype = torch.bfloat16

    quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_id,
        quantization_config=quantization_config,
        attn_implementation="sdpa", # use sdpa, alternatively use "flash_attention_2"
        torch_dtype=quant_storage_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,  # this is needed for gradient checkpointing
        low_cpu_mem_usage=True,
    )

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    ################
    # PEFT
    ################

    peft_config = LoraConfig(
        lora_alpha=128,
        lora_dropout=0.05,
        r=256,
        bias="none",
        target_modules=[
            "q_proj",
            # "up_proj",
            "o_proj",
            "k_proj",
            # "down_proj",
            # "gate_proj",
            "v_proj",
        ],
        task_type="CAUSAL_LM",
        # modules_to_save = ["lm_head", "embed_tokens"] # add if you want to use the Llama 3 instruct template
    )

    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        dataset_text_field="text",
        eval_dataset=test_dataset,
        peft_config=peft_config,
        max_seq_length=script_args.max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        dataset_kwargs={
            "add_special_tokens": False,  # We template with special tokens
            "append_concat_token": False,  # No need to add additional separator token
        },
    )
    if trainer.accelerator.is_main_process:
        trainer.model.print_trainable_parameters()

    ##########################
    # Train model
    ##########################
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    trainer.train(resume_from_checkpoint=checkpoint)

    ##########################
    # SAVE MODEL FOR SAGEMAKER
    ##########################
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    trainer.save_model()
    
if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, TrainingArguments))
    script_args, training_args = parser.parse_args_and_config()    
    
    # set use reentrant to False
    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
    # set seed
    set_seed(training_args.seed)
  
    # launch training
    training_function(script_args, training_args)

djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* wip for dbrx finetuning

* add fastcore for parallel loading of sharded weights

* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback

* update to use v2 of the converted model

* more fixes for dbrx loras

* make sure to enable fsdp activation checkpointing

* fix support for 8bit loras too for dbrx

* apply z3 leaf moe fix for DBRX with deepspeed

* don't raise value error since child module searches could fail and be ok

* revert a previous change to fix fsdp

* update mistral/mistral qlora+fsdp yamls

* fix qlora+fsdp quant storage type

* more edge cases for qlora-fsdp

* fixes for fsdp+qlora w optimizer in 8bit

* add bigstral z3 config and make sure to use full_state_dict for fsdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DBRX training
5 participants