diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 346da82d86f5b3..6672ce79add52d 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -592,7 +592,10 @@ def custom_forward(*inputs): if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) + if self.gradient_checkpointing and self.training: + hidden_states = torch.utils.checkpoint.checkpoint(self.norm, hidden_states) + else: + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: