Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StackLlaMa 2 dpo train failed: 8-bit model can't train with multiple gpus #1348

Closed
fancyerii opened this issue Feb 22, 2024 · 6 comments
Closed

Comments

@fancyerii
Copy link
Contributor

fancyerii commented Feb 22, 2024

I am following https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts.

I ran with:

accelerate launch --config_file 7b.yaml examples/research_projects/stack_llama_2/scripts/dpo_llama2.py     --model_name_or_path="sft/final_checkpoint"     --output_dir="dpo"

7b.yaml file:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: 0,1,2,3
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

error message:

Traceback (most recent call last):
  File "/nas/lili/codes/pt/ft/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py", line 213, in <module>
    dpo_trainer.train()
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1687, in _inner_training_loop
    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1227, in prepare
    result = tuple(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1228, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1104, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1330, in prepare_model
    raise ValueError(
ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}

I searched this issue and I am not sure it use Naive PP.

my environment:

transformers             4.37.2
accelerate               0.26.1
peft                     0.8.2
bitsandbytes             0.43.0.dev0 # latest built from source
trl                      0.7.11.dev0 # latest built from source
torch                    2.2.0
python 3.9.18
@younesbelkada
Copy link
Contributor

Hi @fancyerii
Thanks for the issue !
you need to first install the latest version of accelerate from pypi pip install -U accelerate and load the model with device_map={"":Accelerator().process_index} similarly as in

device_map={"": Accelerator().local_process_index},

If that works, would you be happy to submit a fix through a PR on that DPO script?

@fancyerii
Copy link
Contributor Author

fancyerii commented Feb 22, 2024

Hi @fancyerii Thanks for the issue ! you need to first install the latest version of accelerate from pypi pip install -U accelerate and load the model with device_map={"":Accelerator().process_index} similarly as in

device_map={"": Accelerator().local_process_index},

If that works, would you be happy to submit a fix through a PR on that DPO script?

I upgraded accelerate to 0.27.2. But it failed with new error.

Traceback (most recent call last):
  File "/nas/lili/codes/pt/ft/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py", line 215, in <module>
Traceback (most recent call last):
  File "/nas/lili/codes/pt/ft/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py", line 215, in <module>
Traceback (most recent call last):
  File "/nas/lili/codes/pt/ft/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py", line 215, in <module>
    dpo_trainer.train()
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1869, in _inner
_training_loop
    dpo_trainer.train()
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2781, in traini
ng_step
    return inner_training_loop(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1869, in _inner
_training_loop
    dpo_trainer.train()
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1539, in train
    self.accelerator.backward(loss)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1966, in back
ward
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2781, in traini
ng_step
    return inner_training_loop(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1869, in _inner_training_loop
    loss.backward(**kwargs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 522, in backward
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2781, in training_step
    torch.autograd.backward(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    self.accelerator.backward(loss)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1966, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
    loss.backward(**kwargs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 522, in backward
    self.accelerator.backward(loss)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1966, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    loss.backward(**kwargs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/ft-zSqjAXBp-py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 522, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.

@fancyerii
Copy link
Contributor Author

I searched this issue. And After I added the following line, it worked.

training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

@younesbelkada
Copy link
Contributor

indeed ! adding that line should solve it !
Would you like to sublit a PR wiyh all the fixes?

fancyerii added a commit to fancyerii/trl that referenced this issue Feb 22, 2024
fix "ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on."
see huggingface#1348
@fancyerii
Copy link
Contributor Author

indeed ! adding that line should solve it ! Would you like to sublit a PR wiyh all the fixes?

I have submitted a pr.

@younesbelkada
Copy link
Contributor

Thanks so much @fancyerii !

fancyerii pushed a commit to fancyerii/trl that referenced this issue Feb 22, 2024
younesbelkada pushed a commit that referenced this issue Feb 23, 2024
* fix 8-bit multi-gpu training bug see #1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <[email protected]>
@kashif kashif closed this as completed Feb 23, 2024
lapp0 pushed a commit to lapp0/trl that referenced this issue May 10, 2024
* fix 8-bit multi-gpu training bug see huggingface#1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants