Skip to content

Commit

Permalink
Add mistral model patcher
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 9, 2023
1 parent 52ce2d7 commit b76f43a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 57 deletions.
6 changes: 6 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
BartModelPatcher,
BloomModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
OPTModelPatcher,
SAMModelPatcher,
WavLMModelPatcher,
Expand Down Expand Up @@ -250,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand Down
98 changes: 43 additions & 55 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

from transformers.utils import is_torch_available

from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
from ...utils.modeling_utils import (
_prepare_attn_mask,
_prepare_decoder_attention_mask,
_prepare_decoder_sliding_window_attention_mask,
)


if is_torch_available():
Expand Down Expand Up @@ -346,7 +350,7 @@ def patched_forward(
self.patched_forward = patched_forward


class BloomModelPatcher(ModelPatcher):
class CausalAttentionMaskModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
Expand All @@ -357,95 +361,79 @@ def __init__(

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = self._model.transformer._prepare_attn_mask
self._orig_func = getattr(self._model_to_patch, self._orig_func_name)

def __enter__(self):
super().__enter__()
if self.patch:
self._model.transformer._prepare_attn_mask = _prepare_attn_mask.__get__(self._model.transformer)
setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch))

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
self._model.transformer._prepare_attn_mask = self.orig_prepare_attn_mask.__get__(self._model.transformer)
setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch))


class LlamaModelPatcher(ModelPatcher):
class BloomModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.transformer
self._patch_func = _prepare_attn_mask
self._orig_func_name = "_prepare_attn_mask"
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = self._model.model._prepare_decoder_attention_mask

def __enter__(self):
super().__enter__()
if self.patch:
self._model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__(
self._model.model
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
self._model.model._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__(self._model.model)


class BartModelPatcher(Seq2SeqModelPatcher):
class OPTModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask

def __enter__(self):
super().__enter__()
if self.patch:
self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__(
self._model.model.decoder
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__(
self._model.model.decoder
)
class LlamaModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)


class OPTModelPatcher(ModelPatcher):
class MistralModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_sliding_window_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)
self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask

def __enter__(self):
super().__enter__()
if self.patch:
self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__(
self._model.model.decoder
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.patch:
self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__(
self._model.model.decoder
)
class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,11 @@ def prepare_past_key_values(
# Generate dummy past for the first forward if uses a merged decoder
if past_key_values is None:
batch_size = input_ids.shape[0]
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
if self.config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow
# "1" is the dummy sequence length
Expand Down
36 changes: 36 additions & 0 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"blenderbot-small",
"bloom",
"llama",
"mistral",
"mpt",
"opt",
"pegasus",
Expand Down Expand Up @@ -142,3 +143,38 @@ def _prepare_decoder_attention_mask(
)

return combined_attention_mask


# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask
def _prepare_decoder_sliding_window_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: int,
):
from transformers.models.mistral.modeling_mistral import _make_sliding_window_causal_mask, _expand_mask

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None

combined_attention_mask = _make_sliding_window_causal_mask(
input_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask

0 comments on commit b76f43a

Please sign in to comment.