diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 8d067bea615..0fb55baabb0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -396,7 +396,7 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) -class OPTModelPatcher(ModelPatcher): +class BartModelPatcher(Seq2SeqModelPatcher): def __init__( self, config: "OnnxConfig", @@ -404,32 +404,33 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: + self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) -class BlenderbotSmallModelPatcher(OPTModelPatcher): +class BlenderbotSmallModelPatcher(BartModelPatcher): pass -class BlenderbotModelPatcher(OPTModelPatcher): +class BlenderbotModelPatcher(BartModelPatcher): pass -class PegasusModelPatcher(OPTModelPatcher): +class PegasusModelPatcher(BartModelPatcher): pass -class BartModelPatcher(OPTModelPatcher): +class OPTModelPatcher(BartModelPatcher): pass