diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 550b8f052eb..b217feedc23 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -43,6 +43,7 @@ parse_flag_from_env, retie_parameters, ) +from .utils.other import recursive_getattr logger = logging.getLogger(__name__) @@ -395,7 +396,22 @@ def dispatch_model( else: weights_map = None + # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the + # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its + # original pointer) on each devices. tied_params = find_tied_parameters(model) + + tied_params_map = {} + for group in tied_params: + for param_name in group: + # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need + # to care about views of tensors through storage_offset. + data_ptr = recursive_getattr(model, param_name).data_ptr() + tied_params_map[data_ptr] = {} + + # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer, + # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer. + attach_align_device_hook_on_blocks( model, execution_device=execution_device, @@ -404,6 +420,7 @@ def dispatch_model( weights_map=weights_map, skip_keys=skip_keys, preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, ) # warn if there is any params on the meta device diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index d87f1c18db3..c6ccd472199 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -27,6 +27,7 @@ set_module_tensor_to_device, ) from .utils.modeling import get_non_persistent_buffers +from .utils.other import recursive_getattr class ModelHook: @@ -116,7 +117,9 @@ def detach_hook(self, module): return module -def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False): +def add_hook_to_module( + module: nn.Module, hook: ModelHook, append: bool = False, init_hook_kwargs: Optional[Dict] = None +): """ Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove this behavior and restore the original `forward` method, use `remove_hook_from_module`. @@ -135,6 +138,8 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False) The hook to attach. append (`bool`, *optional*, defaults to `False`): Whether the hook should be chained with an existing one (if module already contains a hook) or not. + init_hook_kwargs (Optional[Dict], *optional*, defaults to `None`): + Optional arguments to pass to the hook initialization. Returns: `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can @@ -153,7 +158,10 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False) old_forward = module.forward module._old_forward = old_forward - module = hook.init_hook(module) + if init_hook_kwargs is None: + init_hook_kwargs = {} + + module = hook.init_hook(module, **init_hook_kwargs) module._hf_hook = hook def new_forward(module, *args, **kwargs): @@ -240,6 +248,7 @@ def __init__( self.input_device = None self.param_original_devices = {} self.buffer_original_devices = {} + self.tied_params_names = set() def __repr__(self): return ( @@ -248,10 +257,14 @@ def __repr__(self): f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})" ) - def init_hook(self, module): + def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None): + # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero. + if self.execution_device == "meta" or self.execution_device == torch.device("meta"): + tied_params_map = None + if not self.offload and self.execution_device is not None: for name, _ in named_module_tensors(module, recurse=self.place_submodules): - set_module_tensor_to_device(module, name, self.execution_device) + set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map) elif self.offload: self.original_devices = { name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules) @@ -266,13 +279,25 @@ def init_hook(self, module): for name, _ in named_module_tensors( module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True ): + # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer, + # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer. + # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str] + # to add on the fly pointers to `tied_params_map` in the pre_forward call. + if tied_params_map is not None and recursive_getattr(module, name).data_ptr() in tied_params_map: + self.tied_params_names.add(name) + set_module_tensor_to_device(module, name, "meta") + if not self.offload_buffers and self.execution_device is not None: for name, _ in module.named_buffers(recurse=self.place_submodules): - set_module_tensor_to_device(module, name, self.execution_device) + set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map) elif self.offload_buffers and self.execution_device is not None: for name in get_non_persistent_buffers(module, recurse=self.place_submodules): - set_module_tensor_to_device(module, name, self.execution_device) + set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map) + + # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory + # for tied weights already loaded on the target execution device. + self.tied_params_map = tied_params_map return module @@ -280,6 +305,8 @@ def pre_forward(self, module, *args, **kwargs): if self.io_same_device: self.input_device = find_device([args, kwargs]) if self.offload: + self.tied_pointers_to_remove = set() + for name, _ in named_module_tensors( module, include_buffers=self.offload_buffers, @@ -287,11 +314,32 @@ def pre_forward(self, module, *args, **kwargs): remove_non_persistent=True, ): fp16_statistics = None + value = self.weights_map[name] if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys(): - if self.weights_map[name].dtype == torch.int8: + if value.dtype == torch.int8: fp16_statistics = self.weights_map[name.replace("weight", "SCB")] + + # In case we are using offloading with tied weights, we need to keep track of the offloaded weights + # that are loaded on device at this point, as we will need to remove them as well from the dictionary + # self.tied_params_map in order to allow to free memory. + if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map: + self.tied_params_map[value.data_ptr()] = {} + + if ( + value is not None + and self.tied_params_map is not None + and value.data_ptr() in self.tied_params_map + and self.execution_device not in self.tied_params_map[value.data_ptr()] + ): + self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device)) + set_module_tensor_to_device( - module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics + module, + name, + self.execution_device, + value=value, + fp16_statistics=fp16_statistics, + tied_params_map=self.tied_params_map, ) return send_to_device(args, self.execution_device), send_to_device( @@ -311,6 +359,12 @@ def post_forward(self, module, output): module.state.SCB = None module.state.CxB = None + # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from + # this dictionary to allow the garbage collector to do its job. + for value_pointer, device in self.tied_pointers_to_remove: + del self.tied_params_map[value_pointer][device] + self.tied_pointers_to_remove = None + if self.io_same_device and self.input_device is not None: output = send_to_device(output, self.input_device, skip_keys=self.skip_keys) @@ -329,6 +383,7 @@ def attach_execution_device_hook( execution_device: Union[int, str, torch.device], skip_keys: Optional[Union[str, List[str]]] = None, preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, ): """ Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right @@ -346,16 +401,24 @@ def attach_execution_device_hook( of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. """ if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0: - add_hook_to_module(module, AlignDevicesHook(execution_device, skip_keys=skip_keys)) + add_hook_to_module( + module, + AlignDevicesHook(execution_device, skip_keys=skip_keys), + init_hook_kwargs={"tied_params_map": tied_params_map}, + ) # Break the recursion if we get to a preload module. if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes: return for child in module.children(): - attach_execution_device_hook(child, execution_device) + attach_execution_device_hook(child, execution_device, tied_params_map=tied_params_map) def attach_align_device_hook( @@ -367,6 +430,7 @@ def attach_align_device_hook( module_name: str = "", skip_keys: Optional[Union[str, List[str]]] = None, preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, ): """ Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or @@ -392,6 +456,10 @@ def attach_align_device_hook( of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. """ # Attach the hook on this module if it has any direct tensor. directs = named_module_tensors(module) @@ -413,7 +481,7 @@ def attach_align_device_hook( place_submodules=full_offload, skip_keys=skip_keys, ) - add_hook_to_module(module, hook, append=True) + add_hook_to_module(module, hook, append=True, init_hook_kwargs={"tied_params_map": tied_params_map}) # We stop the recursion in case we hit the full offload. if full_offload: @@ -431,6 +499,7 @@ def attach_align_device_hook( module_name=child_name, preload_module_classes=preload_module_classes, skip_keys=skip_keys, + tied_params_map=tied_params_map, ) @@ -455,6 +524,7 @@ def attach_align_device_hook_on_blocks( module_name: str = "", skip_keys: Optional[Union[str, List[str]]] = None, preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, ): """ Attaches `AlignDevicesHook` to all blocks of a given model as needed. @@ -481,6 +551,10 @@ def attach_align_device_hook_on_blocks( of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. """ # If one device and one offload, we've got one hook. if not isinstance(execution_device, Mapping) and not isinstance(offload, dict): @@ -488,7 +562,7 @@ def attach_align_device_hook_on_blocks( hook = AlignDevicesHook( execution_device=execution_device, io_same_device=True, skip_keys=skip_keys, place_submodules=True ) - add_hook_to_module(module, hook) + add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map}) else: attach_align_device_hook( module, @@ -498,6 +572,7 @@ def attach_align_device_hook_on_blocks( offload_buffers=offload_buffers, module_name=module_name, skip_keys=skip_keys, + tied_params_map=tied_params_map, ) return @@ -514,8 +589,8 @@ def attach_align_device_hook_on_blocks( place_submodules=True, skip_keys=skip_keys, ) - add_hook_to_module(module, hook) - attach_execution_device_hook(module, execution_device[module_name]) + add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map}) + attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map) elif module_name in execution_device and module_name in offload: attach_align_device_hook( module, @@ -526,21 +601,23 @@ def attach_align_device_hook_on_blocks( module_name=module_name, skip_keys=skip_keys, preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, ) if not hasattr(module, "_hf_hook"): hook = AlignDevicesHook( execution_device=execution_device[module_name], io_same_device=(module_name == ""), skip_keys=skip_keys ) - add_hook_to_module(module, hook) + add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map}) attach_execution_device_hook( module, execution_device[module_name], preload_module_classes=preload_module_classes, skip_keys=skip_keys, + tied_params_map=tied_params_map, ) elif module_name == "": hook = AlignDevicesHook(execution_device=execution_device.get(""), io_same_device=True, skip_keys=skip_keys) - add_hook_to_module(module, hook) + add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map}) for child_name, child in module.named_children(): child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name @@ -553,6 +630,7 @@ def attach_align_device_hook_on_blocks( module_name=child_name, preload_module_classes=preload_module_classes, skip_keys=skip_keys, + tied_params_map=tied_params_map, ) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 4c9cd006547..c299284856c 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -188,6 +188,7 @@ is_port_in_use, merge_dicts, patch_environment, + recursive_getattr, save, wait_for_everyone, write_basic_config, diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 201cb1cff96..83b49206447 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -267,6 +267,7 @@ def set_module_tensor_to_device( value: Optional[torch.Tensor] = None, dtype: Optional[Union[str, torch.dtype]] = None, fp16_statistics: Optional[torch.HalfTensor] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, ): """ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing @@ -286,6 +287,10 @@ def set_module_tensor_to_device( the dtype of the existing parameter in the model. fp16_statistics (`torch.HalfTensor`, *optional*): The list of fp16 statistics to set on the module, used for 8 bit model serialization. + tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): + A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given + execution device, this parameter is useful to reuse the first available pointer of a shared weight on the + device for all others, instead of duplicating memory. """ # Recurse if needed if "." in tensor_name: @@ -302,6 +307,24 @@ def set_module_tensor_to_device( is_buffer = tensor_name in module._buffers old_value = getattr(module, tensor_name) + # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight + # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer. + if ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device in tied_params_map[value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device] + return + elif ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device in tied_params_map[old_value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device] + return + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") @@ -367,6 +390,7 @@ def set_module_tensor_to_device( new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device) else: new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) + module._parameters[tensor_name] = new_value if fp16_statistics is not None: setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device)) @@ -397,6 +421,22 @@ def set_module_tensor_to_device( else: torch.cuda.empty_cache() + # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in + # order to avoid duplicating memory, see above. + if ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device not in tied_params_map[old_value.data_ptr()] + ): + tied_params_map[old_value.data_ptr()][device] = new_value + elif ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device not in tied_params_map[value.data_ptr()] + ): + tied_params_map[value.data_ptr()][device] = new_value + def named_module_tensors( module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False @@ -832,6 +872,7 @@ def get_balanced_memory( The model to analyze. max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. no_split_module_classes (`List[str]`, *optional*): A list of layer class names that should never be split across device (for instance any layer that has a residual connection). @@ -989,6 +1030,7 @@ def infer_auto_device_map( The model to analyze. max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. no_split_module_classes (`List[str]`, *optional*): A list of layer class names that should never be split across device (for instance any layer that has a residual connection). diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 1ab20264717..e9bf8f142fc 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -18,7 +18,7 @@ import re import socket from contextlib import contextmanager -from functools import partial +from functools import partial, reduce from types import MethodType from typing import OrderedDict @@ -320,3 +320,20 @@ def check_os_kernel(): "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher." ) logger.warning(msg, main_process_only=True) + + +def recursive_getattr(obj, attr: str): + """ + Recursive `getattr`. + + Args: + obj: + A class instance holding the attribute. + attr (`str`): + The attribute that is to be retrieved, e.g. 'attribute1.attribute2'. + """ + + def _getattr(obj, attr): + return getattr(obj, attr) + + return reduce(_getattr, [obj] + attr.split(".")) diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index 51ce4a899e4..759ade2ab3d 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -14,6 +14,7 @@ import copy import os import unittest +from collections import OrderedDict from tempfile import TemporaryDirectory import torch @@ -363,6 +364,235 @@ def test_dispatch_model_tied_weights(self): dispatch_model(model, device_map) self.assertIs(model.linear2.weight, model.linear1.weight) + @require_multi_gpu + def test_dispatch_model_tied_weights_memory(self): + # Test that we do not duplicate tied weights at any point during dispatch_model call. + + torch.cuda.empty_cache() # Needed in case we run several tests in a row. + + model = nn.Sequential( + OrderedDict( + [ + ("linear0", nn.Linear(5000, 5000, bias=False)), + ("linear1", nn.Linear(5000, 5000, bias=False)), + ("linear2", nn.Linear(5000, 5000, bias=False)), + ("linear3", nn.Linear(5000, 5000, bias=False)), + ("linear4", nn.Linear(5000, 5000, bias=False)), + ] + ) + ) + model.linear2.weight = model.linear0.weight + model.linear3.weight = model.linear0.weight + model.linear4.weight = model.linear0.weight + + x = torch.randn(5, 5000) + with torch.no_grad(): + expected = model(x) + + # We should need only 5000 * 5000 * 32 // 8 * 1e-6 = 100 MB on the device 0 for the four linear weights. + device_map = {"linear0": 0, "linear1": 1, "linear2": 0, "linear3": 0, "linear4": 0} + + # Just to intialize CUDA context. + a = torch.rand(5).to("cuda:0") # noqa: F841 + + free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0] + required_memory_bytes = 5000 * 5000 * (32 // 8) + + # Leaving 50 MB of free memory for possible buffers, etc. + n_vals = (free_memory_bytes - required_memory_bytes - int(50e6)) // (32 // 8) + foo = torch.rand(n_vals, device="cuda:0") # noqa: F841 + + # If this does OOM: there is an issue in somewhere in dispatch_model, memory of tied weights is duplicated. + try: + dispatch_model(model, device_map) + except torch.cuda.OutOfMemoryError as e: + raise torch.cuda.OutOfMemoryError( + f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory. {e}" + ) + except Exception as e: + raise e + + with torch.no_grad(): + output = model(x) + self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + + @require_cuda + def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self): + # Test that we do not duplicate tied weights at any point during dispatch_model call. + + torch.cuda.empty_cache() # Needed in case we run several tests in a row. + + class SubModule(torch.nn.Module): + def __init__(self, ref_to_parameter): + super().__init__() + self.parameter = ref_to_parameter + + def forward(self, x): + return x + torch.max(self.parameter) + + class LinearModuleAndSubModule(torch.nn.Linear): + def __init__(self, in_features, out_features): + super().__init__(in_features, out_features, bias=False) + self.weight_submodule = SubModule(self.weight) + self.weight_submodule2 = SubModule(self.weight) + self.weight_submodule3 = SubModule(self.weight) + self.weight_submodule4 = SubModule(self.weight) + + def forward(self, x): + a = torch.nn.functional.linear(self.weight_submodule(x), self.weight) + b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight) + c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight) + d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight) + return a + b + c + d + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.compute = LinearModuleAndSubModule(5000, 5000) + self.compute1 = LinearModuleAndSubModule(5000, 5000) + + def forward(self, x): + a = self.compute(x) + b = self.compute1(x) + return a + b + + # We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB. + device_map = {"compute": 0, "compute1": "cpu"} + + model = Model() + + x = torch.randn(1, 5000) + with torch.no_grad(): + expected = model(x) + + # Just to intialize CUDA context. + a = torch.rand(5).to("cuda:0") # noqa: F841 + + free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0] + required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB + + # Leaving 150 MB of free memory for possible buffers, etc. + n_vals = (free_memory_bytes - required_memory_bytes - int(150e6)) // (32 // 8) + foo = torch.rand(n_vals, device="cuda:0") # noqa: F841 + + free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0] + dispatch_model(model, device_map) + free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0] + + self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130) + + original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr() + + with torch.no_grad(): + try: + output = model(x) + except torch.cuda.OutOfMemoryError as e: + raise torch.cuda.OutOfMemoryError( + f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_cpu. {e}" + ) + except Exception as e: + raise e + + self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + + torch.cuda.empty_cache() + + free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0] + + # Check that we have no more references on GPU for the offloaded tied weight. + self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0) + self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0) + self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130) + + @require_cuda + def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self): + # Test that we do not duplicate tied weights at any point during dispatch_model call. + + torch.cuda.empty_cache() # Needed in case we run several tests in a row. + + class SubModule(torch.nn.Module): + def __init__(self, ref_to_parameter): + super().__init__() + self.parameter = ref_to_parameter + + def forward(self, x): + return x + torch.max(self.parameter) + + class LinearModuleAndSubModule(torch.nn.Linear): + def __init__(self, in_features, out_features): + super().__init__(in_features, out_features, bias=False) + self.weight_submodule = SubModule(self.weight) + self.weight_submodule2 = SubModule(self.weight) + self.weight_submodule3 = SubModule(self.weight) + self.weight_submodule4 = SubModule(self.weight) + + def forward(self, x): + a = torch.nn.functional.linear(self.weight_submodule(x), self.weight) + b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight) + c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight) + d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight) + return a + b + c + d + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.compute = LinearModuleAndSubModule(5000, 5000) + self.compute1 = LinearModuleAndSubModule(5000, 5000) + + def forward(self, x): + a = self.compute(x) + b = self.compute1(x) + return a + b + + # We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB. + device_map = {"compute": 0, "compute1": "disk"} + + model = Model() + + x = torch.randn(1, 5000) + with torch.no_grad(): + expected = model(x) + + # Just to intialize CUDA context. + a = torch.rand(5).to("cuda:0") # noqa: F841 + + free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0] + required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB + + # Leaving 150 MB of free memory for possible buffers, etc. + n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8) + foo = torch.rand(n_vals, device="cuda:0") # noqa: F841 + + free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0] + with TemporaryDirectory() as tmp_dir: + dispatch_model(model, device_map, offload_dir=tmp_dir) + free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0] + + self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130) + + original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr() + + with torch.no_grad(): + try: + output = model(x) + except torch.cuda.OutOfMemoryError as e: + raise torch.cuda.OutOfMemoryError( + f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}" + ) + except Exception as e: + raise e + + self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) + + torch.cuda.empty_cache() + + free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0] + + # Check that we have no more references on GPU for the offloaded tied weight. + self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0) + self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0) + self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130) + @require_multi_gpu def test_dispatch_model_multi_gpu(self): model = BiggerModelForTest()