diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 851be1b8f6f..bcd49b32302 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -317,15 +317,6 @@ def main_export( model_name_or_path, subfolder=subfolder, library_name=library_name ) - # get the shapes to be used to generate dummy inputs - input_shapes = {} - for input_name in DEFAULT_DUMMY_SHAPES.keys(): - input_shapes[input_name] = ( - kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] - ) - - torch_dtype = None if fp16 is False else torch.float16 - if task == "auto": try: task = TasksManager.infer_task_from_model(model_name_or_path) @@ -338,6 +329,21 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + # get the shapes to be used to generate dummy inputs + input_shapes = {} + for input_name in DEFAULT_DUMMY_SHAPES.keys(): + input_shapes[input_name] = ( + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] + ) + if ( + input_name == "sequence_length" + and kwargs_shapes.get(input_name) == 1 + and task.startswith("text-generation") + ): + logger.warning("Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results.") + + torch_dtype = None if fp16 is False else torch.float16 + if library_name == "transformers": config = AutoConfig.from_pretrained( model_name_or_path,