Skip to content

Commit

Permalink
Add specific warning for exports with sequence_length set to 1
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer committed Oct 21, 2023
1 parent 9e15d6f commit 03bb480
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,6 @@ def main_export(
model_name_or_path, subfolder=subfolder, library_name=library_name
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

torch_dtype = None if fp16 is False else torch.float16

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
Expand All @@ -338,6 +329,21 @@ def main_export(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)
if (
input_name == "sequence_length"
and kwargs_shapes.get(input_name) == 1
and task.startswith("text-generation")
):
logger.warning("Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results.")

torch_dtype = None if fp16 is False else torch.float16

if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
Expand Down

0 comments on commit 03bb480

Please sign in to comment.