Skip to content

Commit

Permalink
add ugly code
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 3, 2023
1 parent ca7b9d8 commit 9ac3402
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 57 deletions.
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ def main_export(
device=device,
library_name=library_name,
)
import torch
model = model.eval()
with torch.no_grad():
for i in range(len(model.model.decoder.layers)):
model.model.decoder.layers[i].encoder_attn = torch.jit.script(model.model.decoder.layers[i].encoder_attn)

custom_architecture = False
is_stable_diffusion = "stable-diffusion" in task
Expand Down Expand Up @@ -574,10 +579,12 @@ def main_export(
logger.warning(
f"The ONNX export succeeded with the warning: {e}.\n The exported model was saved at: {output.as_posix()}"
)
"""
except Exception as e:
raise Exception(
f"An error occured during validation, but the model was saved nonetheless at {output.as_posix()}. Detailed error: {e}."
)
"""


def main():
Expand Down
45 changes: 29 additions & 16 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def fix_dynamic_axes(
onnx_inputs[name] = value

for name, value in onnx_inputs.items():
if value.dtype == np.float32 and dtype == "fp16":
if value is not None and (value.dtype == np.float32 and dtype == "fp16"):
onnx_inputs[name] = onnx_inputs[name].astype(np.float16)

outputs = session.run(None, onnx_inputs)
Expand Down Expand Up @@ -579,14 +579,16 @@ def __init__(

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if not self.use_past_in_inputs:
common_outputs = super().outputs
#if not self.use_past_in_inputs:
# common_outputs = super().outputs
self.use_past_in_inputs = True

# In the other cases, the sequence_length axis is not dynamic, always of length 1
elif self.task == "feature-extraction":
if self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
if self.use_past:
common_outputs = OrderedDict({"logits": {}})
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs
Expand All @@ -602,7 +604,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past_in_inputs and self.use_cache_branch is not False:

print("self._behavior", self._behavior)
if self._behavior is not ConfigBehavior.ENCODER:
input_names.append("past_key_values")

for input_name in input_names:
Expand Down Expand Up @@ -821,7 +825,16 @@ def with_behavior(

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super(OnnxConfigWithPast, self).outputs
# In the other cases, the sequence_length axis is not dynamic, always of length 1
if self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {}})
else:
common_outputs = OrderedDict({"logits": {}})

# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")

"""
# Renaming the outputs axes properly.
for name, axes_names in common_outputs.items():
if self._behavior is ConfigBehavior.ENCODER or "encoder" in name:
Expand All @@ -840,10 +853,12 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
else:
new_axes_names[axis_idx] = axis_name
common_outputs[name] = new_axes_names
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
"""

return common_outputs

Expand All @@ -859,19 +874,17 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
name = "present"

for i in range(self._normalized_config.decoder_num_layers):
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name}

inputs_or_outputs[f"{name}.{i}.decoder.key"] = {}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {}
if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs)
or direction == "inputs"
self._behavior is ConfigBehavior.DECODER
):
# TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export()
# time we have currently no case to check whether we will merge at a later step or not (self.is_merged is
# not yet set at this time)
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"}
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {}

def flatten_past_key_values(self, flattened_output, name, idx, t):
if len(t) not in [2, 4]:
Expand Down
16 changes: 8 additions & 8 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,26 +272,26 @@ class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
DummyAudioInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}
common_inputs["input_features"] = {}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
common_inputs["decoder_input_ids"] = {}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
self.add_past_key_values(common_inputs, direction="inputs")

common_inputs["decoder_attention_mask"] = {}
common_inputs["position_ids"] = {}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
common_inputs["encoder_outputs"] = {}

return common_inputs

Expand Down
28 changes: 20 additions & 8 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,6 @@ def export_pytorch(
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES

Expand All @@ -555,6 +548,7 @@ def remap(value):

dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)

print("---------------- EXPORT")
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
Expand All @@ -567,11 +561,28 @@ def remap(value):
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

print("input_names", input_names)
print("output_names", output_names)
print("dummy_inputs keys", dummy_inputs.keys())

for name, inp in dummy_inputs.items():
if isinstance(inp, torch.Tensor):
print(name, inp.shape)
else:
print(name, type(inp))

if config._behavior == "decoder":
dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"])

model = torch.jit.trace(model, dummy_inputs)
else:
dummy_inputs = (dummy_inputs,)

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
dummy_inputs,
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
Expand Down Expand Up @@ -879,6 +890,7 @@ def export(
"You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed."
)

print("------ FIXING")
if not disable_dynamic_axes_fix:
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
return export_output
11 changes: 0 additions & 11 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,19 +1130,8 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self._behavior is ConfigBehavior.ENCODER:
# For Whisper, we need to name the second axis as encoder_sequence_length / 2 as the axis name is used for
# dummy input generation
common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2"
return common_outputs


class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down
20 changes: 10 additions & 10 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ def patched_forward(*args, **kwargs):

outputs = self.orig_forward(*args, **kwargs)

"""
# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
for name, value in outputs.items():
# filterd_outputs[name] = value
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or name.startswith("past_key_values")
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
if name != "past_key_values":
Expand All @@ -209,16 +211,14 @@ def patched_forward(*args, **kwargs):
else:
filterd_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filterd_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
# filterd_outputs[name] = tuple([v[:2] for v in value])
filterd_outputs[name] = value
"""

return filterd_outputs
#print("filterd_outputs", filterd_outputs.keys())

return outputs

self.patched_forward = patched_forward

Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,7 @@ def infer_library_from_model(
Returns:
`str`: The library name automatically detected from the model repo.
"""
return "transformers" # working offline
if library_name is not None:
return library_name

Expand Down
20 changes: 16 additions & 4 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,13 @@ def __init__(
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
min_value = 0
max_value = 2 if input_name != "input_ids" else self.vocab_size
shape = [self.batch_size, self.sequence_length]

# TODO: fix
if input_name == "decoder_attention_mask":
shape = [self.batch_size, 128]
else:
shape = [self.batch_size, self.sequence_length]

if self.task == "multiple-choice":
shape = [self.batch_size, self.num_choices, self.sequence_length]
return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype)
Expand Down Expand Up @@ -414,7 +420,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
if input_name in ["encoder_outputs", "encoder_hidden_states"]:
return (
self.random_float_tensor(
shape=[self.batch_size, self.sequence_length, self.hidden_size],
shape=[self.batch_size, 1500, self.hidden_size],
min_value=0,
max_value=1,
framework=framework,
Expand Down Expand Up @@ -492,6 +498,8 @@ def __init__(
**kwargs,
):
self.normalized_config = normalized_config
self.batch_size = 1
"""
if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
Expand All @@ -505,6 +513,10 @@ def __init__(
self.encoder_sequence_length = (
self.sequence_length if encoder_sequence_length is None else encoder_sequence_length
)
"""
self.encoder_sequence_length = 1500
self.sequence_length = 128


def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
encoder_shape = (
Expand All @@ -519,15 +531,15 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
self.sequence_length,
self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads,
)
return [
return tuple(
(
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.normalized_config.decoder_num_layers)
]
)


# TODO: should it just be merged to DummyTextInputGenerator?
Expand Down

0 comments on commit 9ac3402

Please sign in to comment.