diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index fa39a47abf..adda0e6d8d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1675,6 +1675,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: inputs = { "pixel_values": {0: "batch_size"}, "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, + "input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, } else: if self.vision_encoder: @@ -1684,6 +1685,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: "image_positional_embeddings": {0: "batch_size"}, "image_embeddings": {0: "batch_size"}, "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, + "input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, } return inputs diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index db52493b6f..cc448ee68b 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -510,6 +510,7 @@ def __init__( def patched_forward( pixel_values=None, input_points=None, + input_labels=None, image_embeddings=None, image_positional_embeddings=None, return_dict=True, @@ -519,6 +520,7 @@ def patched_forward( return self.orig_forward( pixel_values=pixel_values, input_points=input_points, + input_labels=input_labels, image_embeddings=image_embeddings, return_dict=return_dict, **kwargs, @@ -549,11 +551,7 @@ def patched_forward( "image_positional_embeddings": image_positional_embeddings, } else: - if input_points is not None: - input_labels = torch.ones_like( - input_points[:, :, :, 0], dtype=torch.int, device=input_points.device - ) - else: + if input_points is None: raise ValueError("input_points is required to export the prompt encoder / mask decoder.") sparse_embeddings, dense_embeddings = model.prompt_encoder( diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 1a9024db7a..0c82808131 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -836,7 +836,7 @@ class DummyPointsGenerator(DummyInputGenerator): Generates dummy time step inputs. """ - SUPPORTED_INPUT_NAMES = ("input_points",) + SUPPORTED_INPUT_NAMES = ("input_points", "input_labels") def __init__( self, @@ -854,8 +854,12 @@ def __init__( self.nb_points_per_image = nb_points_per_image def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2] - return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + if input_name == "input_points": + shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + else: # input_labels + shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image] + return self.random_int_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=int_dtype) class DummyVisionEmbeddingsGenerator(DummyInputGenerator):