diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d567b9438a0d00..2c0eb1b9622475 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1819,7 +1819,7 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, use_reentrant: bool = True) -> None: """ Activates gradient checkpointing for the current model. @@ -1828,7 +1828,21 @@ def gradient_checkpointing_enable(self): """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + _supports_use_reentrant = "use_reentrant" in list( + inspect.signature(self._set_gradient_checkpointing).parameters + ) + gc_kwargs = {} + + if not _supports_use_reentrant and not use_reentrant: + logger.warn( + f"{self.__class__.__name__} does not support the use_reentrant argument. The argument will be ignored." + " Please raise an issue on GitHub to support this argument if needed." + ) + elif _supports_use_reentrant and not use_reentrant: + gc_kwargs["use_reentrant"] = use_reentrant + + self.apply(partial(self._set_gradient_checkpointing, value=True, **gc_kwargs)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f246524348d..0a8a180f3b9457 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch OPT model.""" +import inspect from typing import List, Optional, Tuple, Union import torch @@ -411,9 +412,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, value=False, use_reentrant=True): if isinstance(module, (OPTDecoder)): module.gradient_checkpointing = value + module.gradient_checkpointing_use_reentrant = use_reentrant OPT_INPUTS_DOCSTRING = r""" @@ -520,6 +522,8 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + # Use the default value + self.gradient_checkpointing_use_reentrant = True # Initialize weights and apply final processing self.post_init() @@ -699,12 +703,18 @@ def custom_forward(*inputs): return custom_forward + kwargs = {} + + if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters): + kwargs["use_reentrant"] = self.gradient_checkpointing_use_reentrant + layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + **kwargs, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 428db905e73c84..3d5ffb3ff1d050 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -572,6 +572,9 @@ class TrainingArguments: Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished. gradient_checkpointing (`bool`, *optional*, defaults to `False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `True`): + If `False` use `use_reentrant=False` when calling gradient checkpointing as recommended per PyTorch + documentation (can fix some bugs and unexpected behaviours for distributed training). include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. @@ -1119,6 +1122,10 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) + gradient_checkpointing_use_reentrant: bool = field( + default=True, + metadata={"help": "If False use `use_reentrant=False` as recommended per PyTorch documentation."}, + ) include_inputs_for_metrics: bool = field( default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) @@ -2102,6 +2109,7 @@ def set_training( gradient_accumulation_steps: int = 1, seed: int = 42, gradient_checkpointing: bool = False, + gradient_checkpointing_use_reentrant: bool = True, ): """ A method that regroups all basic arguments linked to the training. @@ -2165,6 +2173,7 @@ def set_training( self.gradient_accumulation_steps = gradient_accumulation_steps self.seed = seed self.gradient_checkpointing = gradient_checkpointing + self.gradient_checkpointing_use_reentrant = gradient_checkpointing_use_reentrant return self def set_evaluate(