Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP+QLoRA get ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16 #1494

Closed
6 of 8 tasks
yaohwang opened this issue Apr 8, 2024 · 10 comments
Labels
bug Something isn't working

Comments

@yaohwang
Copy link

yaohwang commented Apr 8, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

work out

Current behaviour

Traceback (most recent call last):
File "/root/miniconda3/envs/python-app/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/python-app/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/src/axolotl/cli/train.py", line 59, in
fire.Fire(do_cli)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/src/axolotl/cli/train.py", line 35, in do_cli
return do_train(parsed_cfg, parsed_cli_args)
File "/src/axolotl/cli/train.py", line 55, in do_train
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
File "/src/axolotl/train.py", line 160, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/transformers/trainer.py", line 1848, in train
return inner_training_loop(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/transformers/trainer.py", line 1991, in _inner_training_loop
self.model = self.accelerator.prepare(self.model)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
result = tuple(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
return self.prepare_model(obj, device_placement=device_placement)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
model = FSDP(model, **kwargs)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in init
_auto_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
return wrapper_cls(module, **kwargs)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in init
_init_param_handle_from_module(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
handle = FlatParamHandle(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in init
self._init_flat_param_and_metadata(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 761, in _validate_tensors_to_flatten
raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16

Steps to reproduce

just run accelerate launch -m axolotl.cli.train examples/llama-2/qlora-fsdp.yml

Config yaml

examples/llama-2/qlora-fsdp.yml

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/4d6490b

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@yaohwang yaohwang added the bug Something isn't working label Apr 8, 2024
@Mohamad-Jaallouk
Copy link

Same issue
accelerate launch -m axolotl.cli.train /workspace/mixtral-qlora-fsdp.yml

@accupham
Copy link

accupham commented Apr 8, 2024

Seems like same issue this unresolved bug report too. #1426

@winglian
Copy link
Collaborator

winglian commented Apr 8, 2024

I'm looking into this today.

@jorge-tromero
Copy link

Same issue here

@SicariusSicariiStuff
Copy link

same issue:

ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.uint8
[2024-04-10 01:54:11,740] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0

@woojh3690
Copy link

Same issue here

@winglian
Copy link
Collaborator

There were some fixes for FSDP in #1462 , so check out the FSDP + QLORA examples provided in that PR. We probably need to update the appropriate YAML files for other model architectures.

@yaohwang
Copy link
Author

There were some fixes for FSDP in #1462 , so check out the FSDP + QLORA examples provided in that PR. We probably need to update the appropriate YAML files for other model architectures.

it worked, many thanks!!!

@0-hero
Copy link
Contributor

0-hero commented Apr 17, 2024

Hey this is still happening for Mixtral-8x22B for DPO

@winglian
Copy link
Collaborator

Hey this is still happening for Mixtral-8x22B for DPO

@0-hero Would you mind opening a separate issue please to track for DPO and a reference YAML pls?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants