You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please check that this issue hasn't been reported before.
I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I expect the ORPO works properly with FSDP and DeepSpeed on Qwen2 models.
Current behaviour
Currently, it's not possible to use ORPO via FSDP or DeepSpeed. It results in
Possible issues:
LoRA / QLoRA
FSDP
DeepSpeed
ORPO
Qwen2
Or a combination of these causes this issue
File "/workspace/axolotl/src/axolotl/cli/train.py", line 67, in do_train
[Previous line repeated 2 more times]
component_trace = _Fire(component, args, parsed_flag_args, context, name)wrapped_child, num_wrapped_params = _recursive_wrap( File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555,
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1181, in _prepare_one
return self.prepare_model(obj, device_placement=device_placement)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 1477, in prepare_model
model = FSDP(model, **kwargs)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 463, in __init__
_auto_wrap(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 537, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 555, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 484, in _wrap
return wrapper_cls(module, **kwargs)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 487, in __init__
_init_param_handle_from_module(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 519, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 531, in _init_param_handle_from_params
handle = FlatParamHandle(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 537, in __init__
self._init_flat_param_and_metadata(
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 585, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/distributed/fsdp/flat_param.py", line 720, in _validate_tensors_to_flatten
raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
Steps to reproduce
launch a RunPod instance with 8x A100 or H100 with 640G memory (or 320G)
Please check that this issue hasn't been reported before.
Expected Behavior
I expect the ORPO works properly with FSDP and DeepSpeed on Qwen2 models.
Current behaviour
Currently, it's not possible to use ORPO via FSDP or DeepSpeed. It results in
Possible issues:
Steps to reproduce
winglian/axolotl-runpod:main-latest
template608a2f3
commit to avoid FSDP issue with the latest changes)Config yaml
Possible solution
No response
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
608a2f3
Acknowledgements
The text was updated successfully, but these errors were encountered: