Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not merge #1427

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ def main_export(
library_name=library_name,
)

if model.config.model_type == "whisper":
import torch
model = model.eval()
with torch.no_grad():
for i in range(len(model.model.decoder.layers)):
model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn)

custom_architecture = False
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")
Expand Down Expand Up @@ -574,10 +581,12 @@ def main_export(
logger.warning(
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {output.as_posix()}"
)
"""
except Exception as e:
raise Exception(
f"An error occured during validation, but the model was saved nonetheless at {output.as_posix()}. Detailed error: {e}."
)
"""


def main():
Expand Down
97 changes: 54 additions & 43 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def fix_dynamic_axes(
onnx_inputs[name] = value

for name, value in onnx_inputs.items():
if value.dtype == np.float32 and dtype == "fp16":
if value is not None and (value.dtype == np.float32 and dtype == "fp16"):
onnx_inputs[name] = onnx_inputs[name].astype(np.float16)

outputs = session.run(None, onnx_inputs)
Expand Down Expand Up @@ -579,14 +579,15 @@ def __init__(

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if not self.use_past_in_inputs:
common_outputs = super().outputs
#if not self.use_past_in_inputs:
# common_outputs = super().outputs

# In the other cases, the sequence_length axis is not dynamic, always of length 1
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
if self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
if self.use_past:
common_outputs = OrderedDict({"logits": {}})
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs
Expand All @@ -602,7 +603,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past_in_inputs and self.use_cache_branch is not False:

if self.use_past_in_inputs:
input_names.append("past_key_values")

for input_name in input_names:
Expand All @@ -623,30 +625,30 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

# refer to https://github.com/huggingface/optimum/pull/764
if (
self.use_past_in_inputs
and self.PAD_ATTENTION_MASK_TO_PAST
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)

if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs:
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["decoder_attention_mask"],
desired_length=past_length + 1,
dim=1,
dtype=dummy_inputs["decoder_attention_mask"].dtype,
)
# if (
# self.use_past_in_inputs
# and self.PAD_ATTENTION_MASK_TO_PAST
# and self.use_cache_branch is not False
# and "attention_mask" in dummy_inputs
# ):
# # Obtain the past sequence length from the value instead of the key (Bloom).
# past_length = dummy_inputs["past_key_values"][0][1].shape[-2]

# dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
# dummy_inputs["attention_mask"],
# desired_length=past_length + 1,
# dim=1,
# dtype=dummy_inputs["attention_mask"].dtype,
# )

# if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs:
# past_length = dummy_inputs["past_key_values"][0][0].shape[2]
# dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
# dummy_inputs["decoder_attention_mask"],
# desired_length=past_length + 1,
# dim=1,
# dtype=dummy_inputs["decoder_attention_mask"].dtype,
# )

return dummy_inputs

Expand Down Expand Up @@ -702,8 +704,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.key"] = {} # static shapes
inputs_or_outputs[f"{name}.{i}.value"] = {}

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]
Expand Down Expand Up @@ -821,7 +823,16 @@ def with_behavior(

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super(OnnxConfigWithPast, self).outputs
# In the other cases, the sequence_length axis is not dynamic, always of length 1
if self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {}})
else:
common_outputs = OrderedDict({"logits": {}})

# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")

"""
# Renaming the outputs axes properly.
for name, axes_names in common_outputs.items():
if self._behavior is ConfigBehavior.ENCODER or "encoder" in name:
Expand All @@ -840,10 +851,12 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
else:
new_axes_names[axis_idx] = axis_name
common_outputs[name] = new_axes_names


if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
"""

return common_outputs

Expand All @@ -859,19 +872,17 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
name = "present"

for i in range(self._normalized_config.decoder_num_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name}

inputs_or_outputs[f"{name}.{i}.decoder.key"] = {}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {}
if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs)
or direction == "inputs"
self._behavior is ConfigBehavior.DECODER
):
# TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export()
# time we have currently no case to check whether we will merge at a later step or not (self.is_merged is
# not yet set at this time)
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {}

def flatten_past_key_values(self, flattened_output, name, idx, t):
if len(t) not in [2, 4]:
Expand Down
82 changes: 41 additions & 41 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["attention_mask"] = {}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"input_ids": {},
"attention_mask": {},
}
return common_inputs

Expand All @@ -109,7 +109,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
else:
# in the merged case, we need to allow the `sequence_length` to be variable, as it is not 1
# during the first pass without past key values
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
common_outputs = OrderedDict({"logits": {}})
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

Expand All @@ -125,32 +125,32 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
save_path=decoder_merged_path,
)
except Exception as e:
raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# In order to do the validation of the two branches on the same file
onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name]

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True
# # Attempt to merge only if the decoder-only was exported separately without/with past
# if self.use_past is True and len(models_and_onnx_configs) == 2:
# decoder_path = Path(path, onnx_files_subpaths[0])
# decoder_with_past_path = Path(path, onnx_files_subpaths[1])
# decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
# try:
# merge_decoders(
# decoder=decoder_path,
# decoder_with_past=decoder_with_past_path,
# save_path=decoder_merged_path,
# )
# except Exception as e:
# raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# # In order to do the validation of the two branches on the same file
# onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name]

# # We validate the two branches of the decoder model then
# models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
# models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# # Past key values won't be generated by default, but added in the input
# models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

# models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
# models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

Expand All @@ -165,9 +165,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
common_inputs["position_ids"] = {}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {}

return common_inputs

Expand Down Expand Up @@ -272,26 +272,26 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
DummyAudioInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}
common_inputs["input_features"] = {}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
common_inputs["decoder_input_ids"] = {}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
self.add_past_key_values(common_inputs, direction="inputs")

common_inputs["decoder_attention_mask"] = {}
common_inputs["position_ids"] = {}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
common_inputs["encoder_outputs"] = {}

return common_inputs

Expand Down
Loading
Loading