Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORPO results in Cannot flatten integer dtype tensors #1838

Open
6 of 8 tasks
maziyarpanahi opened this issue Aug 20, 2024 · 5 comments
Open
6 of 8 tasks

ORPO results in Cannot flatten integer dtype tensors #1838

maziyarpanahi opened this issue Aug 20, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@maziyarpanahi
Copy link
Contributor

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I expect the ORPO works properly with FSDP and DeepSpeed on Qwen2 models.

Current behaviour

Currently, it's not possible to use ORPO via FSDP or DeepSpeed. It results in

Possible issues:

  • LoRA / QLoRA
  • FSDP
  • DeepSpeed
  • ORPO
  • Qwen2
  • Or a combination of these causes this issue
 File "/workspace/axolotl/src/axolotl/cli/train.py", line 67, in do_train
          [Previous line repeated 2 more times]
component_trace = _Fire(component, args, parsed_flag_args, context, name)wrapped_child, num_wrapped_params = _recursive_wrap(  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, 
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1181, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1477, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 463, in __init__
    _auto_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 484, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 487, in __init__
    _init_param_handle_from_module(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 519, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 537, in __init__
    self._init_flat_param_and_metadata(
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 585, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors

Steps to reproduce

  • launch a RunPod instance with 8x A100 or H100 with 640G memory (or 320G)
  • choose winglian/axolotl-runpod:main-latest template
  • follow these steps (checkout 608a2f3 commit to avoid FSDP issue with the latest changes)
rm -rf axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl && \
cd axolotl && \
git checkout 608a2f3 && \
pip install setuptools && \
pip install -e .[flash-attn,deepspeed] && \
cd ..
  • then preprocess and train the config.yaml

Config yaml

base_model: arcee-ai/Arcee-Nova
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

save_safetensors: true

rl: orpo
orpo_alpha: 0.1
chat_template: chatml
datasets:
  - path: mlabonne/orpo-dpo-mix-40k
    type: chat_template.argilla
    chat_template: chatml

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./models/Arcee-Nova-ORPO-v0.1

adapter: qlora
lora_model_dir:

sequence_len: 1800
sample_packing: false
pad_to_sequence_len: false

adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  - gate_proj
  - up_proj
  - down_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false

bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 50
evals_per_epoch: 1
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 100
debug:
weight_decay: 0.05
fsdp:
   - full_shard
   - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
  pad_token: "<|endoftext|>"
  eos_token: "<|im_end|>"

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

608a2f3

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@maziyarpanahi maziyarpanahi added the bug Something isn't working label Aug 20, 2024
@winglian
Copy link
Collaborator

I think we'll want to change from our orpo implementation to the trl ORPOTrainer implementation.

@maziyarpanahi
Copy link
Contributor Author

This is interesting! Would love to help testing it if you have any work in progress for ORPOTrainer?

@maziyarpanahi
Copy link
Contributor Author

Hi @winglian
any updates? tell me if you need me to test anything?

@winglian
Copy link
Collaborator

oh, hmm, we already use the ORPOTrainer (https://github.com/axolotl-ai-cloud/axolotl/pull/1551/files), will need to dig into this a bit deeper

@maziyarpanahi
Copy link
Contributor Author

oh, hmm, we already use the ORPOTrainer (https://github.com/axolotl-ai-cloud/axolotl/pull/1551/files), will need to dig into this a bit deeper

Thanks a lot @winglian - is this something new? Should I try again? (my message is a few months old)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants