Skip to content

Commit

Permalink
gpt2
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 17, 2024
1 parent 248a607 commit 4646014
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)

return attn_output, attn_weights

Expand Down Expand Up @@ -252,6 +253,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)

return attn_output, attn_weights

Expand Down Expand Up @@ -325,15 +327,12 @@ def forward(
**kwargs,
)

attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output_reshaped)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
# weird but needed to satisfy BC and tests (would normally be None)
if self.config._attn_implementation == "flash_attention_2":
attn_weights = attn_output_reshaped
outputs += (attn_weights,)

return outputs # a, present, (attentions)
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)

return attn_output, attn_weights

Expand Down Expand Up @@ -267,6 +268,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)

return attn_output, attn_weights

Expand Down Expand Up @@ -340,15 +342,12 @@ def forward(
**kwargs,
)

attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output_reshaped)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
# weird but needed to satisfy BC and tests (would normally be None)
if self.config._attn_implementation == "flash_attention_2":
attn_weights = attn_output_reshaped
outputs += (attn_weights,)

return outputs # a, present, (attentions)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import numpy as np

import transformers
from transformers.cache_utils import DynamicCache
from transformers import is_flax_available, is_torch_available
from transformers.cache_utils import DynamicCache
from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
Expand Down
2 changes: 1 addition & 1 deletion utils/check_config_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"expert_layer_offset",
"expert_layer_period",
],
"LlamaConfig": ["pretraining_tp"],
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Qwen2VLConfig": ["use_sliding_window"],
Expand Down Expand Up @@ -311,6 +310,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
# rope attributes may not appear directly in the modeling but are used
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]

Expand Down

0 comments on commit 4646014

Please sign in to comment.