Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 20, 2024
1 parent 0696597 commit 5b9c8cd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
8 changes: 0 additions & 8 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,6 @@ class OnnxConfig(ExportConfig, ABC):
"end_logits": {0: "batch_size", 1: "sequence_length"},
}
),
"reinforcement-learning": OrderedDict(
{
"state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
"action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"},
"return_preds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"},
}
),
"semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
Expand Down
14 changes: 13 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,11 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass


class DecisionTransformerOnnxConfig(GPT2OnnxConfig):
class DecisionTransformerOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
act_dim="act_dim", state_dim="state_dim", max_ep_len="max_ep_len", hidden_size="hidden_size", allow_new=True
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -277,6 +280,15 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"states": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"},
"action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"},
"return_preds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"},
}


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
Expand Down
5 changes: 1 addition & 4 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ class TasksManager:
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"reinforcement-learning": (
"AutoModel",
), # multiple auto model families can be used for reinforcement-learning
"reinforcement-learning": "AutoModel",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
Expand Down Expand Up @@ -579,7 +577,6 @@ class TasksManager:
),
"decision-transformer": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"reinforcement-learning",
onnx="DecisionTransformerOnnxConfig",
),
Expand Down

0 comments on commit 5b9c8cd

Please sign in to comment.