Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP with SFTrainer: expected dtype float for end but got dtype c10::BFloat16 #34702

Closed
1 of 4 tasks
asc-raynor opened this issue Nov 12, 2024 · 11 comments
Closed
1 of 4 tasks
Labels

Comments

@asc-raynor
Copy link

System Info

pytorch 2.2 and 2.4 are tested.
transformers 4.46.2
4 * A6000 ada

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

fsdp training code from 'https://huggingface.co/docs/peft/accelerate/fsdp'
but got expected dtype float for end but got dtype c10::BFloat16 error.
I changed dtype (float16, 32, bfloat16) but failed to run the code.
What`s the problem?

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from transformers import HfArgumentParser, set_seed
from trl import SFTConfig, SFTTrainer
from src.utils import create_and_prepare_model, create_datasets


# Define and parse arguments.
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/configs/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    chat_template_format: Optional[str] = field(
        default="none",
        metadata={
            "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
        },
    )
    lora_alpha: Optional[int] = field(default=16)
    lora_dropout: Optional[float] = field(default=0.1)
    lora_r: Optional[int] = field(default=64)
    lora_target_modules: Optional[str] = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
    )
    use_nested_quant: Optional[bool] = field(
        default=False,
        metadata={"help": "Activate nested quantization for 4bit base models"},
    )
    bnb_4bit_compute_dtype: Optional[str] = field(
        default="float16",
        metadata={"help": "Compute dtype for 4bit base models"},
    )
    bnb_4bit_quant_storage_dtype: Optional[str] = field(
        default="uint8",
        metadata={"help": "Quantization storage dtype for 4bit base models"},
    )
    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4",
        metadata={"help": "Quantization type fp4 or nf4"},
    )
    use_flash_attn: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash attention for training."},
    )
    use_peft_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables PEFT LoRA for training."},
    )
    use_8bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 8bit."},
    )
    use_4bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 4bit."},
    )
    use_reentrant: Optional[bool] = field(
        default=False,
        metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
    )
    use_unsloth: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables UnSloth for training."},
    )


@dataclass
class DataTrainingArguments:
    dataset_name: Optional[str] = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )
    append_concat_token: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
    )
    add_special_tokens: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
    )
    splits: Optional[str] = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )


def main(model_args, data_args, training_args):
    # Set seed for reproducibility
    set_seed(training_args.seed)

    # model
    model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)

    # gradient ckpt
    model.config.use_cache = not training_args.gradient_checkpointing
    training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}

    training_args.dataset_kwargs = {
        "append_concat_token": data_args.append_concat_token,
        "add_special_tokens": data_args.add_special_tokens,
    }

    # datasets
    train_dataset, eval_dataset = create_datasets(
        tokenizer,
        data_args,
        training_args,
        apply_chat_template=model_args.chat_template_format != "none",
    )

    # trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
    )
    trainer.accelerator.print(f"{trainer.model}")
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

    # train
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    trainer.train(resume_from_checkpoint=checkpoint)

    # saving final model
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    trainer.save_model()


if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py#L151
    # training_args.optim = "adamw_torch_4bit"

    main(model_args, data_args, training_args)
import os
from enum import Enum

import packaging.version
import torch
import transformers
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

from peft import LoraConfig


DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"


class ZephyrSpecialTokens(str, Enum):
    user = "<|user|>"
    assistant = "<|assistant|>"
    system = "<|system|>"
    eos_token = "</s>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]


class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]


def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):
    def preprocess(samples):
        batch = []
        for conversation in samples["messages"]:
            batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))
        return {"content": batch}

    raw_datasets = DatasetDict()
    for split in data_args.splits.split(","):
        try:
            # Try first if dataset on a Hub repo
            dataset = load_dataset(data_args.dataset_name, split=split)
        except DatasetGenerationError:
            # If not, check local dataset
            dataset = load_from_disk(os.path.join(data_args.dataset_name, split))

        if "train" in split:
            raw_datasets["train"] = dataset
        elif "test" in split:
            raw_datasets["test"] = dataset
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")

    if apply_chat_template:
        raw_datasets = raw_datasets.map(
            preprocess,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
        )

    train_data = raw_datasets["train"]
    valid_data = raw_datasets["test"]
    print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
    print(f"A sample of train dataset: {train_data[0]}")

    return train_data, valid_data


def create_and_prepare_model(args, data_args, training_args):
    if args.use_unsloth:
        from unsloth import FastLanguageModel
    bnb_config = None
    quant_storage_dtype = None

    if (
        torch.distributed.is_available()
        and torch.distributed.is_initialized()
        and torch.distributed.get_world_size() > 1
        and args.use_unsloth
    ):
        raise NotImplementedError("Unsloth is not supported in distributed training")

    if args.use_4bit_quantization:
        compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
        quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.use_4bit_quantization,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=args.use_nested_quant,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )

        if compute_dtype == torch.float16 and args.use_4bit_quantization:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                print("=" * 80)
                print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
                print("=" * 80)
        elif args.use_8bit_quantization:
            bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)

    if args.use_unsloth:
        # Load model
        model, _ = FastLanguageModel.from_pretrained(
            model_name=args.model_name_or_path,
            max_seq_length=data_args.max_seq_length,
            dtype=None,
            load_in_4bit=args.use_4bit_quantization,
        )
    else:
        torch_dtype = (
            quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
        )
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            quantization_config=bnb_config,
            trust_remote_code=True,
            attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
            torch_dtype=torch_dtype,
        )

    peft_config = None
    chat_template = None
    if args.use_peft_lora and not args.use_unsloth:
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
        )

    special_tokens = None
    chat_template = None
    if args.chat_template_format == "chatml":
        special_tokens = ChatmlSpecialTokens
        chat_template = DEFAULT_CHATML_CHAT_TEMPLATE
    elif args.chat_template_format == "zephyr":
        special_tokens = ZephyrSpecialTokens
        chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE

    if special_tokens is not None:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path,
            pad_token=special_tokens.pad_token.value,
            bos_token=special_tokens.bos_token.value,
            eos_token=special_tokens.eos_token.value,
            additional_special_tokens=special_tokens.list(),
            trust_remote_code=True,
        )
        tokenizer.chat_template = chat_template

        # make embedding resizing configurable?
        # Transformers 4.46.0+ defaults uses mean_resizing by default, which fails with QLoRA + FSDP because the
        # embedding could be on meta device, therefore, we set mean_resizing=False in that case (i.e. the status quo
        # ante). See https://github.com/huggingface/accelerate/issues/1620.
        uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
        uses_fsdp = os.environ.get("ACCELERATE_USE_FSDP").lower() == "true"
        if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46:
            model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False)
        else:
            model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token

    if args.use_unsloth:
        # Do model patching and add fast LoRA weights
        model = FastLanguageModel.get_peft_model(
            model,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
            use_gradient_checkpointing=training_args.gradient_checkpointing,
            random_state=training_args.seed,
            max_seq_length=data_args.max_seq_length,
        )

    return model, peft_config, tokenizer

param:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate launch --config_file "configs/fsdp_config.yaml"  train.py \
--seed 100 \
--model_name_or_path "meta-llama/Llama-3.2-1B-Instruct" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--eval_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "llama-sft-qlora-fsdp" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"

Expected behavior

FSDP training

@asc-raynor asc-raynor added the bug label Nov 12, 2024
@LysandreJik
Copy link
Member

Thanks all for the report and sorry for the delay, we're looking into it cc @muellerzr @SunMarc

@alexdauenhauer
Copy link

alexdauenhauer commented Nov 15, 2024

same issue, but with DPOTrainer (probably I also have it with SFTTrainer, but haven't tested). The error only occurs for me in multi-worker/multi-gpu/multi-node training, when using FSDP with single GPU there is no error. The issue also is not present in 4.45.2. I am wondering if it is due to this change?

v4.45.2 (in mistral_modeling)

hidden_states = outputs[0]
if labels is None and not is_torchdynamo_compiling():
    logger.warning_once(
        "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
    )
    
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()

in v4.46.2

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

my traceback looks like this

  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/working_dir_files/_ray_pkg_92dffa2da1edbd43/fine_tune/main.py", line 87, in train_func
    trainer.train()
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
    self.optimizer.step()
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/accelerate/optimizer.py", line 171, in step
    self.optimizer.step(closure)
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/tmp/ray/session_2024-11-13_12-57-50_682472_12/runtime_resources/pip/885b4123dae986bae1106a4662ccedcbc5ae220d/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 375, in _single_tensor_adamw
    exp_avg.lerp_(grad, 1 - beta1)
RuntimeError: expected dtype float for `end` but got dtype c10::BFloat16

@asc-raynor
Copy link
Author

The latest version of TRL (0.12.0) seems to have some issues, but version 0.11.3 works fine.

@benjamin-marie
Copy link

Same issue. I can't use FSDP with TRL anymore. Everything works again if I downgrade Accelerate+Transformers+TRL as if we were in September.
Not related to Pytorch (I tried from 2.1 to 2.6).
It might be related to something introduced in Transformers 4.46.2. To be confirmed.

@muellerzr
Copy link
Contributor

Hi! Thanks for the bug report. This should be fixed via #34645, can you install transformers via pip install git+https://github.com/huggingface/transformers? Thanks for your patience while we figure out ripple effects from the grad accum changes 🤗

@benjamin-marie
Copy link

I confirm that it works with the most recent version of Transformers (already available through pip).

@alexdauenhauer
Copy link

@muellerzr I am still seeing this error when using transformers 4.46.3 and trl 0.12.1, but it only happens occassionally. I had a training run with 351 steps, it made it through 172 steps and I got this error on step 173. I have tried with both the SFTTrainer and DPOTrainer

@alexdauenhauer
Copy link

@muellerzr it appears to be happening at the exactly halfway through max_steps. Looking at the fix in the PR you linked I can't understand this behavior, but thought it might provide a clue

@benjamin-marie
Copy link

I don't have this issue anymore on my side. Could you provide the traceback?

@alexdauenhauer
Copy link

alexdauenhauer commented Dec 3, 2024

@benjamin-marie thanks, here is my traceback

  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
    self.optimizer.step()
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/accelerate/optimizer.py", line 171, in step
    self.optimizer.step(closure)
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 227, in step
    adamw(
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 161, in maybe_fallback
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 767, in adamw
    func(
  File "/tmp/ray/session_2024-12-02_11-03-27_501367_12/runtime_resources/pip/74b671a31be4649681b5b250a141caa5a98ab328/virtualenv/lib/python3.11/site-packages/torch/optim/adamw.py", line 380, in _single_tensor_adamw
    exp_avg.lerp_(grad, 1 - beta1)
RuntimeError: expected dtype float for `end` but got dtype c10::BFloat16

I can't understand why it can successfully complete 175/351 steps, but then fails. I have tried with different datasets, both using SFT and DPO from trl and it always fails at the halfway step

@alexdauenhauer
Copy link

@benjamin-marie I found out how this happened for me. We run our training in a multi-node setup and I was incorrectly calculating the per_device_train_batch_size as well as the max_steps based on my desired total_batch_size and total size of my dataset leading to weird calculations of epochs and update steps in _inner_training_loop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants