Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove attn mask patching #1509

Merged
merged 10 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ...commands.export.onnx import parse_args_onnx
from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging
from ...utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
Expand Down Expand Up @@ -83,16 +84,12 @@ def _get_submodels_and_onnx_configs(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and legacy:
onnx_config_kwargs["no_position_ids"] = legacy

onnx_config = onnx_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
**onnx_config_kwargs,
legacy=legacy,
)

onnx_config.variant = _variant
Expand Down Expand Up @@ -317,13 +314,6 @@ def main_export(
model_name_or_path, subfolder=subfolder, library_name=library_name
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
Expand Down Expand Up @@ -382,6 +372,25 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

# For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, a tracer that does not handle
# controlflows will trace incorrectly the mask generation, resulting in incorrect attention masks for other sequence lengthss.
# Reference: https://github.com/huggingface/transformers/blob/af3de8d87c717c4bb090f037d0d89413c195a42f/src/transformers/modeling_attn_mask_utils.py#L94
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

# TODO: this may be moved rather to the OnnxConfig to avoid bloating this script.
if (
model_type in MODEL_TO_PATCH_FOR_PAST
and input_name == "sequence_length"
and kwargs_shapes.get(input_name) == 1
):
raise ValueError(
f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results."
)

if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
Expand Down
30 changes: 25 additions & 5 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
preprocessors: Optional[List[Any]] = None,
int_dtype: str = "int64",
float_dtype: str = "fp32",
legacy: bool = False,
):
self.task = task
self.int_dtype = int_dtype
Expand All @@ -209,6 +210,7 @@ def __init__(
self._preprocessors = preprocessors
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.variant = "default"
self.legacy = legacy

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Expand Down Expand Up @@ -565,14 +567,20 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
self.use_past = use_past
self.use_past_in_inputs = use_past_in_inputs

self.is_merged = False
self.use_cache_branch = None
super().__init__(
config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=legacy,
)

@property
Expand Down Expand Up @@ -628,11 +636,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[-2]
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
desired_length=past_present_length,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)
Expand All @@ -658,11 +666,15 @@ def overwrite_shape_and_generate_input(

# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both

# TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models.
if (
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids", "position_ids"]
and ((self.task == "text-generation" and self.legacy) or self.task != "text-generation")
):
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
Expand Down Expand Up @@ -768,6 +780,7 @@ def __init__(
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -777,6 +790,7 @@ def __init__(
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
legacy=legacy,
)
self._behavior = behavior

Expand Down Expand Up @@ -816,6 +830,7 @@ def with_behavior(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=self.legacy,
)
onnx_config.variant = self.variant
return onnx_config
Expand Down Expand Up @@ -1003,14 +1018,15 @@ class OnnxConfigWithLoss(OnnxConfig, ABC):

DUMMY_EXTRA_INPUT_GENERATOR_CLASSES = (DummyLabelsGenerator,)

def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32"):
def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32", legacy: bool = False):
self._onnx_config = config
self.task = self._onnx_config.task
self.int_dtype = int_dtype
self.float_dtype = float_dtype
self._normalized_config = self._onnx_config._normalized_config
self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS
self.variant = "default"
self.legacy = legacy

@classmethod
def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss":
Expand All @@ -1037,7 +1053,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
batch_size = dummy_inputs[input_name].shape[0]

# TODO: doesn't this break attention_mask generation?
if isinstance(self._onnx_config, OnnxConfigWithPast) and self._onnx_config.use_past_in_inputs is True:
if (
isinstance(self._onnx_config, OnnxConfigWithPast)
and self._onnx_config.use_past_in_inputs is True
and self.task != "text-generation"
):
kwargs["sequence_length"] = 1
else:
for input_name, dynamic_axes in self._tasks_to_extra_inputs[self.task].items():
Expand Down
18 changes: 14 additions & 4 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@
)
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import DecoderModelPatcher


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel

from .model_patcher import ModelPatcher

if is_tf_available():
from transformers import TFPreTrainedModel

Expand Down Expand Up @@ -75,7 +78,7 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -85,9 +88,8 @@ def __init__(
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
legacy=legacy,
)
# TODO: remove no_position_ids once optimum is sufficiently above 1.13
self.no_position_ids = no_position_ids

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -154,6 +156,12 @@ def post_process_exported_models(

return models_and_onnx_configs, onnx_files_subpaths

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
# Refer to DecoderModelPatcher.
return DecoderModelPatcher(self, model, model_kwargs=model_kwargs)


class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig):
@property
Expand All @@ -163,7 +171,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# Decoders based on GPT2 require a position_ids input to avoid
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]:
if not self.legacy and self.task in ["text-generation", "feature-extraction"]:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs
Expand Down Expand Up @@ -316,6 +324,7 @@ def __init__(
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -326,6 +335,7 @@ def __init__(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=preprocessors,
legacy=legacy,
)

from ..tasks import TasksManager
Expand Down
Loading
Loading