diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 18c75953c3..8efaebbd8c 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -82,6 +82,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - OLMo - OLMo2 - OWL-ViT +- PatchTST +- PatchTSMixer - Pegasus - Perceiver - Phi diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7d90e36056..e6618568c0 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -180,6 +180,7 @@ class OnnxConfig(ExportConfig, ABC): "text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), "text-classification": OrderedDict({"logits": {0: "batch_size"}}), "text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "zero-shot-image-classification": OrderedDict( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 63bf220ca9..b587f16316 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -34,6 +34,7 @@ DummyInputGenerator, DummyIntGenerator, DummyPastKeyValuesGenerator, + DummyPatchTSTInputGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, DummySeq2SeqDecoderTextInputGenerator, @@ -58,6 +59,7 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, is_diffusers_available, is_diffusers_version, @@ -2619,3 +2621,24 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + +class PatchTSTOnnxConfig(OnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyPatchTSTInputGenerator,) + ATOL_FOR_VALIDATION = 1e-4 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"past_values": {0: "batch_size", 1: "sequence_length"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.task == "feature-extraction": + return {"last_hidden_state": {0: "batch_size"}} + else: + return super().outputs + + +class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig): + pass diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 083bc12799..80293e7b95 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -113,6 +113,53 @@ class PatchingSpec: op_wrapper: Optional[Callable] = None +# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get: +# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible. +# See https://github.com/pytorch/pytorch/issues/81871 for more information +def onnx_compatible_unfold(input_tensor, dimension, size, step): + """ + Custom implementation of torch.unfold without using torch.unfold. + + Args: + input_tensor (torch.Tensor): The input tensor. + dimension (int): The dimension to unfold. + size (int): The size of each slice. + step (int): The step size between slices. + + Returns: + torch.Tensor: The unfolded tensor. + """ + # Check if dimension is within the valid range + if not (-input_tensor.dim() <= dimension < input_tensor.dim()): + raise ValueError( + f"Dimension out of range (expected to be in range of [{-input_tensor.dim()}, {input_tensor.dim() - 1}], but got {dimension})" + ) + + # Normalize negative dimension + dimension = dimension % input_tensor.dim() + + # Compute the shape of the unfolded output + input_size = input_tensor.size(dimension) + num_slices = (input_size - size) // step + 1 + + # Permute dimension to the end for easier indexing + input_tensor = input_tensor.transpose(dimension, -1) + + # Extract slices + slices = [] + for i in range(num_slices): + start = i * step + end = start + size + slices.append(input_tensor[..., start:end]) + + # Stack slices and permute dimensions back + result = torch.stack(slices, dim=-2).transpose(dimension, -2) + return result + + +UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] + + class ModelPatcher: def __init__( self, @@ -122,9 +169,11 @@ def __init__( ): self._model = model - patching_specs = config.PATCHING_SPECS + patching_specs = config.PATCHING_SPECS or [] + patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC) + self._patching_specs = [] - for spec in patching_specs if patching_specs is not None else []: + for spec in patching_specs: final_spec = spec if spec.orig_op is None: final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name)) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 59c066ac38..5651537162 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -322,6 +322,8 @@ class TasksManager: } _CUSTOM_CLASSES = { + ("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"), + ("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"), ("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), @@ -962,6 +964,16 @@ class TasksManager: "text-classification", onnx="OPTOnnxConfig", ), + "patchtst": supported_tasks_mapping( + "feature-extraction", + "time-series-forecasting", + onnx="PatchTSTOnnxConfig", + ), + "patchtsmixer": supported_tasks_mapping( + "feature-extraction", + "time-series-forecasting", + onnx="PatchTSMixerOnnxConfig", + ), "qwen2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index b80cf4b8c4..c870e49fad 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -69,6 +69,7 @@ DummyIntGenerator, DummyLabelsGenerator, DummyPastKeyValuesGenerator, + DummyPatchTSTInputGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, DummySeq2SeqDecoderTextInputGenerator, @@ -98,5 +99,6 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index b0be9f3a3f..e4545a8473 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -1532,3 +1532,30 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype) return super().generate(input_name, framework, int_dtype, float_dtype) + + +class DummyPatchTSTInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("past_values",) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + + self.batch_size = batch_size + self.context_length = normalized_config.context_length + self.num_input_channels = normalized_config.num_input_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + return self.random_float_tensor( + shape=[self.batch_size, self.context_length, self.num_input_channels], + min_value=-1, + max_value=1, + framework=framework, + dtype=float_dtype, + ) diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 9fde2bd469..053417b20b 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -77,6 +77,11 @@ def has_attribute(self, attr_name): return True +class NormalizedTimeSeriesForecastingConfig(NormalizedConfig): + NUM_INPUT_CHANNELS = "num_input_channels" + CONTEXT_LENGTH = "context_length" + + class NormalizedTextConfig(NormalizedConfig): VOCAB_SIZE = "vocab_size" HIDDEN_SIZE = "hidden_size" diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index d256e16dd4..ee31397fd8 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -136,6 +136,8 @@ "opt": "hf-internal-testing/tiny-random-OPTModel", "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", "owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel", + "patchtst": "ibm/test-patchtst", + "patchtsmixer": "ibm/test-patchtsmixer", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],