diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3a48a579c2..82628bd28d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -824,6 +824,64 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs +class DepthAnythingOnnxConfig(ViTOnnxConfig): + pass + + +class DummyPromptDepthInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "pixel_values", + "prompt_depth", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + prompt_height: int = DEFAULT_DUMMY_SHAPES["prompt_height"], + prompt_width: int = DEFAULT_DUMMY_SHAPES["prompt_width"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + num_channels=num_channels, + width=width, + height=height, + **kwargs, + ) + self.prompt_height = prompt_height + self.prompt_width = prompt_width + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "prompt_depth": + return self.random_float_tensor( + (self.batch_size, 1, self.prompt_height, self.prompt_width), + framework=framework, + dtype=float_dtype, + ) + else: + return super().generate( + input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype + ) + + +class PromptDepthAnythingOnnxConfig(DepthAnythingOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyPromptDepthInputGenerator,) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "prompt_depth": {0: "batch_size", 2: "prompt_height", 3: "prompt_width"}, + } + + class CvTOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 13 ATOL_FOR_VALIDATION = 1e-2 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7cb5a31d2d..8e48fd14d3 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -602,6 +602,11 @@ class TasksManager: "masked-im", onnx="DeiTOnnxConfig", ), + "depth-anything": supported_tasks_mapping( + "feature-extraction", + "depth-estimation", + onnx="DepthAnythingOnnxConfig", + ), "detr": supported_tasks_mapping( "feature-extraction", "object-detection", @@ -1033,6 +1038,11 @@ class TasksManager: "image-classification", onnx="PoolFormerOnnxConfig", ), + "prompt-depth-anything": supported_tasks_mapping( + "feature-extraction", + "depth-estimation", + onnx="PromptDepthAnythingOnnxConfig", + ), "pvt": supported_tasks_mapping( "feature-extraction", "image-classification", diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 18a2a5a3fd..9512881415 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -63,6 +63,8 @@ def wrapper(*args, **kwargs): "num_channels": 3, "point_batch_size": 3, "nb_points_per_image": 2, + "prompt_width": 32, + "prompt_height": 32, # audio "feature_size": 80, "nb_max_frames": 3000,