Skip to content

Commit

Permalink
Decision Transformer to ONNX V0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
ra9hur committed Oct 11, 2024
1 parent d85062a commit ace12de
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
8 changes: 4 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ class OnnxConfig(ExportConfig, ABC):
),
"reinforcement-learning": OrderedDict(
{
"state_preds": {0: "batch_size", 1: "sequence_length"},
"action_preds": {0: "batch_size", 1: "sequence_length"},
"return_preds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"state_preds": {0: "batch_size", 1: "sequence_length", 2: "states"},
"action_preds": {0: "batch_size", 1: "sequence_length", 2: "actions"},
"return_preds": {0: "batch_size", 1: "sequence_length", 2: "returns"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"},
}
),
"semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
Expand Down
17 changes: 11 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

from packaging import version
from sipbuild.generator.parser.tokens import states
from transformers.utils import is_tf_available

from ...onnx import merge_decoders
Expand Down Expand Up @@ -264,16 +265,20 @@ class DecisionTransformerOnnxConfig(GPT2OnnxConfig):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
DEFAULT_DUMMY_SHAPES['actions'] = self._normalized_config.config.act_dim
DEFAULT_DUMMY_SHAPES['states'] = self._normalized_config.config.state_dim
DEFAULT_DUMMY_SHAPES['returns'] = 1
DEFAULT_DUMMY_SHAPES['last_hidden_state'] = self._normalized_config.config.hidden_size

return {
'actions': dynamic_axis,
'timesteps': dynamic_axis,
'attention_mask': dynamic_axis,
'returns_to_go': dynamic_axis,
'states': dynamic_axis,
'states': {0: 'batch_size', 1: 'sequence_length', 2: 'states'},
'actions': {0: 'batch_size', 1: 'sequence_length', 2: 'actions'},
'returns_to_go': {0: 'batch_size', 1: 'sequence_length', 2: 'returns'},
'timesteps': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
}


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
Expand Down
18 changes: 11 additions & 7 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,20 +522,24 @@ class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):

states = self.normalized_config.config.state_dim
actions = self.normalized_config.config.act_dim
max_ep_len = self.normalized_config.config.max_ep_len

if input_name == "states":
shape = [self.batch_size, self.normalized_config.config.state_dim]
shape = [self.batch_size, self.sequence_length, states]
elif input_name == "actions":
shape = [self.batch_size, self.normalized_config.config.act_dim]
shape = [self.batch_size, self.sequence_length, actions]
elif input_name == 'returns_to_go':
shape = [self.batch_size, 1]
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == 'timesteps':
shape = [self.normalized_config.config.state_dim, self.batch_size]
max_value = self.normalized_config.config.max_ep_len
shape = [self.batch_size, self.sequence_length]
max_value = max_ep_len
return self.random_int_tensor(shape=shape, max_value = max_value, framework=framework, dtype=int_dtype)
elif input_name == "attention_mask":
shape = [self.batch_size, self.normalized_config.config.state_dim]
shape = [self.batch_size, self.sequence_length]

return self.random_float_tensor(shape, min_value=-1., max_value=1., framework=framework, dtype=float_dtype)
return self.random_float_tensor(shape, min_value=-2., max_value=2., framework=framework, dtype=float_dtype)


class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
Expand Down

0 comments on commit ace12de

Please sign in to comment.