From 26d97e8fd671e5ad0ddbc20241478a7d93e4b2c0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 10:53:18 +0200 Subject: [PATCH] fix format --- optimum/exporters/onnx/model_patcher.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 77a0345a9c8..e8a1574128b 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -402,7 +402,11 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder" + self.patch = ( + self.real_config.task == "text-generation" + and self.real_config.use_past + and self.real_config._behavior == "decoder" + ) if self.patch: self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") @@ -440,7 +444,6 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) - class MPTModelPatcher(BloomModelPatcher): pass @@ -455,4 +458,3 @@ class BlenderbotModelPatcher(BartModelPatcher): class PegasusModelPatcher(BartModelPatcher): pass -