Skip to content

Commit

Permalink
Merge branch 'main' into fix-pkv
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 3, 2023
2 parents 2cac863 + ca19481 commit 72ac4c3
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 402 deletions.
33 changes: 21 additions & 12 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ...commands.export.onnx import parse_args_onnx
from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging
from ...utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
Expand Down Expand Up @@ -83,16 +84,12 @@ def _get_submodels_and_onnx_configs(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and legacy:
onnx_config_kwargs["no_position_ids"] = legacy

onnx_config = onnx_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
**onnx_config_kwargs,
legacy=legacy,
)

onnx_config.variant = _variant
Expand Down Expand Up @@ -317,13 +314,6 @@ def main_export(
model_name_or_path, subfolder=subfolder, library_name=library_name
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
Expand Down Expand Up @@ -382,6 +372,25 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

# For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, a tracer that does not handle
# controlflows will trace incorrectly the mask generation, resulting in incorrect attention masks for other sequence lengthss.
# Reference: https://github.com/huggingface/transformers/blob/af3de8d87c717c4bb090f037d0d89413c195a42f/src/transformers/modeling_attn_mask_utils.py#L94
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

# TODO: this may be moved rather to the OnnxConfig to avoid bloating this script.
if (
model_type in MODEL_TO_PATCH_FOR_PAST
and input_name == "sequence_length"
and kwargs_shapes.get(input_name) == 1
):
raise ValueError(
f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results."
)

if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
Expand Down
30 changes: 25 additions & 5 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
preprocessors: Optional[List[Any]] = None,
int_dtype: str = "int64",
float_dtype: str = "fp32",
legacy: bool = False,
):
self.task = task
self.int_dtype = int_dtype
Expand All @@ -209,6 +210,7 @@ def __init__(
self._preprocessors = preprocessors
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.variant = "default"
self.legacy = legacy

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Expand Down Expand Up @@ -565,14 +567,20 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
self.use_past = use_past
self.use_past_in_inputs = use_past_in_inputs

self.is_merged = False
self.use_cache_branch = None
super().__init__(
config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=legacy,
)

@property
Expand Down Expand Up @@ -628,11 +636,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[-2]
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
desired_length=past_present_length,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)
Expand All @@ -658,11 +666,15 @@ def overwrite_shape_and_generate_input(

# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both

# TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models.
if (
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids", "position_ids"]
and ((self.task == "text-generation" and self.legacy) or self.task != "text-generation")
):
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
Expand Down Expand Up @@ -768,6 +780,7 @@ def __init__(
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -777,6 +790,7 @@ def __init__(
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
legacy=legacy,
)
self._behavior = behavior

Expand Down Expand Up @@ -816,6 +830,7 @@ def with_behavior(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=self.legacy,
)
onnx_config.variant = self.variant
return onnx_config
Expand Down Expand Up @@ -1003,14 +1018,15 @@ class OnnxConfigWithLoss(OnnxConfig, ABC):

DUMMY_EXTRA_INPUT_GENERATOR_CLASSES = (DummyLabelsGenerator,)

def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32"):
def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32", legacy: bool = False):
self._onnx_config = config
self.task = self._onnx_config.task
self.int_dtype = int_dtype
self.float_dtype = float_dtype
self._normalized_config = self._onnx_config._normalized_config
self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS
self.variant = "default"
self.legacy = legacy

@classmethod
def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss":
Expand All @@ -1037,7 +1053,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
batch_size = dummy_inputs[input_name].shape[0]

# TODO: doesn't this break attention_mask generation?
if isinstance(self._onnx_config, OnnxConfigWithPast) and self._onnx_config.use_past_in_inputs is True:
if (
isinstance(self._onnx_config, OnnxConfigWithPast)
and self._onnx_config.use_past_in_inputs is True
and self.task != "text-generation"
):
kwargs["sequence_length"] = 1
else:
for input_name, dynamic_axes in self._tasks_to_extra_inputs[self.task].items():
Expand Down
18 changes: 14 additions & 4 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@
)
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import DecoderModelPatcher


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel

from .model_patcher import ModelPatcher

if is_tf_available():
from transformers import TFPreTrainedModel

Expand Down Expand Up @@ -75,7 +78,7 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -85,9 +88,8 @@ def __init__(
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
legacy=legacy,
)
# TODO: remove no_position_ids once optimum is sufficiently above 1.13
self.no_position_ids = no_position_ids

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -154,6 +156,12 @@ def post_process_exported_models(

return models_and_onnx_configs, onnx_files_subpaths

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
# Refer to DecoderModelPatcher.
return DecoderModelPatcher(self, model, model_kwargs=model_kwargs)


class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig):
@property
Expand All @@ -163,7 +171,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# Decoders based on GPT2 require a position_ids input to avoid
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]:
if not self.legacy and self.task in ["text-generation", "feature-extraction"]:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs
Expand Down Expand Up @@ -316,6 +324,7 @@ def __init__(
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
super().__init__(
config=config,
Expand All @@ -326,6 +335,7 @@ def __init__(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=preprocessors,
legacy=legacy,
)

from ..tasks import TasksManager
Expand Down
Loading

0 comments on commit 72ac4c3

Please sign in to comment.