Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gpt bigcode model loading with fp16 weights precision #1098

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
GPTBigCodeOnnxConfig,
GPTJOnnxConfig,
GPTNeoOnnxConfig,
GPTNeoXOnnxConfig,
Expand Down Expand Up @@ -73,6 +74,7 @@
FalconModelPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptBigCodeModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
Expand Down Expand Up @@ -2591,3 +2593,21 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return GraniteMoEModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gpt-bigcode",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GPTBigCodeOpenVINOConfig(GPTBigCodeOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)
90 changes: 90 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3650,3 +3650,93 @@ def __exit__(self, exc_type, exc_value, traceback):
block_sparse_moe.router.forward = block_sparse_moe.router._orig_forward
block_sparse_moe.input_linear.forward = block_sparse_moe.input_linear._orig_forward
block_sparse_moe.output_linear.forward = block_sparse_moe.output_linear._orig_forward


# copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
# The super dispatch is done in the forward.
raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")

scale = None
if not self.scale_attn_weights:
scale = 1

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key.shape[-2]

if self.multi_query:
query_length = query_shape[1]

# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)

# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)

# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
#
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
if is_torch_version(">=", "2.2.0"):
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]

# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
# create a causal mask in case query_length == 1.
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
# different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
if attention_mask is not None:
attention_mask = attention_mask.to(query.dtype)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=is_causal,
scale=scale,
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)

# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)

return sdpa_result, None


class GptBigCodeModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._orig_attn = layer.attn._attn
layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._attn = layer.attn._orig_attn