From ace12de787389c7aa630a9a47edef6d3cf67bc22 Mon Sep 17 00:00:00 2001 From: Raghu Ramarao Date: Fri, 11 Oct 2024 16:50:40 +0530 Subject: [PATCH] Decision Transformer to ONNX V0.2 --- optimum/exporters/onnx/base.py | 8 ++++---- optimum/exporters/onnx/model_configs.py | 17 +++++++++++------ optimum/utils/input_generators.py | 18 +++++++++++------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index ccf3a3f2bde..568535472fa 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -175,10 +175,10 @@ class OnnxConfig(ExportConfig, ABC): ), "reinforcement-learning": OrderedDict( { - "state_preds": {0: "batch_size", 1: "sequence_length"}, - "action_preds": {0: "batch_size", 1: "sequence_length"}, - "return_preds": {0: "batch_size", 1: "sequence_length"}, - "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "state_preds": {0: "batch_size", 1: "sequence_length", 2: "states"}, + "action_preds": {0: "batch_size", 1: "sequence_length", 2: "actions"}, + "return_preds": {0: "batch_size", 1: "sequence_length", 2: "returns"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"}, } ), "semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b2932079dd2..d31fef36cf5 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from packaging import version +from sipbuild.generator.parser.tokens import states from transformers.utils import is_tf_available from ...onnx import merge_decoders @@ -264,16 +265,20 @@ class DecisionTransformerOnnxConfig(GPT2OnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - dynamic_axis = {0: "batch_size", 1: "sequence_length"} + DEFAULT_DUMMY_SHAPES['actions'] = self._normalized_config.config.act_dim + DEFAULT_DUMMY_SHAPES['states'] = self._normalized_config.config.state_dim + DEFAULT_DUMMY_SHAPES['returns'] = 1 + DEFAULT_DUMMY_SHAPES['last_hidden_state'] = self._normalized_config.config.hidden_size return { - 'actions': dynamic_axis, - 'timesteps': dynamic_axis, - 'attention_mask': dynamic_axis, - 'returns_to_go': dynamic_axis, - 'states': dynamic_axis, + 'states': {0: 'batch_size', 1: 'sequence_length', 2: 'states'}, + 'actions': {0: 'batch_size', 1: 'sequence_length', 2: 'actions'}, + 'returns_to_go': {0: 'batch_size', 1: 'sequence_length', 2: 'returns'}, + 'timesteps': {0: 'batch_size', 1: 'sequence_length'}, + 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, } + class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads") diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index e6f2059a1fc..808213eee10 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -522,20 +522,24 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + states = self.normalized_config.config.state_dim + actions = self.normalized_config.config.act_dim + max_ep_len = self.normalized_config.config.max_ep_len + if input_name == "states": - shape = [self.batch_size, self.normalized_config.config.state_dim] + shape = [self.batch_size, self.sequence_length, states] elif input_name == "actions": - shape = [self.batch_size, self.normalized_config.config.act_dim] + shape = [self.batch_size, self.sequence_length, actions] elif input_name == 'returns_to_go': - shape = [self.batch_size, 1] + shape = [self.batch_size, self.sequence_length, 1] elif input_name == 'timesteps': - shape = [self.normalized_config.config.state_dim, self.batch_size] - max_value = self.normalized_config.config.max_ep_len + shape = [self.batch_size, self.sequence_length] + max_value = max_ep_len return self.random_int_tensor(shape=shape, max_value = max_value, framework=framework, dtype=int_dtype) elif input_name == "attention_mask": - shape = [self.batch_size, self.normalized_config.config.state_dim] + shape = [self.batch_size, self.sequence_length] - return self.random_float_tensor(shape, min_value=-1., max_value=1., framework=framework, dtype=float_dtype) + return self.random_float_tensor(shape, min_value=-2., max_value=2., framework=framework, dtype=float_dtype) class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):