Skip to content

Commit

Permalink
Check if the buffers fit GPU memory after device map auto inferred
Browse files Browse the repository at this point in the history
  * For some models, like TheBloke/WizardCoder-33B-V1.1-GPTQ, contain a
    huge buffer, which may cause OOM on GPU memory if not using
    offload_buffers. This commit adds a check for such case.
  • Loading branch information
notsyncing committed Mar 5, 2024
1 parent 65544d8 commit 2a278f4
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,11 @@ def load_checkpoint_and_dispatch(
low_zero=(device_map == "balanced_low_0"),
)
device_map = infer_auto_device_map(
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=dtype,
offload_buffers=offload_buffers,
)
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
offload_state_dict = True
Expand Down
71 changes: 70 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import re
import shutil
import tempfile
import warnings
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -690,6 +691,7 @@ def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
buffers_only: bool = False,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -701,7 +703,15 @@ def compute_module_sizes(
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
for name, tensor in named_module_tensors(model, recurse=True):

module_list = []

if not buffers_only:
module_list = named_module_tensors(model, recurse=True)
else:
module_list = model.named_buffers(recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
Expand All @@ -719,6 +729,18 @@ def compute_module_sizes(
return module_sizes


def compute_module_total_buffer_size(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the total size of buffers in each submodule of a given model.
"""
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
return module_sizes.get("", 0)


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
Expand Down Expand Up @@ -1030,6 +1052,7 @@ def infer_auto_device_map(
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -1066,6 +1089,9 @@ def infer_auto_device_map(
Whether or not to provide debugging statements as the function builds the device_map.
clean_result (`bool`, *optional*, defaults to `True`):
Clean the resulting device_map by grouping all submodules that go on the same device together.
offload_buffers (`bool`, *optional*, defaults to `False`):
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand Down Expand Up @@ -1098,6 +1124,9 @@ def infer_auto_device_map(
device_map = OrderedDict()
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_memory_reserved = {}
device_buffer_sizes = {}

# Direct submodules and parameters
modules_to_treat = (
Expand Down Expand Up @@ -1150,6 +1179,7 @@ def infer_auto_device_map(
# Reduce max size available by the largest layer.
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
device_memory_reserved[current_device] = max_layer_size
# Case 1 -> We're too big!
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# Split or not split?
Expand All @@ -1167,6 +1197,8 @@ def infer_auto_device_map(
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand Down Expand Up @@ -1218,6 +1250,12 @@ def infer_auto_device_map(
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
Expand Down Expand Up @@ -1258,6 +1296,8 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand All @@ -1272,10 +1312,39 @@ def infer_auto_device_map(
f"(available={current_max_size - current_memory_used})."
)
current_memory_used += module_size
device_memory_used[device] = current_memory_used
device_map[name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

if clean_result:
device_map = clean_device_map(device_map)

if not offload_buffers:
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)

is_buffer_fit_any_gpu = False
for gpu_device, gpu_max_memory in max_memory.items():
if gpu_device == "cpu" or gpu_device == "disk":
continue

if not is_buffer_fit_any_gpu:
gpu_memory_used = device_memory_used.get(gpu_device, 0) + device_memory_reserved.get(gpu_device, 0)

if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
is_buffer_fit_any_gpu = True

if len(gpus) > 0 and not is_buffer_fit_any_gpu:
warnings.warn(
f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
f"offload_buffers=True."
)

return device_map


Expand Down
67 changes: 67 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import tempfile
import unittest
import warnings
from collections import OrderedDict

import torch
Expand All @@ -28,6 +29,7 @@
check_device_map,
clean_device_map,
compute_module_sizes,
compute_module_total_buffer_size,
convert_file_size_to_int,
find_tied_parameters,
get_balanced_memory,
Expand Down Expand Up @@ -297,6 +299,18 @@ def test_compute_module_sizes(self):
module_sizes = compute_module_sizes(model)
assert module_sizes == expected_sizes

def test_compute_module_total_buffer_size(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer", torch.zeros(10, 10))
model.register_buffer("test_buffer2", torch.zeros(20, 10))

buffer_size = compute_module_total_buffer_size(model)
assert buffer_size == 1240

model.half()
buffer_size = compute_module_total_buffer_size(model)
assert buffer_size == 624

def test_check_device_map(self):
model = ModelForTest()
check_device_map(model, {"": 0})
Expand Down Expand Up @@ -604,6 +618,59 @@ def test_infer_auto_device_map_on_t0pp(self):
assert device_map["encoder.embed_tokens"] == 0
assert device_map["decoder.embed_tokens"] == 0

def test_infer_auto_device_map_with_buffer_check(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2))
model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3))
model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3))
# model has size 236(parameters) + 360(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120

# Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit
# device 0, because they will also be loaded to device 0 all at once when inferencing without offload_buffers
# Should print a warning as intended in such case
with self.assertWarns(Warning):
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

# Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit
# device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True)
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2))
model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3))
model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3))
model.linear3 = nn.Linear(4, 5)
model.linear3.register_buffer("test_buffer4", torch.zeros(10, 2))
# model has size 336(parameters) + 440(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120,
# linear3 100 + 80

# Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device
# can hold all remaining buffers
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

# Now we have two devices, but neither the first nor the second device can hold all remaining buffers
# Should print a warning as intended in such case
with self.assertWarns(Warning):
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

# Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True
# Should NOT print a warning in such case
with warnings.catch_warnings():
warnings.simplefilter("error")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True)
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

@require_cuda
def test_get_balanced_memory(self):
model = ModelForTest()
Expand Down

0 comments on commit 2a278f4

Please sign in to comment.