diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 84f7191581..1e0780ae76 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -136,9 +136,9 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) last_hidden_state = model_outputs["last_hidden_state"] @@ -365,9 +365,9 @@ def forward( else: raise ValueError("Unsupported num_pkv") else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) # TODO: using a new variable out_past_key_values is memory inefficient, # past_key_values is not used anymore at this point diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 9d3535384a..8b7e22558d 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -275,9 +275,9 @@ def forward( if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) loss = model_outputs.get("loss", None) logits = model_outputs["logits"] diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 35a448fbb8..c61de8cf29 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -889,32 +889,57 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): ) def _prepare_onnx_inputs( - self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray] + self, use_torch: bool, model_inputs: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> Dict[str, np.ndarray]: + """ + Prepares the inputs for ONNX Runtime by converting them to numpy arrays with the expected dtype. + + Args: + use_torch (`bool`): + Whether the inputs are torch.Tensor or not. + inputs (`Dict[str, Union[torch.Tensor, np.ndarray]]`): + The inputs to prepare for ONNX Runtime. + + Returns: + `Dict[str, np.ndarray]`: The inputs prepared for ONNX Runtime. + """ + onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx - for input_name in self.input_names.keys(): - onnx_inputs[input_name] = inputs.pop(input_name) - if onnx_inputs[input_name] is None: + for input_name in self.input_names.keys(): + if model_inputs.get(input_name, None) is None: raise ValueError(f"Input {input_name} is required by model but not provided.") if use_torch: - onnx_inputs[input_name] = onnx_inputs[input_name].numpy(force=True) + onnx_inputs[input_name] = model_inputs[input_name].numpy(force=True) + else: + onnx_inputs[input_name] = model_inputs[input_name] - if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: - onnx_inputs[input_name] = onnx_inputs[input_name].astype( - TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) - ) + expected_dtype = TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) + + if onnx_inputs[input_name].dtype != expected_dtype: + onnx_inputs[input_name] = onnx_inputs[input_name].astype(expected_dtype) return onnx_inputs def _prepare_onnx_outputs( - self, use_torch: bool, *onnx_outputs: np.ndarray + self, use_torch: bool, onnx_outputs: List[np.ndarray] ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: + """ + Prepares the outputs from ONNX Runtime by converting them to torch.Tensor if requested. + + Args: + use_torch (`bool`): + Whether the outputs should be torch.Tensor or not. + onnx_outputs (`List[np.ndarray]`): + The outputs from ONNX Runtime. + + Returns: + `Dict[str, Union[torch.Tensor, np.ndarray]]`: The outputs prepared for the user. + """ + model_outputs = {} - # converts onnxruntime outputs into tensor for standard outputs for output_name, idx in self.output_names.items(): model_outputs[output_name] = onnx_outputs[idx] @@ -1068,9 +1093,9 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) if "last_hidden_state" in self.output_names: last_hidden_state = model_outputs["last_hidden_state"] @@ -1225,9 +1250,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1314,9 +1339,9 @@ def forward( start_logits = output_buffers["start_logits"].view(output_shapes["start_logits"]) end_logits = output_buffers["end_logits"].view(output_shapes["end_logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) start_logits = model_outputs["start_logits"] end_logits = model_outputs["end_logits"] @@ -1421,9 +1446,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1513,9 +1538,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1598,9 +1623,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1685,9 +1710,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1772,9 +1797,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1899,9 +1924,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -1984,9 +2009,9 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -2069,9 +2094,9 @@ def forward( embeddings = output_buffers["embeddings"].view(output_shapes["embeddings"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] embeddings = model_outputs["embeddings"] @@ -2136,9 +2161,9 @@ def forward( else: model_inputs = {"input_values": input_values} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) logits = model_outputs["logits"] @@ -2211,9 +2236,9 @@ def forward( reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) reconstruction = model_outputs["reconstruction"] return ImageSuperResolutionOutput(reconstruction=reconstruction) @@ -2282,9 +2307,9 @@ def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]): model_outputs[name] = IOBindingHelper.to_pytorch(output) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) # converts output to namedtuple for pipelines post-processing return ModelOutput(**model_outputs) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 8ab19c9951..180dbf87bb 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -363,11 +363,13 @@ def forward( use_torch = isinstance(input_features, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) + model_inputs = { + "input_features": input_features, + "attention_mask": attention_mask, + } + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: - model_inputs = ( - [input_features, attention_mask] if "attention_mask" in self.input_names else [input_features] - ) - io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, *model_inputs) + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs) io_binding.synchronize_inputs() self.session.run_with_iobinding(io_binding) @@ -375,27 +377,11 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - else: - onnx_inputs = {"input_features": input_features} - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask - - # TODO: Replace with a better solution - # attention_mask is exported with int64 datatype and tokenizer produces int32 input - # for speech2text model. Hence, the input is type casted for inference. - if "attention_mask" in self.input_names: - if self.session.get_inputs()[1].type == "tensor(int64)": - onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64) - - outputs = self.session.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + last_hidden_state = model_outputs["last_hidden_state"] return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -431,9 +417,9 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) last_hidden_state = model_outputs["last_hidden_state"] @@ -473,9 +459,9 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs) last_hidden_state = model_outputs["last_hidden_state"]