-
-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP+QLoRA get ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16 #1494
Comments
Same issue |
Seems like same issue this unresolved bug report too. #1426 |
I'm looking into this today. |
Same issue here |
same issue: ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.uint8 |
Same issue here |
There were some fixes for FSDP in #1462 , so check out the FSDP + QLORA examples provided in that PR. We probably need to update the appropriate YAML files for other model architectures. |
it worked, many thanks!!! |
Hey this is still happening for Mixtral-8x22B for DPO |
@0-hero Would you mind opening a separate issue please to track for DPO and a reference YAML pls? |
Please check that this issue hasn't been reported before.
Expected Behavior
work out
Current behaviour
Traceback (most recent call last):
File "/root/miniconda3/envs/python-app/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/python-app/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/src/axolotl/cli/train.py", line 59, in
fire.Fire(do_cli)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/src/axolotl/cli/train.py", line 35, in do_cli
return do_train(parsed_cfg, parsed_cli_args)
File "/src/axolotl/cli/train.py", line 55, in do_train
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
File "/src/axolotl/train.py", line 160, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/transformers/trainer.py", line 1848, in train
return inner_training_loop(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/transformers/trainer.py", line 1991, in _inner_training_loop
self.model = self.accelerator.prepare(self.model)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
result = tuple(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
return self.prepare_model(obj, device_placement=device_placement)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
model = FSDP(model, **kwargs)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in init
_auto_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
return wrapper_cls(module, **kwargs)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in init
_init_param_handle_from_module(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
handle = FlatParamHandle(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in init
self._init_flat_param_and_metadata(
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "/root/miniconda3/envs/python-app/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 761, in _validate_tensors_to_flatten
raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.float32 and torch.bfloat16
Steps to reproduce
just run
accelerate launch -m axolotl.cli.train examples/llama-2/qlora-fsdp.yml
Config yaml
examples/llama-2/qlora-fsdp.yml
Possible solution
No response
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
main/4d6490b
Acknowledgements
The text was updated successfully, but these errors were encountered: