From 32b3bf622fb38ccf623fabb20ebbbaa7b15a5d9b Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 4 Dec 2023 23:26:24 +0100 Subject: [PATCH 1/3] cpu_offload --- src/accelerate/big_modeling.py | 15 ++++++++++-- src/accelerate/hooks.py | 30 +++++++++++++++++++----- tests/test_big_modeling.py | 43 ++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 8 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index fc0dbb32f63..dd0ed1047d0 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -302,6 +302,7 @@ def dispatch_model( offload_dir: Optional[Union[str, os.PathLike]] = None, offload_index: Optional[Dict[str, str]] = None, offload_buffers: bool = False, + cpu_offload: bool = False, skip_keys: Optional[Union[str, List[str]]] = None, preload_module_classes: Optional[List[str]] = None, force_hooks: bool = False, @@ -321,6 +322,8 @@ def dispatch_model( `"disk"`. state_dict (`Dict[str, torch.Tensor]`, *optional*): The state dict of the part of the model that will be kept on CPU. + cpu_offload (`bool`, *optional*, defaults to `False`): + Whether the weights offloaded on the cpu should be kept in the module or not. offload_dir (`str` or `os.PathLike`): The folder in which to offload the model weights (or where the model weights are already offloaded). offload_index (`Dict`, *optional*): @@ -358,7 +361,7 @@ def dispatch_model( else: main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] - if main_device != "cpu": + if main_device != "cpu" and not cpu_offload: cpu_modules = [name for name, device in device_map.items() if device == "cpu"] if state_dict is None and len(cpu_modules) > 0: state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules) @@ -381,8 +384,11 @@ def dispatch_model( name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items() } execution_device[""] = main_device - offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] + offloaded_devices = ( + ["disk"] if cpu_offload or main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] + ) offload = {name: device in offloaded_devices for name, device in device_map.items()} + cpu_offload = {name: device == "cpu" for name, device in device_map.items()} save_folder = offload_dir if len(disk_modules) > 0 else None if state_dict is not None or save_folder is not None or offload_index is not None: device = main_device if offload_index is not None else None @@ -397,6 +403,7 @@ def dispatch_model( model, execution_device=execution_device, offload=offload, + cpu_offload=cpu_offload, offload_buffers=offload_buffers, weights_map=weights_map, skip_keys=skip_keys, @@ -450,6 +457,7 @@ def load_checkpoint_and_dispatch( no_split_module_classes: Optional[List[str]] = None, offload_folder: Optional[Union[str, os.PathLike]] = None, offload_buffers: bool = False, + cpu_offload: bool = False, dtype: Optional[Union[str, torch.dtype]] = None, offload_state_dict: Optional[bool] = None, skip_keys: Optional[Union[str, List[str]]] = None, @@ -484,6 +492,8 @@ def load_checkpoint_and_dispatch( offload_buffers (`bool`, *optional*, defaults to `False`): In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. + cpu_offload (`bool`, *optional*, defaults to `False`): + Whether the weights offloaded on the cpu should be kept in the module or not. dtype (`str` or `torch.dtype`, *optional*): If provided, the weights will be converted to that type when loaded. offload_state_dict (`bool`, *optional*): @@ -558,6 +568,7 @@ def load_checkpoint_and_dispatch( device_map=device_map, offload_dir=offload_folder, offload_buffers=offload_buffers, + cpu_offload=cpu_offload, skip_keys=skip_keys, preload_module_classes=preload_module_classes, force_hooks=force_hooks, diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index d87f1c18db3..7ff1da63684 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -208,6 +208,8 @@ class AlignDevicesHook(ModelHook): The device on which inputs and model weights should be placed before the forward pass. offload (`bool`, *optional*, defaults to `False`): Whether or not the weights should be offloaded after the forward pass. + cpu_offload (`bool`, *optional*, defaults to `False`): + Whether the weights offloaded on the cpu should be kept in the module or not. io_same_device (`bool`, *optional*, defaults to `False`): Whether or not the output should be placed on the same device as the input was. weights_map (`Mapping[str, torch.Tensor]`, *optional*): @@ -222,6 +224,7 @@ def __init__( self, execution_device: Optional[Union[int, str, torch.device]] = None, offload: bool = False, + cpu_offload: bool = False, io_same_device: bool = False, weights_map: Optional[Mapping] = None, offload_buffers: bool = False, @@ -230,12 +233,12 @@ def __init__( ): self.execution_device = execution_device self.offload = offload + self.cpu_offload = cpu_offload self.io_same_device = io_same_device self.weights_map = weights_map self.offload_buffers = offload_buffers self.place_submodules = place_submodules self.skip_keys = skip_keys - # Will contain the input device when `io_same_device=True`. self.input_device = None self.param_original_devices = {} @@ -249,7 +252,7 @@ def __repr__(self): ) def init_hook(self, module): - if not self.offload and self.execution_device is not None: + if not self.offload and self.execution_device is not None and not self.cpu_offload: for name, _ in named_module_tensors(module, recurse=self.place_submodules): set_module_tensor_to_device(module, name, self.execution_device) elif self.offload: @@ -273,7 +276,9 @@ def init_hook(self, module): elif self.offload_buffers and self.execution_device is not None: for name in get_non_persistent_buffers(module, recurse=self.place_submodules): set_module_tensor_to_device(module, name, self.execution_device) - + elif self.cpu_offload: + for name, _ in named_module_tensors(module, recurse=self.place_submodules): + set_module_tensor_to_device(module, name, "cpu") return module def pre_forward(self, module, *args, **kwargs): @@ -293,7 +298,9 @@ def pre_forward(self, module, *args, **kwargs): set_module_tensor_to_device( module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics ) - + elif self.cpu_offload: + for name, _ in named_module_tensors(module, recurse=self.place_submodules): + set_module_tensor_to_device(module, name, self.execution_device) return send_to_device(args, self.execution_device), send_to_device( kwargs, self.execution_device, skip_keys=self.skip_keys ) @@ -310,7 +317,9 @@ def post_forward(self, module, output): if type(module).__name__ == "Linear8bitLt": module.state.SCB = None module.state.CxB = None - + elif self.cpu_offload: + for name, _ in named_module_tensors(module, recurse=self.place_submodules): + set_module_tensor_to_device(module, name, "cpu") if self.io_same_device and self.input_device is not None: output = send_to_device(output, self.input_device, skip_keys=self.skip_keys) @@ -450,6 +459,7 @@ def attach_align_device_hook_on_blocks( module: nn.Module, execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None, offload: Union[bool, Dict[str, bool]] = False, + cpu_offload: bool = False, weights_map: Mapping = None, offload_buffers: bool = False, module_name: str = "", @@ -468,6 +478,8 @@ def attach_align_device_hook_on_blocks( offload (`bool`, *optional*, defaults to `False`): Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole module, or a dictionary mapping module name to boolean. + cpu_offload (`bool`, *optional*, defaults to `False`): + Whether the weights offloaded on the cpu should be kept in the module or not. weights_map (`Mapping[str, torch.Tensor]`, *optional*): When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. offload_buffers (`bool`, *optional*, defaults to `False`): @@ -506,9 +518,14 @@ def attach_align_device_hook_on_blocks( if not isinstance(offload, Mapping): offload = {key: offload for key in execution_device.keys()} - if module_name in execution_device and module_name in offload and not offload[module_name]: + if ( + module_name in execution_device + and module_name in offload + and (not offload[module_name] or cpu_offload[module_name]) + ): hook = AlignDevicesHook( execution_device=execution_device[module_name], + cpu_offload=cpu_offload[module_name], offload_buffers=offload_buffers, io_same_device=(module_name == ""), place_submodules=True, @@ -548,6 +565,7 @@ def attach_align_device_hook_on_blocks( child, execution_device=execution_device, offload=offload, + cpu_offload=cpu_offload, weights_map=weights_map, offload_buffers=offload_buffers, module_name=child_name, diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index 51ce4a899e4..37fc0efc7a9 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -338,6 +338,24 @@ def test_dispatch_model_with_non_persistent_buffers(self): with TemporaryDirectory() as tmp_dir: dispatch_model(model, device_map, offload_dir=tmp_dir, offload_buffers=True) + + output = model(x) + self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + + def test_dispatch_model_with_cpu_offload(self): + model = ModelForTest() + device_map = {"linear1": "disk", "batchnorm": "cpu", "linear2": 0} + + x = torch.randn(2, 3) + expected = model(x) + + with TemporaryDirectory() as tmp_dir: + dispatch_model(model, device_map, offload_dir=tmp_dir, cpu_offload=True) + + self.assertEqual(model.linear1.weight.device, torch.device("meta")) + self.assertEqual(model.batchnorm.weight.device, torch.device("cpu")) + self.assertEqual(model.linear2.weight.device, torch.device(0)) + output = model(x) self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) @@ -548,6 +566,31 @@ def test_load_checkpoint_and_dispatch(self): output = new_model(x) self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + @require_cuda + def test_load_checkpoint_and_dispatch_with_cpu_offload(self): + model = ModelForTest() + device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": 0} + + x = torch.randn(2, 3) + expected = model(x) + + with TemporaryDirectory() as tmp_dir: + checkpoint = os.path.join(tmp_dir, "pt_model.bin") + torch.save(model.state_dict(), checkpoint) + + new_model = ModelForTest() + new_model = load_checkpoint_and_dispatch( + new_model, checkpoint, device_map=device_map, cpu_offload=True, offload_folder=tmp_dir + ) + + # CPU-offloaded weights are on the meta device while waiting for the forward pass. + self.assertEqual(new_model.linear1.weight.device, torch.device("cpu")) + self.assertEqual(new_model.batchnorm.weight.device, torch.device("meta")) + self.assertEqual(new_model.linear2.weight.device, torch.device(0)) + + output = new_model(x) + self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + @require_mps def test_load_checkpoint_and_dispatch_mps(self): model = ModelForTest() From 223a946aeb8f8f5ca55b7ac7e55b964a27ceeca1 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 4 Dec 2023 23:58:20 +0100 Subject: [PATCH 2/3] fix log --- src/accelerate/big_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index dd0ed1047d0..b2852ec20fc 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -412,7 +412,7 @@ def dispatch_model( # warn if there is any params on the meta device offloaded_devices_str = " and ".join( - [device for device in set(device_map.values()) if device in ("cpu", "disk")] + [device for device in set(device_map.values()) if device in offloaded_devices] ) if len(offloaded_devices_str) > 0: logging.warning( From 08e16f0482ba0ca17cf1a9347fff513ca70c1721 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 5 Dec 2023 15:49:21 +0100 Subject: [PATCH 3/3] fix --- src/accelerate/big_modeling.py | 3 ++- src/accelerate/hooks.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index b2852ec20fc..0ca443ad778 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -388,7 +388,8 @@ def dispatch_model( ["disk"] if cpu_offload or main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] ) offload = {name: device in offloaded_devices for name, device in device_map.items()} - cpu_offload = {name: device == "cpu" for name, device in device_map.items()} + if cpu_offload: + cpu_offload = {name: device == "cpu" for name, device in device_map.items()} save_folder = offload_dir if len(disk_modules) > 0 else None if state_dict is not None or save_folder is not None or offload_index is not None: device = main_device if offload_index is not None else None diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 7ff1da63684..c058dd507f5 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -459,7 +459,7 @@ def attach_align_device_hook_on_blocks( module: nn.Module, execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None, offload: Union[bool, Dict[str, bool]] = False, - cpu_offload: bool = False, + cpu_offload: Union[bool, Dict[str, bool]] = False, weights_map: Mapping = None, offload_buffers: bool = False, module_name: str = "", @@ -478,7 +478,7 @@ def attach_align_device_hook_on_blocks( offload (`bool`, *optional*, defaults to `False`): Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole module, or a dictionary mapping module name to boolean. - cpu_offload (`bool`, *optional*, defaults to `False`): + cpu_offload (`Union[bool, Dict[str, bool]]`, *optional*, defaults to `False`): Whether the weights offloaded on the cpu should be kept in the module or not. weights_map (`Mapping[str, torch.Tensor]`, *optional*): When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. @@ -517,7 +517,8 @@ def attach_align_device_hook_on_blocks( execution_device = {key: execution_device for key in offload.keys()} if not isinstance(offload, Mapping): offload = {key: offload for key in execution_device.keys()} - + if not isinstance(cpu_offload, Mapping): + cpu_offload = {key: cpu_offload for key in execution_device.keys()} if ( module_name in execution_device and module_name in offload