From 60461ff7c479b9ea60757ec18279ad3cf84d29cb Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Tue, 3 Dec 2024 20:44:59 +0800 Subject: [PATCH] Fix: Resolve #3060, `preload_module_classes` is lost for nested modules (#3248) * resolve 3060 * format * add tests * fix * fix * format --- src/accelerate/hooks.py | 8 ++++- tests/test_accelerator.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 14d57e33661..50098eaac01 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -436,7 +436,13 @@ def attach_execution_device_hook( return for child in module.children(): - attach_execution_device_hook(child, execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map) + attach_execution_device_hook( + child, + execution_device, + skip_keys=skip_keys, + preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, + ) def attach_align_device_hook( diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 9b18fe5c909..651dc17da5e 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -30,6 +30,7 @@ from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, + require_huggingface_suite, require_multi_gpu, require_non_cpu, require_transformer_engine, @@ -762,3 +763,65 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights assert torch.allclose(original_linear1, new_linear1) assert torch.allclose(original_batchnorm, new_batchnorm) assert torch.allclose(original_linear2, new_linear2) + + @require_cuda + @require_huggingface_suite + def test_nested_hook(self, use_safetensors): + from transformers.modeling_utils import PretrainedConfig, PreTrainedModel + + class MyLinear(torch.nn.Module): + def __init__(self, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.centroid = torch.nn.Embedding(1, 2) + self.indices = torch.nn.parameter(torch.empty((1, 2, 2), **factory_kwargs)) + + def forward(self, x): + orig_shape = x.shape + x = torch.abs(x + self.indices).long() + x = x % 2 + x = x.sum(-1) + x = (self.centroid.weight + x).reshape(orig_shape) + return x + + class MySubModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = MyLinear() + + def forward(self, x): + return self.layer(x) + + class MyModel(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layer = torch.nn.ModuleList([MySubModel() for i in range(4)]) + + def forward(self, x): + for layer in self.layer: + x = layer(x) + return x + + with tempfile.TemporaryDirectory() as tmpdirname: + check_point = tmpdirname + offload_folder = check_point + "/offload" + os.makedirs(offload_folder, exist_ok=True) + config = PretrainedConfig() + m = MyModel(config) + m.save_pretrained(check_point) + + with init_empty_weights(): + my_model = MyModel(config) + my_model = load_checkpoint_and_dispatch( + my_model, + checkpoint=check_point, + max_memory={"cpu": 60, 0: 60}, + device_map="auto", + no_split_module_classes=["MySubModel"], + offload_folder=offload_folder, + preload_module_classes=["MyLinear"], + ) + # before fix, this would raise an error + # weight is on the meta device, we need a `value` to put in on 0 + x = torch.randn(1, 2) + my_model(x)