Skip to content

Commit

Permalink
remove init_hook_kwargs (#2365)
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty authored Jan 22, 2024
1 parent 53845d2 commit 0d6a5fa
Showing 1 changed file with 43 additions and 29 deletions.
72 changes: 43 additions & 29 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def detach_hook(self, module):
return module


def add_hook_to_module(
module: nn.Module, hook: ModelHook, append: bool = False, init_hook_kwargs: Optional[Dict] = None
):
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
Expand All @@ -138,8 +136,6 @@ def add_hook_to_module(
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
init_hook_kwargs (Optional[Dict], *optional*, defaults to `None`):
Optional arguments to pass to the hook initialization.
Returns:
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
Expand All @@ -158,10 +154,7 @@ def add_hook_to_module(
old_forward = module.forward
module._old_forward = old_forward

if init_hook_kwargs is None:
init_hook_kwargs = {}

module = hook.init_hook(module, **init_hook_kwargs)
module = hook.init_hook(module)
module._hf_hook = hook

def new_forward(module, *args, **kwargs):
Expand Down Expand Up @@ -235,6 +228,7 @@ def __init__(
offload_buffers: bool = False,
place_submodules: bool = False,
skip_keys: Optional[Union[str, List[str]]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
self.execution_device = execution_device
self.offload = offload
Expand All @@ -250,21 +244,25 @@ def __init__(
self.buffer_original_devices = {}
self.tied_params_names = set()

# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
# for tied weights already loaded on the target execution device.
self.tied_params_map = tied_params_map

def __repr__(self):
return (
f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
)

def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None):
def init_hook(self, module):
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
tied_params_map = None
self.tied_params_map = None

if not self.offload and self.execution_device is not None:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
elif self.offload:
self.original_devices = {
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
Expand All @@ -283,21 +281,24 @@ def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.devic
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
if tied_params_map is not None and recursive_getattr(module, name).data_ptr() in tied_params_map:
if (
self.tied_params_map is not None
and recursive_getattr(module, name).data_ptr() in self.tied_params_map
):
self.tied_params_names.add(name)

set_module_tensor_to_device(module, name, "meta")

if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
set_module_tensor_to_device(
module, name, self.execution_device, tied_params_map=self.tied_params_map
)
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)

# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
# for tied weights already loaded on the target execution device.
self.tied_params_map = tied_params_map
set_module_tensor_to_device(
module, name, self.execution_device, tied_params_map=self.tied_params_map
)

return module

Expand Down Expand Up @@ -409,8 +410,7 @@ def attach_execution_device_hook(
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
add_hook_to_module(
module,
AlignDevicesHook(execution_device, skip_keys=skip_keys),
init_hook_kwargs={"tied_params_map": tied_params_map},
AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
)

# Break the recursion if we get to a preload module.
Expand Down Expand Up @@ -480,8 +480,9 @@ def attach_align_device_hook(
offload_buffers=offload_buffers,
place_submodules=full_offload,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
add_hook_to_module(module, hook, append=True, init_hook_kwargs={"tied_params_map": tied_params_map})
add_hook_to_module(module, hook, append=True)

# We stop the recursion in case we hit the full offload.
if full_offload:
Expand Down Expand Up @@ -560,9 +561,13 @@ def attach_align_device_hook_on_blocks(
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
if not offload:
hook = AlignDevicesHook(
execution_device=execution_device, io_same_device=True, skip_keys=skip_keys, place_submodules=True
execution_device=execution_device,
io_same_device=True,
skip_keys=skip_keys,
place_submodules=True,
tied_params_map=tied_params_map,
)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
add_hook_to_module(module, hook)
else:
attach_align_device_hook(
module,
Expand All @@ -588,8 +593,9 @@ def attach_align_device_hook_on_blocks(
io_same_device=(module_name == ""),
place_submodules=True,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
add_hook_to_module(module, hook)
attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map)
elif module_name in execution_device and module_name in offload:
attach_align_device_hook(
Expand All @@ -605,9 +611,12 @@ def attach_align_device_hook_on_blocks(
)
if not hasattr(module, "_hf_hook"):
hook = AlignDevicesHook(
execution_device=execution_device[module_name], io_same_device=(module_name == ""), skip_keys=skip_keys
execution_device=execution_device[module_name],
io_same_device=(module_name == ""),
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
add_hook_to_module(module, hook)
attach_execution_device_hook(
module,
execution_device[module_name],
Expand All @@ -616,8 +625,13 @@ def attach_align_device_hook_on_blocks(
tied_params_map=tied_params_map,
)
elif module_name == "":
hook = AlignDevicesHook(execution_device=execution_device.get(""), io_same_device=True, skip_keys=skip_keys)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
hook = AlignDevicesHook(
execution_device=execution_device.get(""),
io_same_device=True,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
add_hook_to_module(module, hook)

for child_name, child in module.named_children():
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
Expand Down

0 comments on commit 0d6a5fa

Please sign in to comment.