Skip to content

Commit

Permalink
Merge branch 'main' into mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 6, 2023
2 parents a11c3e8 + ba113e5 commit aca22df
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 26 deletions.
10 changes: 9 additions & 1 deletion optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def opt_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -336,6 +337,7 @@ def t5_forward(
query_length=None,
use_cache=False,
output_attentions=False,
**kwargs,
):
raise_on_head_mask(layer_head_mask)

Expand Down Expand Up @@ -466,6 +468,7 @@ def bart_forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
raise_on_head_mask(layer_head_mask)
Expand Down Expand Up @@ -583,6 +586,7 @@ def llama_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand Down Expand Up @@ -768,6 +772,7 @@ def gpt_bigcode_forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand Down Expand Up @@ -826,6 +831,7 @@ def bloom_forward(
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

Expand Down Expand Up @@ -907,9 +913,11 @@ def falcon_forward(
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
Expand All @@ -930,7 +938,7 @@ def falcon_forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)

if layer_past is not None:
past_key, past_value = layer_past
Expand Down
2 changes: 1 addition & 1 deletion optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
else:
from ...utils.dummy_bettertransformer_objects import BarkSelfAttention

if check_if_transformers_greater("4.32"):
if check_if_transformers_greater("4.34"):
from transformers.models.falcon.modeling_falcon import FalconAttention
else:
from ...utils.dummy_bettertransformer_objects import FalconAttention
Expand Down
41 changes: 41 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,48 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.num_key_value_heads = normalized_config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator

DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _from_pretrained(
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"],
ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"],
)
new_model_save_dir = Path(model_id)

Expand Down
15 changes: 8 additions & 7 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ORTConfigManager:
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"big-bird": "bert",
# "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
Expand All @@ -112,9 +112,9 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt_bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gpt-bigcode": "gpt2",
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
Expand All @@ -123,7 +123,7 @@ class ORTConfigManager:
"mbart": "bart",
"mistral": "gpt2",
"mt5": "bart",
"m2m_100": "bart",
"m2m-100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
Expand All @@ -135,6 +135,7 @@ class ORTConfigManager:

@classmethod
def get_model_ort_type(cls, model_type: str) -> str:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

Expand Down Expand Up @@ -162,7 +163,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"vit",
"swin",
]

model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise NotImplementedError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. "
Expand Down
4 changes: 2 additions & 2 deletions optimum/utils/dummy_bettertransformer_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def __init__(self, *args, **kwargs):


class FalconAttention(metaclass=DummyObject):
_backends = ["transformers_432"]
_backends = ["transformers_434"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers_432"])
requires_backends(self, ["transformers_434"])


def _llama_prepare_decoder_attention_mask(*args, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def require_numpy_strictly_lower(version: str, message: str):
"transformers_432",
(lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")),
),
(
"transformers_434",
(lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")),
),
]
)

Expand Down
33 changes: 28 additions & 5 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,21 +862,44 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class FalconDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
self.num_kv_heads = 1
head_dim = self.hidden_size // self.num_attention_heads
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.num_kv_heads = self.num_kv_heads = (
normalized_config.num_kv_heads
if (normalized_config.new_decoder_architecture or not normalized_config.multi_query)
else 1
)
self.head_dim = self.hidden_size // self.num_attention_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (
self.batch_size,
self.num_kv_heads,
self.sequence_length,
head_dim,
self.head_dim,
)
past_value_shape = (
self.batch_size,
self.num_kv_heads,
self.sequence_length,
head_dim,
self.head_dim,
)
return [
(
Expand Down
18 changes: 10 additions & 8 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ class NormalizedConfigManager:
# "big_bird": NormalizedTextConfig,
# "bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"falcon": NormalizedTextConfig.with_args(
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads"
),
"camembert": NormalizedTextConfig,
"codegen": GPT2LikeNormalizedTextConfig,
"cvt": NormalizedVisionConfig,
Expand All @@ -225,9 +227,9 @@ class NormalizedConfigManager:
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
Expand All @@ -236,7 +238,7 @@ class NormalizedConfigManager:
"mbart": BartLikeNormalizedTextConfig,
"mistral": MistralNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
Expand All @@ -245,7 +247,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
Expand All @@ -255,7 +257,6 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
}

@classmethod
Expand All @@ -269,5 +270,6 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]
2 changes: 1 addition & 1 deletion tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
"falcon": "Rocketknight1/tiny-random-falcon-7b",
"falcon": "fxmarty/really-tiny-falcon-testing",
"fsmt": "hf-internal-testing/tiny-random-FSMTModel",
"gpt2": "hf-internal-testing/tiny-random-GPT2Model",
# NOTE: this tiny model does not use attention_softmax_in_fp32=True (contrary to e.g. starcoder)
Expand Down

0 comments on commit aca22df

Please sign in to comment.