Skip to content

Commit

Permalink
Add input_labels input to SAM model export (#1638)
Browse files Browse the repository at this point in the history
* Add `input_labels` input to SAM model export

* Add back error guard
  • Loading branch information
xenova authored Jan 11, 2024
1 parent c668b26 commit 4d25ed2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
inputs = {
"pixel_values": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}
else:
if self.vision_encoder:
Expand All @@ -1684,6 +1685,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"image_positional_embeddings": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
"input_labels": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}
return inputs

Expand Down
8 changes: 3 additions & 5 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
def patched_forward(
pixel_values=None,
input_points=None,
input_labels=None,
image_embeddings=None,
image_positional_embeddings=None,
return_dict=True,
Expand All @@ -519,6 +520,7 @@ def patched_forward(
return self.orig_forward(
pixel_values=pixel_values,
input_points=input_points,
input_labels=input_labels,
image_embeddings=image_embeddings,
return_dict=return_dict,
**kwargs,
Expand Down Expand Up @@ -549,11 +551,7 @@ def patched_forward(
"image_positional_embeddings": image_positional_embeddings,
}
else:
if input_points is not None:
input_labels = torch.ones_like(
input_points[:, :, :, 0], dtype=torch.int, device=input_points.device
)
else:
if input_points is None:
raise ValueError("input_points is required to export the prompt encoder / mask decoder.")

sparse_embeddings, dense_embeddings = model.prompt_encoder(
Expand Down
10 changes: 7 additions & 3 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ class DummyPointsGenerator(DummyInputGenerator):
Generates dummy time step inputs.
"""

SUPPORTED_INPUT_NAMES = ("input_points",)
SUPPORTED_INPUT_NAMES = ("input_points", "input_labels")

def __init__(
self,
Expand All @@ -854,8 +854,12 @@ def __init__(
self.nb_points_per_image = nb_points_per_image

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
if input_name == "input_points":
shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image, 2]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
else: # input_labels
shape = [self.batch_size, self.point_batch_size, self.nb_points_per_image]
return self.random_int_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=int_dtype)


class DummyVisionEmbeddingsGenerator(DummyInputGenerator):
Expand Down

0 comments on commit 4d25ed2

Please sign in to comment.