Skip to content

Commit

Permalink
fix dropout with training=True export
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent b88ed06 commit 918893e
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dataclasses
import functools
import inspect
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
Expand Down Expand Up @@ -355,7 +356,55 @@ def patched_forward(
self.patched_forward = patched_forward


def patched_speecht5_prenet_forward(
self,
input_values: torch.Tensor,
speaker_embeddings: Optional[torch.Tensor] = None,
):
# Dropout is always applied, even when evaluating. See §2.2 in https://arxiv.org/abs/1712.05884.

inputs_embeds = input_values
for layer in self.layers:
inputs_embeds = torch.nn.functional.relu(layer(inputs_embeds))

# NOTE: we patch the prenet to avoid using torch.nn.functional.dropout, that is exported as a `Dropout` node in the ONNX
# that is ignored during inference by some runtimes as ONNX Runtime.
# Reference: https://github.com/microsoft/onnxruntime/issues/9333 & https://github.com/microsoft/onnxruntime/issues/5549
mask = torch.rand(inputs_embeds.shape, device=inputs_embeds.device) > self.config.speech_decoder_prenet_dropout
inputs_embeds = inputs_embeds * mask / (1 - self.config.speech_decoder_prenet_dropout)

# inputs_embeds = nn.functional.dropout(
# inputs_embeds, self.config.speech_decoder_prenet_dropout, training=True
# )

inputs_embeds = self.final_layer(inputs_embeds)
inputs_embeds = self.encode_positions(inputs_embeds)

if speaker_embeddings is not None:
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings)
speaker_embeddings = speaker_embeddings.unsqueeze(1)
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
inputs_embeds = torch.nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))

return inputs_embeds


class SpeechT5ModelPatcher(ModelPatcher):
def __enter__(self):
self.patch_ops()
self._model.speecht5.decoder.prenet.forward = types.MethodType(
patched_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)
setattr(self._model, self.orig_forward_name, self.patched_forward)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)
self._model.speecht5.decoder.prenet.forward = types.MethodType(
self.original_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)

def __init__(
self,
config: "OnnxConfig",
Expand All @@ -364,6 +413,8 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

self.original_speecht5_prenet_forward = model.speecht5.decoder.prenet.forward

model.vocoder = model_kwargs["vocoder_model"].eval()

def patched_forward(
Expand Down

0 comments on commit 918893e

Please sign in to comment.