Skip to content

Commit

Permalink
Avoid duplicating memory for tied weights in dispatch_model, and in…
Browse files Browse the repository at this point in the history
… forward with offloading (#2330)

* wip

* fix

* add test

* cleanup

* style

* style & tests pass

* fix offload, submodules

* cleanup

* Update tests/test_big_modeling.py

Co-authored-by: Marc Sun <[email protected]>

* Update tests/test_big_modeling.py

Co-authored-by: Marc Sun <[email protected]>

* disk offloading do not reload tied parameters in memory

* remove outdated comment

---------

Co-authored-by: Your Name <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2024
1 parent 31fd2b1 commit 6719cb6
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 17 deletions.
17 changes: 17 additions & 0 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
parse_flag_from_env,
retie_parameters,
)
from .utils.other import recursive_getattr


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -395,7 +396,22 @@ def dispatch_model(
else:
weights_map = None

# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
# original pointer) on each devices.
tied_params = find_tied_parameters(model)

tied_params_map = {}
for group in tied_params:
for param_name in group:
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
# to care about views of tensors through storage_offset.
data_ptr = recursive_getattr(model, param_name).data_ptr()
tied_params_map[data_ptr] = {}

# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.

attach_align_device_hook_on_blocks(
model,
execution_device=execution_device,
Expand All @@ -404,6 +420,7 @@ def dispatch_model(
weights_map=weights_map,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
tied_params_map=tied_params_map,
)

# warn if there is any params on the meta device
Expand Down
110 changes: 94 additions & 16 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
set_module_tensor_to_device,
)
from .utils.modeling import get_non_persistent_buffers
from .utils.other import recursive_getattr


class ModelHook:
Expand Down Expand Up @@ -116,7 +117,9 @@ def detach_hook(self, module):
return module


def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
def add_hook_to_module(
module: nn.Module, hook: ModelHook, append: bool = False, init_hook_kwargs: Optional[Dict] = None
):
"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
Expand All @@ -135,6 +138,8 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
init_hook_kwargs (Optional[Dict], *optional*, defaults to `None`):
Optional arguments to pass to the hook initialization.
Returns:
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
Expand All @@ -153,7 +158,10 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
old_forward = module.forward
module._old_forward = old_forward

module = hook.init_hook(module)
if init_hook_kwargs is None:
init_hook_kwargs = {}

module = hook.init_hook(module, **init_hook_kwargs)
module._hf_hook = hook

def new_forward(module, *args, **kwargs):
Expand Down Expand Up @@ -240,6 +248,7 @@ def __init__(
self.input_device = None
self.param_original_devices = {}
self.buffer_original_devices = {}
self.tied_params_names = set()

def __repr__(self):
return (
Expand All @@ -248,10 +257,14 @@ def __repr__(self):
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
)

def init_hook(self, module):
def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None):
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
tied_params_map = None

if not self.offload and self.execution_device is not None:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
elif self.offload:
self.original_devices = {
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
Expand All @@ -266,32 +279,67 @@ def init_hook(self, module):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
):
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
if tied_params_map is not None and recursive_getattr(module, name).data_ptr() in tied_params_map:
self.tied_params_names.add(name)

set_module_tensor_to_device(module, name, "meta")

if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)

# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
# for tied weights already loaded on the target execution device.
self.tied_params_map = tied_params_map

return module

def pre_forward(self, module, *args, **kwargs):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
if self.offload:
self.tied_pointers_to_remove = set()

for name, _ in named_module_tensors(
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
remove_non_persistent=True,
):
fp16_statistics = None
value = self.weights_map[name]
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
if self.weights_map[name].dtype == torch.int8:
if value.dtype == torch.int8:
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]

# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
# self.tied_params_map in order to allow to free memory.
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
self.tied_params_map[value.data_ptr()] = {}

if (
value is not None
and self.tied_params_map is not None
and value.data_ptr() in self.tied_params_map
and self.execution_device not in self.tied_params_map[value.data_ptr()]
):
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))

set_module_tensor_to_device(
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics
module,
name,
self.execution_device,
value=value,
fp16_statistics=fp16_statistics,
tied_params_map=self.tied_params_map,
)

return send_to_device(args, self.execution_device), send_to_device(
Expand All @@ -311,6 +359,12 @@ def post_forward(self, module, output):
module.state.SCB = None
module.state.CxB = None

# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
# this dictionary to allow the garbage collector to do its job.
for value_pointer, device in self.tied_pointers_to_remove:
del self.tied_params_map[value_pointer][device]
self.tied_pointers_to_remove = None

if self.io_same_device and self.input_device is not None:
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)

Expand All @@ -329,6 +383,7 @@ def attach_execution_device_hook(
execution_device: Union[int, str, torch.device],
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
Expand All @@ -346,16 +401,24 @@ def attach_execution_device_hook(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
add_hook_to_module(module, AlignDevicesHook(execution_device, skip_keys=skip_keys))
add_hook_to_module(
module,
AlignDevicesHook(execution_device, skip_keys=skip_keys),
init_hook_kwargs={"tied_params_map": tied_params_map},
)

# Break the recursion if we get to a preload module.
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
return

for child in module.children():
attach_execution_device_hook(child, execution_device)
attach_execution_device_hook(child, execution_device, tied_params_map=tied_params_map)


def attach_align_device_hook(
Expand All @@ -367,6 +430,7 @@ def attach_align_device_hook(
module_name: str = "",
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
Expand All @@ -392,6 +456,10 @@ def attach_align_device_hook(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
# Attach the hook on this module if it has any direct tensor.
directs = named_module_tensors(module)
Expand All @@ -413,7 +481,7 @@ def attach_align_device_hook(
place_submodules=full_offload,
skip_keys=skip_keys,
)
add_hook_to_module(module, hook, append=True)
add_hook_to_module(module, hook, append=True, init_hook_kwargs={"tied_params_map": tied_params_map})

# We stop the recursion in case we hit the full offload.
if full_offload:
Expand All @@ -431,6 +499,7 @@ def attach_align_device_hook(
module_name=child_name,
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)


Expand All @@ -455,6 +524,7 @@ def attach_align_device_hook_on_blocks(
module_name: str = "",
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
Expand All @@ -481,14 +551,18 @@ def attach_align_device_hook_on_blocks(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
# If one device and one offload, we've got one hook.
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
if not offload:
hook = AlignDevicesHook(
execution_device=execution_device, io_same_device=True, skip_keys=skip_keys, place_submodules=True
)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
else:
attach_align_device_hook(
module,
Expand All @@ -498,6 +572,7 @@ def attach_align_device_hook_on_blocks(
offload_buffers=offload_buffers,
module_name=module_name,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
return

Expand All @@ -514,8 +589,8 @@ def attach_align_device_hook_on_blocks(
place_submodules=True,
skip_keys=skip_keys,
)
add_hook_to_module(module, hook)
attach_execution_device_hook(module, execution_device[module_name])
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map)
elif module_name in execution_device and module_name in offload:
attach_align_device_hook(
module,
Expand All @@ -526,21 +601,23 @@ def attach_align_device_hook_on_blocks(
module_name=module_name,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
tied_params_map=tied_params_map,
)
if not hasattr(module, "_hf_hook"):
hook = AlignDevicesHook(
execution_device=execution_device[module_name], io_same_device=(module_name == ""), skip_keys=skip_keys
)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
attach_execution_device_hook(
module,
execution_device[module_name],
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
elif module_name == "":
hook = AlignDevicesHook(execution_device=execution_device.get(""), io_same_device=True, skip_keys=skip_keys)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})

for child_name, child in module.named_children():
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
Expand All @@ -553,6 +630,7 @@ def attach_align_device_hook_on_blocks(
module_name=child_name,
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)


Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
is_port_in_use,
merge_dicts,
patch_environment,
recursive_getattr,
save,
wait_for_everyone,
write_basic_config,
Expand Down
Loading

0 comments on commit 6719cb6

Please sign in to comment.