Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix gradient checkpointing use_reentrant issues #26917

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self):
def gradient_checkpointing_enable(self, use_reentrant: bool = True) -> None:
"""
Activates gradient checkpointing for the current model.

Expand All @@ -1828,7 +1828,21 @@ def gradient_checkpointing_enable(self):
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))

_supports_use_reentrant = "use_reentrant" in list(
inspect.signature(self._set_gradient_checkpointing).parameters
)
gc_kwargs = {}

if not _supports_use_reentrant and not use_reentrant:
logger.warn(
f"{self.__class__.__name__} does not support the use_reentrant argument. The argument will be ignored."
" Please raise an issue on GitHub to support this argument if needed."
)
elif _supports_use_reentrant and not use_reentrant:
gc_kwargs["use_reentrant"] = use_reentrant

self.apply(partial(self._set_gradient_checkpointing, value=True, **gc_kwargs))

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import inspect
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -411,9 +412,10 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, value=False, use_reentrant=True):
if isinstance(module, (OPTDecoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_use_reentrant = use_reentrant


OPT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -520,6 +522,8 @@ def __init__(self, config: OPTConfig):
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

self.gradient_checkpointing = False
# Use the default value
self.gradient_checkpointing_use_reentrant = True
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -699,12 +703,18 @@ def custom_forward(*inputs):

return custom_forward

kwargs = {}

if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters):
kwargs["use_reentrant"] = self.gradient_checkpointing_use_reentrant

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
**kwargs,
)
else:
layer_outputs = decoder_layer(
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ class TrainingArguments:
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `True`):
If `False` use `use_reentrant=False` when calling gradient checkpointing as recommended per PyTorch
documentation (can fix some bugs and unexpected behaviours for distributed training).
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
Expand Down Expand Up @@ -1119,6 +1122,10 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_use_reentrant: bool = field(
default=True,
metadata={"help": "If False use `use_reentrant=False` as recommended per PyTorch documentation."},
)
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
Expand Down Expand Up @@ -2102,6 +2109,7 @@ def set_training(
gradient_accumulation_steps: int = 1,
seed: int = 42,
gradient_checkpointing: bool = False,
gradient_checkpointing_use_reentrant: bool = True,
):
"""
A method that regroups all basic arguments linked to the training.
Expand Down Expand Up @@ -2165,6 +2173,7 @@ def set_training(
self.gradient_accumulation_steps = gradient_accumulation_steps
self.seed = seed
self.gradient_checkpointing = gradient_checkpointing
self.gradient_checkpointing_use_reentrant = gradient_checkpointing_use_reentrant
return self

def set_evaluate(
Expand Down
Loading