Skip to content

Commit

Permalink
Add ONNX export support for PatchTST (#2101)
Browse files Browse the repository at this point in the history
* Add ONNX export support for `PatchTST`

* Add unit test for patchtst

* Add listed support for PatchTST

* Add ONNX export support for patchtsmixer

* Add task=feature-extraction

* Fix ONNX compatible unfold

* Formatting

* Correctly handle negative indexing for onnx compatible unfold

* Update tests/exporters/exporters_utils.py

Co-authored-by: Ella Charlaix <[email protected]>

* Update optimum/exporters/tasks.py

Co-authored-by: Ella Charlaix <[email protected]>

* Move dummy patch tst input generator to input_generators.py

* Code formatting

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
xenova and echarlaix authored Jan 9, 2025
1 parent 8d1347f commit b9fa9aa
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- OLMo
- OLMo2
- OWL-ViT
- PatchTST
- PatchTSMixer
- Pegasus
- Perceiver
- Phi
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class OnnxConfig(ExportConfig, ABC):
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
"text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"zero-shot-image-classification": OrderedDict(
Expand Down
23 changes: 23 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DummyInputGenerator,
DummyIntGenerator,
DummyPastKeyValuesGenerator,
DummyPatchTSTInputGenerator,
DummyPix2StructInputGenerator,
DummyPointsGenerator,
DummySeq2SeqDecoderTextInputGenerator,
Expand All @@ -58,6 +59,7 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
is_diffusers_available,
is_diffusers_version,
Expand Down Expand Up @@ -2619,3 +2621,24 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig

DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.


class PatchTSTOnnxConfig(OnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyPatchTSTInputGenerator,)
ATOL_FOR_VALIDATION = 1e-4

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"past_values": {0: "batch_size", 1: "sequence_length"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "feature-extraction":
return {"last_hidden_state": {0: "batch_size"}}
else:
return super().outputs


class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig):
pass
53 changes: 51 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,53 @@ class PatchingSpec:
op_wrapper: Optional[Callable] = None


# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
# See https://github.com/pytorch/pytorch/issues/81871 for more information
def onnx_compatible_unfold(input_tensor, dimension, size, step):
"""
Custom implementation of torch.unfold without using torch.unfold.
Args:
input_tensor (torch.Tensor): The input tensor.
dimension (int): The dimension to unfold.
size (int): The size of each slice.
step (int): The step size between slices.
Returns:
torch.Tensor: The unfolded tensor.
"""
# Check if dimension is within the valid range
if not (-input_tensor.dim() <= dimension < input_tensor.dim()):
raise ValueError(
f"Dimension out of range (expected to be in range of [{-input_tensor.dim()}, {input_tensor.dim() - 1}], but got {dimension})"
)

# Normalize negative dimension
dimension = dimension % input_tensor.dim()

# Compute the shape of the unfolded output
input_size = input_tensor.size(dimension)
num_slices = (input_size - size) // step + 1

# Permute dimension to the end for easier indexing
input_tensor = input_tensor.transpose(dimension, -1)

# Extract slices
slices = []
for i in range(num_slices):
start = i * step
end = start + size
slices.append(input_tensor[..., start:end])

# Stack slices and permute dimensions back
result = torch.stack(slices, dim=-2).transpose(dimension, -2)
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]


class ModelPatcher:
def __init__(
self,
Expand All @@ -122,9 +169,11 @@ def __init__(
):
self._model = model

patching_specs = config.PATCHING_SPECS
patching_specs = config.PATCHING_SPECS or []
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)

self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
for spec in patching_specs:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
Expand Down
12 changes: 12 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class TasksManager:
}

_CUSTOM_CLASSES = {
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),
("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
Expand Down Expand Up @@ -962,6 +964,16 @@ class TasksManager:
"text-classification",
onnx="OPTOnnxConfig",
),
"patchtst": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSTOnnxConfig",
),
"patchtsmixer": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSMixerOnnxConfig",
),
"qwen2": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
DummyIntGenerator,
DummyLabelsGenerator,
DummyPastKeyValuesGenerator,
DummyPatchTSTInputGenerator,
DummyPix2StructInputGenerator,
DummyPointsGenerator,
DummySeq2SeqDecoderTextInputGenerator,
Expand Down Expand Up @@ -98,5 +99,6 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
)
27 changes: 27 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,3 +1532,30 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)


class DummyPatchTSTInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("past_values",)

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.normalized_config = normalized_config

self.batch_size = batch_size
self.context_length = normalized_config.context_length
self.num_input_channels = normalized_config.num_input_channels

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
shape=[self.batch_size, self.context_length, self.num_input_channels],
min_value=-1,
max_value=1,
framework=framework,
dtype=float_dtype,
)
5 changes: 5 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def has_attribute(self, attr_name):
return True


class NormalizedTimeSeriesForecastingConfig(NormalizedConfig):
NUM_INPUT_CHANNELS = "num_input_channels"
CONTEXT_LENGTH = "context_length"


class NormalizedTextConfig(NormalizedConfig):
VOCAB_SIZE = "vocab_size"
HIDDEN_SIZE = "hidden_size"
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@
"opt": "hf-internal-testing/tiny-random-OPTModel",
"owlv2": "hf-internal-testing/tiny-random-Owlv2Model",
"owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel",
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver": {
"hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],
Expand Down

0 comments on commit b9fa9aa

Please sign in to comment.