Skip to content

Commit

Permalink
working whisper & opt stati
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent 21f3b07 commit 1a972b3
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 60 deletions.
58 changes: 28 additions & 30 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,10 @@ def __init__(
def outputs(self) -> Dict[str, Dict[int, str]]:
#if not self.use_past_in_inputs:
# common_outputs = super().outputs
self.use_past_in_inputs = True

# In the other cases, the sequence_length axis is not dynamic, always of length 1
if self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
common_outputs = OrderedDict({"last_hidden_state": {}})
else:
common_outputs = OrderedDict({"logits": {}})

Expand All @@ -605,8 +604,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]

print("self._behavior", self._behavior)
if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
input_names.append("past_key_values")

for input_name in input_names:
Expand All @@ -627,30 +625,30 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

# refer to https://github.com/huggingface/optimum/pull/764
if (
self.use_past_in_inputs
and self.PAD_ATTENTION_MASK_TO_PAST
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)

if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs:
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["decoder_attention_mask"],
desired_length=past_length + 1,
dim=1,
dtype=dummy_inputs["decoder_attention_mask"].dtype,
)
# if (
# self.use_past_in_inputs
# and self.PAD_ATTENTION_MASK_TO_PAST
# and self.use_cache_branch is not False
# and "attention_mask" in dummy_inputs
# ):
# # Obtain the past sequence length from the value instead of the key (Bloom).
# past_length = dummy_inputs["past_key_values"][0][1].shape[-2]

# dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
# dummy_inputs["attention_mask"],
# desired_length=past_length + 1,
# dim=1,
# dtype=dummy_inputs["attention_mask"].dtype,
# )

# if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs:
# past_length = dummy_inputs["past_key_values"][0][0].shape[2]
# dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
# dummy_inputs["decoder_attention_mask"],
# desired_length=past_length + 1,
# dim=1,
# dtype=dummy_inputs["decoder_attention_mask"].dtype,
# )

return dummy_inputs

Expand Down Expand Up @@ -706,8 +704,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.key"] = {} # static shapes
inputs_or_outputs[f"{name}.{i}.value"] = {}

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]
Expand Down
14 changes: 7 additions & 7 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["attention_mask"] = {}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"input_ids": {},
"attention_mask": {},
}
return common_inputs

Expand All @@ -109,7 +109,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
else:
# in the merged case, we need to allow the `sequence_length` to be variable, as it is not 1
# during the first pass without past key values
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
common_outputs = OrderedDict({"logits": {}})
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

Expand Down Expand Up @@ -165,9 +165,9 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
common_inputs["position_ids"] = {}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {}

return common_inputs

Expand Down
37 changes: 18 additions & 19 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,23 @@ def validate_models_outputs(
else output_dir.joinpath(model_name + ".onnx")
)
onnx_paths.append(onnx_model_path)
try:
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
# Reference: https://github.com/huggingface/optimum/pull/1115
validate_model_outputs(
config=sub_onnx_config,
reference_model=submodel,
onnx_model=onnx_model_path,
onnx_named_outputs=onnx_named_outputs[i],
atol=atol,
input_shapes=input_shapes,
device=device,
dtype=dtype,
use_subprocess=use_subprocess,
model_kwargs=model_kwargs,
)
except Exception as e:
exceptions.append(e)
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
# Reference: https://github.com/huggingface/optimum/pull/1115
validate_model_outputs(
config=sub_onnx_config,
reference_model=submodel,
onnx_model=onnx_model_path,
onnx_named_outputs=onnx_named_outputs[i],
atol=atol,
input_shapes=input_shapes,
device=device,
dtype=dtype,
use_subprocess=use_subprocess,
model_kwargs=model_kwargs,
)
#except Exception as e:
# exceptions.append(e)

if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
Expand Down Expand Up @@ -571,7 +570,7 @@ def remap(value):
else:
print(name, type(inp))

if config._behavior == "decoder":
if model.config.model_type == "whisper" and config._behavior == "decoder":
dummy_inputs = (dummy_inputs["decoder_input_ids"], dummy_inputs["decoder_attention_mask"], dummy_inputs["encoder_outputs"][0], dummy_inputs["past_key_values"], dummy_inputs["position_ids"])

model = torch.jit.trace(model, dummy_inputs)
Expand Down
28 changes: 27 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class OPTOnnxConfig(TextDecoderOnnxConfig):
class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# OPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand Down Expand Up @@ -1127,6 +1127,32 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
ATOL_FOR_VALIDATION = 1e-3

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=preprocessors,
)
if self._behavior is ConfigBehavior.ENCODER:
self.use_past_in_inputs = False
else:
self.use_past_in_inputs = True

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
Expand Down
4 changes: 1 addition & 3 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
else:
self.real_config = config

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past
allow_past_in_outputs = (hasattr(self.real_config, "use_past") and self.real_config.use_past)

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
Expand Down Expand Up @@ -216,8 +216,6 @@ def patched_forward(*args, **kwargs):
filterd_outputs[name] = value
"""

#print("filterd_outputs", filterd_outputs.keys())

return outputs

self.patched_forward = patched_forward
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int

if self.task == "multiple-choice":
shape = [self.batch_size, self.num_choices, self.sequence_length]

return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype)


Expand Down

0 comments on commit 1a972b3

Please sign in to comment.