Skip to content

Commit

Permalink
unify io binding api with non io binding
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 15, 2025
1 parent e9abe6a commit c9b45ee
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 74 deletions.
8 changes: 4 additions & 4 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor,

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = model_outputs["last_hidden_state"]

Expand Down Expand Up @@ -365,9 +365,9 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def forward(
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]
Expand Down
105 changes: 65 additions & 40 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,32 +889,57 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool):
)

def _prepare_onnx_inputs(
self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]
self, use_torch: bool, model_inputs: Dict[str, Union[torch.Tensor, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""
Prepares the inputs for ONNX Runtime by converting them to numpy arrays with the expected dtype.
Args:
use_torch (`bool`):
Whether the inputs are torch.Tensor or not.
inputs (`Dict[str, Union[torch.Tensor, np.ndarray]]`):
The inputs to prepare for ONNX Runtime.
Returns:
`Dict[str, np.ndarray]`: The inputs prepared for ONNX Runtime.
"""

onnx_inputs = {}
# converts pytorch inputs into numpy inputs for onnx
for input_name in self.input_names.keys():
onnx_inputs[input_name] = inputs.pop(input_name)

if onnx_inputs[input_name] is None:
for input_name in self.input_names.keys():
if model_inputs.get(input_name, None) is None:
raise ValueError(f"Input {input_name} is required by model but not provided.")

if use_torch:
onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True)
onnx_inputs[input_name] = model_inputs[input_name].numpy(force=True)
else:
onnx_inputs[input_name] = model_inputs[input_name]

if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]:
onnx_inputs[input_name] = onnx_inputs[input_name].astype(
TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name])
)
expected_dtype = TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name])

if onnx_inputs[input_name].dtype != expected_dtype:
onnx_inputs[input_name] = onnx_inputs[input_name].astype(expected_dtype)

return onnx_inputs

def _prepare_onnx_outputs(
self, use_torch: bool, *onnx_outputs: np.ndarray
self, use_torch: bool, onnx_outputs: List[np.ndarray]
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
"""
Prepares the outputs from ONNX Runtime by converting them to torch.Tensor if requested.
Args:
use_torch (`bool`):
Whether the outputs should be torch.Tensor or not.
onnx_outputs (`List[np.ndarray]`):
The outputs from ONNX Runtime.
Returns:
`Dict[str, Union[torch.Tensor, np.ndarray]]`: The outputs prepared for the user.
"""

model_outputs = {}

# converts onnxruntime outputs into tensor for standard outputs
for output_name, idx in self.output_names.items():
model_outputs[output_name] = onnx_outputs[idx]

Expand Down Expand Up @@ -1068,9 +1093,9 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

if "last_hidden_state" in self.output_names:
last_hidden_state = model_outputs["last_hidden_state"]
Expand Down Expand Up @@ -1225,9 +1250,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1314,9 +1339,9 @@ def forward(
start_logits = output_buffers["start_logits"].view(output_shapes["start_logits"])
end_logits = output_buffers["end_logits"].view(output_shapes["end_logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

start_logits = model_outputs["start_logits"]
end_logits = model_outputs["end_logits"]
Expand Down Expand Up @@ -1421,9 +1446,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1513,9 +1538,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1598,9 +1623,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1685,9 +1710,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1772,9 +1797,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1899,9 +1924,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -1984,9 +2009,9 @@ def forward(

logits = output_buffers["logits"].view(output_shapes["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -2069,9 +2094,9 @@ def forward(
embeddings = output_buffers["embeddings"].view(output_shapes["embeddings"])

else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]
embeddings = model_outputs["embeddings"]
Expand Down Expand Up @@ -2136,9 +2161,9 @@ def forward(
else:
model_inputs = {"input_values": input_values}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

logits = model_outputs["logits"]

Expand Down Expand Up @@ -2211,9 +2236,9 @@ def forward(

reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
reconstruction = model_outputs["reconstruction"]
return ImageSuperResolutionOutput(reconstruction=reconstruction)

Expand Down Expand Up @@ -2282,9 +2307,9 @@ def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]):
model_outputs[name] = IOBindingHelper.to_pytorch(output)

else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

# converts output to namedtuple for pipelines post-processing
return ModelOutput(**model_outputs)
42 changes: 14 additions & 28 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,39 +363,25 @@ def forward(
use_torch = isinstance(input_features, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

model_inputs = {
"input_features": input_features,
"attention_mask": attention_mask,
}

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = (
[input_features, attention_mask] if "attention_mask" in self.input_names else [input_features]
)
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, *model_inputs)
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_features": input_features.cpu().detach().numpy()}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_features": input_features}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask

# TODO: Replace with a better solution
# attention_mask is exported with int64 datatype and tokenizer produces int32 input
# for speech2text model. Hence, the input is type casted for inference.
if "attention_mask" in self.input_names:
if self.session.get_inputs()[1].type == "tensor(int64)":
onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64)

outputs = self.session.run(None, onnx_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -431,9 +417,9 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = model_outputs["last_hidden_state"]

Expand Down Expand Up @@ -473,9 +459,9 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = model_outputs["last_hidden_state"]

Expand Down

0 comments on commit c9b45ee

Please sign in to comment.