diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5bb96d8c224..cc34cd5cb1c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1372,7 +1372,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e with torch.no_grad(): convert_model(model) model._converted_to_transformer_engine = True - #model._original_forward = model.forward + model._original_forward = model.forward kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {} if "fp8_format" in kwargs: