Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RewardTrainer] Enable gradient checkpointing for all multi-GPU training modes #835

Closed
lewtun opened this issue Oct 5, 2023 · 7 comments · Fixed by huggingface/peft#1036 or #912
Closed
Assignees

Comments

@lewtun
Copy link
Member

lewtun commented Oct 5, 2023

We currently have a few issues like #831 and #480 where gradient checkpointing + DDP does not work with the RewardTrainer.

Let's use this issue to collect the various training modes we'd like to support and track the status of their fixes. Ideally these would all be treated as integration tests so we know if there's regressions in future.

Mode Supported?
LoRA (unquantized)
LoRA (4-bit)
LoRA (8-bit)
DDP
ZeRO-1
ZeRO-2
ZeRO-3

More details on each mode are provided below.

1. LoRA without quantization

Currently, this mode gives a warning that gradients are None on the inputs (i.e. model doesn't learn):

/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

Command to reproduce:

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/reward_trainer.py --use_peft True --gradient_checkpointing=True --gradient_accumulation_steps 1

2. LoRA with 4-bit or 8-bit quantization

Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/reward_trainer.py", line 168, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
    self.accelerator.backward(loss)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1985, in backward
    loss.backward(**kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 93 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

Commands to reproduce:

# 8-bit
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/reward_trainer.py --use_peft True --load_in_8bit True --gradient_checkpointing=True --gradient_accumulation_steps 1

# 4-bit
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/reward_trainer.py --use_peft True --load_in_4bit True --gradient_checkpointing=True --gradient_accumulation_steps 1

3. Full training with DDP

Similar to 4-bit / 8-bit and DDP:

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 386 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

Command to reproduce:

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/reward_trainer.py --gradient_checkpointing=True --gradient_accumulation_steps 1

4. Full training with ZeRO-1

ZeRO-1 is fine and can be tested with (add bf16=True to the TrainingArguments):

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero1.yaml examples/scripts/reward_trainer.py --gradient_checkpointing=True --gradient_accumulation_steps 1

5. Full training with ZeRO-2

ZeRO-2 throws the following error:

Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/reward_trainer.py", line 169, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
    self.accelerator.backward(loss)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1979, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
    self.engine.backward(loss, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1861, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1900, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 809, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1257, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 835, in reduce_independent_p_g_buckets_and_remove_grads
    self.reduce_ipg_grads()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1226, in reduce_ipg_grads
    assert self.params_already_reduced[param_id] == False, \
AssertionError: The parameter 387 has already been reduced.                     Gradient computed twice for this partition.                     Multiple gradient reduction is currently not supported

Command to reproduce (add bf16=True to the TrainingArguments):

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/reward_trainer.py --gradient_checkpointing=True --gradient_accumulation_steps 1

6. Full training with ZeRO-3

Error:

AssertionError
Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/reward_trainer.py", line 169, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
    self.accelerator.backward(loss)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1979, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
    self.engine.backward(loss, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1861, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1993, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1006, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1286, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1041, in reduce_independent_p_g_buckets_and_remove_grads
    self.__reduce_and_partition_ipg_grads()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1073, in __reduce_and_partition_ipg_grads
    assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket)

Command to reproduce (add bf16=True to the TrainingArguments):

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/reward_trainer.py --gradient_checkpointing=True --gradient_accumulation_steps 1
@younesbelkada
Copy link
Contributor

Thanks for the deep investigation and report ! Will have a look in the next weeks

@younesbelkada younesbelkada self-assigned this Oct 10, 2023
@mnoukhov
Copy link
Contributor

@younesbelkada related, are we correctly doing gradient checkpointing with peft and quantization when we

  1. call model.gradient_checkpointing_enable() (as part of prepare_model_for_kbit_training) and then
  2. create an adapter with get_peft_model?

e.g. as part of SFT Trainer https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L150
or as part of Reward Trainer https://github.com/huggingface/trl/blob/main/trl/trainer/reward_trainer.py#L131

In huggingface/transformers#25841 (comment), it seems like you're saying that we should do things the other way around

  1. add the adapters and then 2. call model.gradient_checkpointing_enable()

@bcol23
Copy link

bcol23 commented Oct 18, 2023

I've met the exact same issue for full training with ZeRO-3. And I fixed it by adding {"reduce_bucket_size": 1e6} into the zero_opt_dict.

P.S. I do not quiet understand the reason.

@nghuyong
Copy link

I hit this issue and fixed by using ZeRO-1

@younesbelkada
Copy link
Contributor

Hi everyone,
If you install trl, transformers and PEFT from source, passing gradient_checkpointing_kwargs={"use_reentrant": False} to the training arguments should fix this issue!

@zztMermory
Copy link

Hi everyone, If you install trl, transformers and PEFT from source, passing gradient_checkpointing_kwargs={"use_reentrant": False} to the training arguments should fix this issue!

What is different from setting gradient_checkpointing=False?

@ycYiwei
Copy link

ycYiwei commented Dec 18, 2023

I encountered
UserWarning: None of the inputs have requires_grad=True. Gradients will be None
and
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
when training the reward model with lora even with gradient_checkpointing_kwargs={"use_reentrant": False}. Any help would be appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment