Skip to content

Commit

Permalink
restore SDPA in gpt neo after 4.45 (#1092)
Browse files Browse the repository at this point in the history
* restore SDPA in gpt neo after 4.45

* fix accuracy

* left padding
  • Loading branch information
eaidova authored Dec 30, 2024
1 parent 014a840 commit 1733791
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 1 deletion.
20 changes: 20 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FalconOnnxConfig,
GemmaOnnxConfig,
GPTJOnnxConfig,
GPTNeoOnnxConfig,
GPTNeoXOnnxConfig,
IBertOnnxConfig,
LlamaOnnxConfig,
Expand Down Expand Up @@ -68,6 +69,7 @@
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
IBertModelPatcher,
Expand Down Expand Up @@ -790,6 +792,24 @@ def patch_model_for_export(
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gpt-neo",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GPTNeoOpenVINOConfig(GPTNeoOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptNeoModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gptj",
*[
Expand Down
90 changes: 90 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,6 +2681,96 @@ def __exit__(self, exc_type, exc_value, traceback):
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")


def _gpt_neo_attn_forward(
self,
hidden_states,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
if output_attentions:
self._attn = self._orig_attn

return self._orig_forward(
hidden_states,
attention_mask=attention_mask,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)


# Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
def _gpt_neo_attn_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
batch_size = query.shape[0]

mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

dropout_p = float(self.config.attention_dropout) if self.training else 0.0
if (batch_size == 1 or self.training) and self.attention_type == "global":
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=1.0
)
else:
query_length, key_length = query.size(-2), key.size(-2)

causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

causal_mask = torch.where(causal_mask, 0, mask_value)
if batch_size > 1:
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

if attention_mask is not None:
attention_mask = causal_mask + attention_mask

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False, scale=1.0
)

return sdpa_result, None


class GptNeoModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.45.0") and is_torch_version(">=", "2.1.0"):
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
for layer in self._model.transformer.h:
self_attn = layer.attn.attention
self_attn._orig_attn = self_attn._attn
self_attn._attn = types.MethodType(_gpt_neo_attn_sdpa, self_attn)
self_attn._orig_forward = types.MethodType(_gpt_neo_attn_forward, self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.transformer.h:
for layer in self._model.transformer.h:
layer.attn.attention.forward = layer.attn.attention._orig_forward
layer.attn.attention._attn = layer.attn.attention._orig_attn


class Gemma2ModelPatcher(LlamaModelPatcher):
def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,15 @@ def test_beam_search(self, model_arch):
transformers_model._supports_cache_class = True
from transformers.cache_utils import DynamicCache
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
tokenization_args = {}
if is_transformers_version(">=", "4.45") and model_arch == "gpt_neo":
tokenization_args["padding_side"] = "left"
tokens = tokenizer(
["Today is a nice day and I am longer", "This is me"],
return_tensors="pt",
padding=True,
**tokenization_args,
)
ov_model_stateful.generation_config.eos_token_id = None
ov_model_stateless.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
Expand Down

0 comments on commit 1733791

Please sign in to comment.