From 2a278f44030ca198950c68f6cda7e51e2eaa2b00 Mon Sep 17 00:00:00 2001 From: notsyncing Date: Fri, 2 Feb 2024 13:09:02 +0800 Subject: [PATCH 1/3] Check if the buffers fit GPU memory after device map auto inferred * For some models, like TheBloke/WizardCoder-33B-V1.1-GPTQ, contain a huge buffer, which may cause OOM on GPU memory if not using offload_buffers. This commit adds a check for such case. --- src/accelerate/big_modeling.py | 6 ++- src/accelerate/utils/modeling.py | 71 +++++++++++++++++++++++++++++++- tests/test_modeling_utils.py | 67 ++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 2 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 7c0582773a6..9a005d21f40 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -578,7 +578,11 @@ def load_checkpoint_and_dispatch( low_zero=(device_map == "balanced_low_0"), ) device_map = infer_auto_device_map( - model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=dtype, + offload_buffers=offload_buffers, ) if offload_state_dict is None and device_map is not None and "disk" in device_map.values(): offload_state_dict = True diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 8ca67fa15a1..166f140393f 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -22,6 +22,7 @@ import re import shutil import tempfile +import warnings from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -690,6 +691,7 @@ def compute_module_sizes( model: nn.Module, dtype: Optional[Union[str, torch.device]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + buffers_only: bool = False, ): """ Compute the size of each submodule of a given model. @@ -701,7 +703,15 @@ def compute_module_sizes( special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} module_sizes = defaultdict(int) - for name, tensor in named_module_tensors(model, recurse=True): + + module_list = [] + + if not buffers_only: + module_list = named_module_tensors(model, recurse=True) + else: + module_list = model.named_buffers(recurse=True) + + for name, tensor in module_list: if special_dtypes is not None and name in special_dtypes: size = tensor.numel() * special_dtypes_size[name] elif dtype is None: @@ -719,6 +729,18 @@ def compute_module_sizes( return module_sizes +def compute_module_total_buffer_size( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the total size of buffers in each submodule of a given model. + """ + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True) + return module_sizes.get("", 0) + + def get_max_layer_size( modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str] ): @@ -1030,6 +1052,7 @@ def infer_auto_device_map( special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, verbose: bool = False, clean_result: bool = True, + offload_buffers: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -1066,6 +1089,9 @@ def infer_auto_device_map( Whether or not to provide debugging statements as the function builds the device_map. clean_result (`bool`, *optional*, defaults to `True`): Clean the resulting device_map by grouping all submodules that go on the same device together. + offload_buffers (`bool`, *optional*, defaults to `False`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. """ # Get default / clean up max_memory max_memory = get_max_memory(max_memory) @@ -1098,6 +1124,9 @@ def infer_auto_device_map( device_map = OrderedDict() current_device = 0 current_memory_used = 0 + device_memory_used = {} + device_memory_reserved = {} + device_buffer_sizes = {} # Direct submodules and parameters modules_to_treat = ( @@ -1150,6 +1179,7 @@ def infer_auto_device_map( # Reduce max size available by the largest layer. if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size + device_memory_reserved[current_device] = max_layer_size # Case 1 -> We're too big! if current_max_size is not None and current_memory_used + module_size > current_max_size: # Split or not split? @@ -1167,6 +1197,8 @@ def infer_auto_device_map( # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") + + device_memory_used[device] = current_memory_used current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 @@ -1218,6 +1250,12 @@ def infer_auto_device_map( modules_to_treat.pop(tied_module_index) device_map[tied_module_name] = devices[current_device] + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + else: # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it # smaller or do we need to go on the next device? @@ -1258,6 +1296,8 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") + + device_memory_used[device] = current_memory_used current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 @@ -1272,10 +1312,39 @@ def infer_auto_device_map( f"(available={current_max_size - current_memory_used})." ) current_memory_used += module_size + device_memory_used[device] = current_memory_used device_map[name] = devices[current_device] + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + if clean_result: device_map = clean_device_map(device_map) + + if not offload_buffers: + non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) + + is_buffer_fit_any_gpu = False + for gpu_device, gpu_max_memory in max_memory.items(): + if gpu_device == "cpu" or gpu_device == "disk": + continue + + if not is_buffer_fit_any_gpu: + gpu_memory_used = device_memory_used.get(gpu_device, 0) + device_memory_reserved.get(gpu_device, 0) + + if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: + is_buffer_fit_any_gpu = True + + if len(gpus) > 0 and not is_buffer_fit_any_gpu: + warnings.warn( + f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " + f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " + f"offload_buffers=True." + ) + return device_map diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 69b80a4b12f..a023a441844 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -16,6 +16,7 @@ import os import tempfile import unittest +import warnings from collections import OrderedDict import torch @@ -28,6 +29,7 @@ check_device_map, clean_device_map, compute_module_sizes, + compute_module_total_buffer_size, convert_file_size_to_int, find_tied_parameters, get_balanced_memory, @@ -297,6 +299,18 @@ def test_compute_module_sizes(self): module_sizes = compute_module_sizes(model) assert module_sizes == expected_sizes + def test_compute_module_total_buffer_size(self): + model = ModelForTest() + model.linear1.register_buffer("test_buffer", torch.zeros(10, 10)) + model.register_buffer("test_buffer2", torch.zeros(20, 10)) + + buffer_size = compute_module_total_buffer_size(model) + assert buffer_size == 1240 + + model.half() + buffer_size = compute_module_total_buffer_size(model) + assert buffer_size == 624 + def test_check_device_map(self): model = ModelForTest() check_device_map(model, {"": 0}) @@ -604,6 +618,59 @@ def test_infer_auto_device_map_on_t0pp(self): assert device_map["encoder.embed_tokens"] == 0 assert device_map["decoder.embed_tokens"] == 0 + def test_infer_auto_device_map_with_buffer_check(self): + model = ModelForTest() + model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2)) + model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3)) + model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3)) + # model has size 236(parameters) + 360(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120 + + # Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit + # device 0, because they will also be loaded to device 0 all at once when inferencing without offload_buffers + # Should print a warning as intended in such case + with self.assertWarns(Warning): + device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}) + assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"} + + # Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit + # device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now + # Should NOT print a warning in such case + with warnings.catch_warnings(): + warnings.simplefilter("error") + device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True) + assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"} + + def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): + model = ModelForTest() + model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2)) + model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3)) + model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3)) + model.linear3 = nn.Linear(4, 5) + model.linear3.register_buffer("test_buffer4", torch.zeros(10, 2)) + # model has size 336(parameters) + 440(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120, + # linear3 100 + 80 + + # Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device + # can hold all remaining buffers + # Should NOT print a warning in such case + with warnings.catch_warnings(): + warnings.simplefilter("error") + device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"}) + assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} + + # Now we have two devices, but neither the first nor the second device can hold all remaining buffers + # Should print a warning as intended in such case + with self.assertWarns(Warning): + device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}) + assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} + + # Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True + # Should NOT print a warning in such case + with warnings.catch_warnings(): + warnings.simplefilter("error") + device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True) + assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} + @require_cuda def test_get_balanced_memory(self): model = ModelForTest() From 91ed9dea71ab4f7e726d365fd2b7ea5498f98844 Mon Sep 17 00:00:00 2001 From: notsyncing Date: Wed, 6 Mar 2024 15:41:54 +0800 Subject: [PATCH 2/3] Minor refactors. --- src/accelerate/utils/modeling.py | 17 ++++++++--------- tests/test_modeling_utils.py | 6 +++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 166f140393f..ebedfd3ebbc 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1125,7 +1125,6 @@ def infer_auto_device_map( current_device = 0 current_memory_used = 0 device_memory_used = {} - device_memory_reserved = {} device_buffer_sizes = {} # Direct submodules and parameters @@ -1176,10 +1175,11 @@ def infer_auto_device_map( device = devices[current_device] current_max_size = max_memory[device] if device != "disk" else None + current_memory_reserved = 0 # Reduce max size available by the largest layer. if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size - device_memory_reserved[current_device] = max_layer_size + current_memory_reserved = max_layer_size # Case 1 -> We're too big! if current_max_size is not None and current_memory_used + module_size > current_max_size: # Split or not split? @@ -1198,7 +1198,7 @@ def infer_auto_device_map( if verbose: print("This module cannot be split, going to the next device.") - device_memory_used[device] = current_memory_used + device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 @@ -1297,7 +1297,7 @@ def infer_auto_device_map( if verbose: print("None of the tied module can be split, going to the next device.") - device_memory_used[device] = current_memory_used + device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 @@ -1312,7 +1312,7 @@ def infer_auto_device_map( f"(available={current_max_size - current_memory_used})." ) current_memory_used += module_size - device_memory_used[device] = current_memory_used + device_memory_used[device] = current_memory_used + current_memory_reserved device_map[name] = devices[current_device] if not offload_buffers and isinstance(module, nn.Module): @@ -1324,16 +1324,15 @@ def infer_auto_device_map( if clean_result: device_map = clean_device_map(device_map) - if not offload_buffers: - non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) - + non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) + if non_gpu_buffer_size > 0 and not offload_buffers: is_buffer_fit_any_gpu = False for gpu_device, gpu_max_memory in max_memory.items(): if gpu_device == "cpu" or gpu_device == "disk": continue if not is_buffer_fit_any_gpu: - gpu_memory_used = device_memory_used.get(gpu_device, 0) + device_memory_reserved.get(gpu_device, 0) + gpu_memory_used = device_memory_used.get(gpu_device, 0) if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: is_buffer_fit_any_gpu = True diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index a023a441844..7ada3725648 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -636,7 +636,7 @@ def test_infer_auto_device_map_with_buffer_check(self): # device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now # Should NOT print a warning in such case with warnings.catch_warnings(): - warnings.simplefilter("error") + warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True) assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"} @@ -654,7 +654,7 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): # can hold all remaining buffers # Should NOT print a warning in such case with warnings.catch_warnings(): - warnings.simplefilter("error") + warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"}) assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} @@ -667,7 +667,7 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): # Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True # Should NOT print a warning in such case with warnings.catch_warnings(): - warnings.simplefilter("error") + warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True) assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} From f67ac009563d7c4618d81840f68efe3e854bedf1 Mon Sep 17 00:00:00 2001 From: notsyncing Date: Fri, 8 Mar 2024 12:22:30 +0800 Subject: [PATCH 3/3] Add missing assertions --- tests/test_modeling_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 7ada3725648..e006dc3ffd8 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -635,9 +635,10 @@ def test_infer_auto_device_map_with_buffer_check(self): # Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit # device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now # Should NOT print a warning in such case - with warnings.catch_warnings(): + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True) + assert len(w) == 0 assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"} def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): @@ -653,9 +654,10 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): # Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device # can hold all remaining buffers # Should NOT print a warning in such case - with warnings.catch_warnings(): + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"}) + assert len(w) == 0 assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} # Now we have two devices, but neither the first nor the second device can hold all remaining buffers @@ -666,9 +668,10 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): # Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True # Should NOT print a warning in such case - with warnings.catch_warnings(): + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True) + assert len(w) == 0 assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} @require_cuda