Skip to content

Commit

Permalink
fix bart model patcher
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 19, 2023
1 parent 33957af commit c2ec382
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,40 +396,41 @@ def __exit__(self, exc_type, exc_value, traceback):
setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class OPTModelPatcher(ModelPatcher):
class BartModelPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")
if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

def __enter__(self):
super().__enter__()
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if self.real_config.task == "text-generation" and self.real_config.use_past:
if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past:
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)


class BlenderbotSmallModelPatcher(OPTModelPatcher):
class BlenderbotSmallModelPatcher(BartModelPatcher):
pass


class BlenderbotModelPatcher(OPTModelPatcher):
class BlenderbotModelPatcher(BartModelPatcher):
pass


class PegasusModelPatcher(OPTModelPatcher):
class PegasusModelPatcher(BartModelPatcher):
pass


class BartModelPatcher(OPTModelPatcher):
class OPTModelPatcher(BartModelPatcher):
pass


0 comments on commit c2ec382

Please sign in to comment.