diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index cc34cd5cb1c..e35c7e723b7 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1357,8 +1357,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." ) + model._accelerate_original_forward_id = id(model.forward) if self.native_amp: - model._original_forward = model.forward model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) new_forward = autocast_context(model_forward_func) @@ -1372,7 +1372,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e with torch.no_grad(): convert_model(model) model._converted_to_transformer_engine = True - model._original_forward = model.forward kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {} if "fp8_format" in kwargs: diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 970b2969478..6dc84b54cfa 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -86,14 +86,16 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True): model = model.module if not keep_fp32_wrapper: + original_forward_id = getattr(model, "_accelerate_original_forward_id", None) forward = getattr(model, "forward") - original_forward = model.__dict__.pop("_original_forward", None) - if original_forward is not None: + forward_func = getattr(forward, "__func__", None) + if (original_forward_id is not None) and (id(forward) != original_forward_id): while hasattr(forward, "__wrapped__"): forward = forward.__wrapped__ - if forward == original_forward: + forward_func = getattr(forward, "__func__", None) + if id(forward) == original_forward_id or id(forward_func) == original_forward_id: break - model.forward = MethodType(forward, model) + model.forward = MethodType(forward_func if forward_func is not None else forward, model) if getattr(model, "_converted_to_transformer_engine", False): convert_model(model, to_transformer_engine=False) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2bf86314eed..eb3c65ffe31 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -132,7 +132,7 @@ def test_patch_environment_key_exists(self): def test_can_undo_convert_outputs(self): model = RegressionModel() - model._original_forward = model.forward + model._accelerate_original_forward_id = id(model.forward) model.forward = convert_outputs_to_fp32(model.forward) model = extract_model_from_parallel(model, keep_fp32_wrapper=False) _ = pickle.dumps(model) @@ -140,7 +140,7 @@ def test_can_undo_convert_outputs(self): @require_cuda def test_can_undo_fp16_conversion(self): model = RegressionModel() - model._original_forward = model.forward + model._accelerate_original_forward_id = id(model.forward) model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward) model.forward = convert_outputs_to_fp32(model.forward) model = extract_model_from_parallel(model, keep_fp32_wrapper=False) @@ -150,7 +150,7 @@ def test_can_undo_fp16_conversion(self): @require_torch_min_version(version="2.0") def test_dynamo(self): model = RegressionModel() - model._original_forward = model.forward + model._accelerate_original_forward_id = id(model.forward) model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward) model.forward = convert_outputs_to_fp32(model.forward) model.forward = torch.compile(model.forward, backend="inductor")