Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 20, 2023
1 parent b05f599 commit 26d97e8
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder"
self.patch = (
self.real_config.task == "text-generation"
and self.real_config.use_past
and self.real_config._behavior == "decoder"
)
if self.patch:
self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask")

Expand Down Expand Up @@ -440,7 +444,6 @@ def __exit__(self, exc_type, exc_value, traceback):
setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask)



class MPTModelPatcher(BloomModelPatcher):
pass

Expand All @@ -455,4 +458,3 @@ class BlenderbotModelPatcher(BartModelPatcher):

class PegasusModelPatcher(BartModelPatcher):
pass

0 comments on commit 26d97e8

Please sign in to comment.