Skip to content

Commit

Permalink
Add ONNX and ORT support for Falcon (#1391)
Browse files Browse the repository at this point in the history
* add onnx and ort falcon

* add back ort support

* hopefully working ort inference

* address review

* style

* remove diff in base.py
  • Loading branch information
fxmarty authored Oct 18, 2023
1 parent e7bd60d commit 1ae95a7
Show file tree
Hide file tree
Showing 15 changed files with 548 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Supported architectures:
- Donut-Swin
- Electra
- Encoder Decoder
- Falcon
- Flaubert
- GPT-2
- GPT-BigCode
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# Decoders based on GPT2 require a position_ids input to avoid
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs
Expand Down
86 changes: 83 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
MultiQueryPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
Expand All @@ -59,6 +60,7 @@
from .model_patcher import (
BartModelPatcher,
BloomModelPatcher,
FalconModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
OPTModelPatcher,
Expand Down Expand Up @@ -279,9 +281,6 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Refer to OnnxConfigWithPast in base.py
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

Expand Down Expand Up @@ -337,6 +336,87 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t


class FalconOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
MultiQueryPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_PKV_GENERATOR_CLASS = MultiQueryPastKeyValuesGenerator

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
no_position_ids=no_position_ids,
)
# For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers:
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337
self._normalized_config.num_kv_heads = (
self._normalized_config.num_kv_heads
if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query)
else 1
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs

if (
not self.no_position_ids
and not self._config.alibi
and self.task in ["text-generation", "feature-extraction"]
):
# When alibi is used, position_ids are not used in Falcon.
# Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return FalconModelPatcher(self, model, model_kwargs=model_kwargs)

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
encoder_shape = (
Expand Down
238 changes: 237 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
import dataclasses
import functools
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import transformers
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor
from transformers.utils import is_torch_available

from ...utils.modeling_utils import (
_falcon_prepare_attn_mask,
_prepare_attn_mask,
_prepare_decoder_attention_mask,
_prepare_decoder_sliding_window_attention_mask,
Expand Down Expand Up @@ -229,6 +234,237 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward


def _make_causal_mask_falcon_patched(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
target_length, target_length+past_key_values_length]`.
"""
batch_size, target_length = input_ids_shape

# NOTE: ONNX Runtime is not able to run ONNX Trilu node with bool input. As a workaround, we pass a float input
# and cast to bool here. Reference: https://github.com/microsoft/onnxruntime/issues/16189
mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.float, device=device), diagonal=1).to(
torch.bool
)

# If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
# This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
# way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
mask = torch.cat([past_mask, mask], dim=-1)
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask


def falcon_model_forward_without_kv_reformatting(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if past_key_values is None:
past_key_values = tuple([None] * len(self.h))

# NOTE: here we removed the _convert_to_rw_cache call

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

hidden_states = inputs_embeds

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)

if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
# NOTE: here we use expand(batch_size, -1) instead of transformers view(-1, seq_length) that is bugged
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
else:
position_ids = position_ids.view(-1, seq_length).long()

causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

# Add last hidden state
hidden_states = self.ln_f(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

# NOTE: here we removed the _convert_cache_to_standard_format call

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


class FalconModelPatcher(ModelPatcher):
def __enter__(self):
self.patch_ops()

transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched

if self.real_config.task == "text-generation":
self._model.transformer.forward = types.MethodType(
falcon_model_forward_without_kv_reformatting, self._model.transformer
)

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(self._model, FalconModel):
self._model._prepare_attn_mask = _falcon_prepare_attn_mask
else:
self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask

setattr(self._model, self.orig_forward_name, self.patched_forward)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()

setattr(self._model, self.orig_forward_name, self.orig_forward)

if self.real_config.task == "text-generation":
self._model.transformer.forward = types.MethodType(
self.original_model_transformer_forward, self._model.transformer
)

transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(self._model, FalconModel):
self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask
else:
self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask

def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if config.task == "text-generation":
self.original_model_transformer_forward = model.transformer.forward

self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask

if isinstance(model, FalconModel):
self.original_falcon_prepare_attn_mask = model._prepare_attn_mask
else:
self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask

self._model = model

self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call"
self.orig_forward = getattr(self._model, self.orig_forward_name)

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
# setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention
# in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/falcon/modeling_falcon.py#L425
model_kwargs["output_attentions"] = True
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)

outputs = self.orig_forward(*args, **kwargs)

filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs

self.patched_forward = patched_forward


class WavLMModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"falcon",
"gpt2",
"gpt-bigcode",
"gpt-neo",
Expand Down
Loading

0 comments on commit 1ae95a7

Please sign in to comment.