Skip to content

Commit

Permalink
Check if the buffers fit GPU memory after device map auto inferred
Browse files Browse the repository at this point in the history
  * For some models, like TheBloke/WizardCoder-33B-V1.1-GPTQ, contain a
    huge buffer, which may cause OOM on GPU memory if not using
    offload_buffers. This commit adds a check for such case.
  • Loading branch information
notsyncing committed Mar 1, 2024
1 parent ca37b0e commit 0e4c945
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,11 @@ def load_checkpoint_and_dispatch(
low_zero=(device_map == "balanced_low_0"),
)
device_map = infer_auto_device_map(
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=dtype,
offload_buffers=offload_buffers,
)
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
offload_state_dict = True
Expand Down
66 changes: 65 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
buffers_only: bool = False,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -701,7 +702,16 @@ def compute_module_sizes(
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
for name, tensor in named_module_tensors(model, recurse=True):

module_list = []

if not buffers_only:
module_list = named_module_tensors(model, recurse=True)
else:
if hasattr(model, "named_buffers"):
module_list = model.named_buffers(recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
Expand All @@ -719,6 +729,18 @@ def compute_module_sizes(
return module_sizes


def compute_module_total_buffer_size(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the total size of buffers in each submodule of a given model.
"""
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
return module_sizes.get("", 0)


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
Expand Down Expand Up @@ -1030,6 +1052,7 @@ def infer_auto_device_map(
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -1066,6 +1089,9 @@ def infer_auto_device_map(
Whether or not to provide debugging statements as the function builds the device_map.
clean_result (`bool`, *optional*, defaults to `True`):
Clean the resulting device_map by grouping all submodules that go on the same device together.
offload_buffers (`bool`, *optional*, defaults to `False`):
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand Down Expand Up @@ -1098,6 +1124,8 @@ def infer_auto_device_map(
device_map = OrderedDict()
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_buffer_sizes = {}

# Direct submodules and parameters
modules_to_treat = (
Expand Down Expand Up @@ -1163,6 +1191,8 @@ def infer_auto_device_map(
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand Down Expand Up @@ -1214,6 +1244,12 @@ def infer_auto_device_map(
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

if not offload_buffers:
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
Expand Down Expand Up @@ -1254,6 +1290,8 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand All @@ -1270,8 +1308,34 @@ def infer_auto_device_map(
current_memory_used += module_size
device_map[name] = devices[current_device]

if not offload_buffers:
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

if clean_result:
device_map = clean_device_map(device_map)

if not offload_buffers:
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)

is_buffer_fit_any_gpu = False
for gpu_device, gpu_max_memory in max_memory.items():
if gpu_device == "cpu" or gpu_device == "disk":
continue

if not is_buffer_fit_any_gpu:
gpu_memory_used = device_memory_used.get(gpu_device, 0)
if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
is_buffer_fit_any_gpu = True

if not is_buffer_fit_any_gpu:
logger.warn(
f"Current model requires {non_gpu_buffer_size} bytes of buffer, which seems does not fit any GPU's "
f"remaining memory. Please consider using offload_buffers=True."
)

return device_map


Expand Down

0 comments on commit 0e4c945

Please sign in to comment.