Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cpu offload with weights inside the module #2214

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def dispatch_model(
offload_dir: Optional[Union[str, os.PathLike]] = None,
offload_index: Optional[Dict[str, str]] = None,
offload_buffers: bool = False,
cpu_offload: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, should newly added parameters be placed last in case someone calls this function with purely positional arguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be confusing as CPU offloading is already indicated in the device_map.

IMO ideally there should not be any argument added, and by default the weights of modules offloaded on RAM should be on cpu device, not meta. it this is kind of a breaking change in case anybody is assuming that by default attached weights are on meta and weights_map holds the true weights.

skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
force_hooks: bool = False,
Expand All @@ -321,6 +322,8 @@ def dispatch_model(
`"disk"`.
state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the part of the model that will be kept on CPU.
cpu_offload (`bool`, *optional*, defaults to `False`):
Whether the weights offloaded on the cpu should be kept in the module or not.
offload_dir (`str` or `os.PathLike`):
The folder in which to offload the model weights (or where the model weights are already offloaded).
offload_index (`Dict`, *optional*):
Expand Down Expand Up @@ -358,7 +361,7 @@ def dispatch_model(
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]

if main_device != "cpu":
if main_device != "cpu" and not cpu_offload:
cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
if state_dict is None and len(cpu_modules) > 0:
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
Expand All @@ -381,8 +384,12 @@ def dispatch_model(
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
}
execution_device[""] = main_device
offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
offloaded_devices = (
["disk"] if cpu_offload or main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
)
offload = {name: device in offloaded_devices for name, device in device_map.items()}
if cpu_offload:
cpu_offload = {name: device == "cpu" for name, device in device_map.items()}
save_folder = offload_dir if len(disk_modules) > 0 else None
if state_dict is not None or save_folder is not None or offload_index is not None:
device = main_device if offload_index is not None else None
Expand All @@ -397,6 +404,7 @@ def dispatch_model(
model,
execution_device=execution_device,
offload=offload,
cpu_offload=cpu_offload,
offload_buffers=offload_buffers,
weights_map=weights_map,
skip_keys=skip_keys,
Expand All @@ -405,7 +413,7 @@ def dispatch_model(

# warn if there is any params on the meta device
offloaded_devices_str = " and ".join(
[device for device in set(device_map.values()) if device in ("cpu", "disk")]
[device for device in set(device_map.values()) if device in offloaded_devices]
)
if len(offloaded_devices_str) > 0:
logging.warning(
Expand Down Expand Up @@ -450,6 +458,7 @@ def load_checkpoint_and_dispatch(
no_split_module_classes: Optional[List[str]] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
offload_buffers: bool = False,
cpu_offload: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
offload_state_dict: Optional[bool] = None,
skip_keys: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -484,6 +493,8 @@ def load_checkpoint_and_dispatch(
offload_buffers (`bool`, *optional*, defaults to `False`):
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
cpu_offload (`bool`, *optional*, defaults to `False`):
Whether the weights offloaded on the cpu should be kept in the module or not.
dtype (`str` or `torch.dtype`, *optional*):
If provided, the weights will be converted to that type when loaded.
offload_state_dict (`bool`, *optional*):
Expand Down Expand Up @@ -558,6 +569,7 @@ def load_checkpoint_and_dispatch(
device_map=device_map,
offload_dir=offload_folder,
offload_buffers=offload_buffers,
cpu_offload=cpu_offload,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
force_hooks=force_hooks,
Expand Down
33 changes: 26 additions & 7 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class AlignDevicesHook(ModelHook):
The device on which inputs and model weights should be placed before the forward pass.
offload (`bool`, *optional*, defaults to `False`):
Whether or not the weights should be offloaded after the forward pass.
cpu_offload (`bool`, *optional*, defaults to `False`):
Whether the weights offloaded on the cpu should be kept in the module or not.
io_same_device (`bool`, *optional*, defaults to `False`):
Whether or not the output should be placed on the same device as the input was.
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
Expand All @@ -222,6 +224,7 @@ def __init__(
self,
execution_device: Optional[Union[int, str, torch.device]] = None,
offload: bool = False,
cpu_offload: bool = False,
io_same_device: bool = False,
weights_map: Optional[Mapping] = None,
offload_buffers: bool = False,
Expand All @@ -230,12 +233,12 @@ def __init__(
):
self.execution_device = execution_device
self.offload = offload
self.cpu_offload = cpu_offload
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to pass both offload and cpu_offload? It seems the former would take precedent over the latter. Maybe this could be checked or documented?

self.io_same_device = io_same_device
self.weights_map = weights_map
self.offload_buffers = offload_buffers
self.place_submodules = place_submodules
self.skip_keys = skip_keys

# Will contain the input device when `io_same_device=True`.
self.input_device = None
self.param_original_devices = {}
Expand All @@ -249,7 +252,7 @@ def __repr__(self):
)

def init_hook(self, module):
if not self.offload and self.execution_device is not None:
if not self.offload and self.execution_device is not None and not self.cpu_offload:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
elif self.offload:
Expand All @@ -273,7 +276,9 @@ def init_hook(self, module):
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)

elif self.cpu_offload:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, "cpu")
return module

def pre_forward(self, module, *args, **kwargs):
Expand All @@ -293,7 +298,9 @@ def pre_forward(self, module, *args, **kwargs):
set_module_tensor_to_device(
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics
)

elif self.cpu_offload:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
return send_to_device(args, self.execution_device), send_to_device(
kwargs, self.execution_device, skip_keys=self.skip_keys
)
Expand All @@ -310,7 +317,9 @@ def post_forward(self, module, output):
if type(module).__name__ == "Linear8bitLt":
module.state.SCB = None
module.state.CxB = None

elif self.cpu_offload:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, "cpu")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is special handling for Linear8bitLt required, similar to above?

if self.io_same_device and self.input_device is not None:
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)

Expand Down Expand Up @@ -450,6 +459,7 @@ def attach_align_device_hook_on_blocks(
module: nn.Module,
execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None,
offload: Union[bool, Dict[str, bool]] = False,
cpu_offload: Union[bool, Dict[str, bool]] = False,
weights_map: Mapping = None,
offload_buffers: bool = False,
module_name: str = "",
Expand All @@ -468,6 +478,8 @@ def attach_align_device_hook_on_blocks(
offload (`bool`, *optional*, defaults to `False`):
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
module, or a dictionary mapping module name to boolean.
cpu_offload (`Union[bool, Dict[str, bool]]`, *optional*, defaults to `False`):
Whether the weights offloaded on the cpu should be kept in the module or not.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring misses to explain what the option is for passing a dict here.

weights_map (`Mapping[str, torch.Tensor]`, *optional*):
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
offload_buffers (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -505,10 +517,16 @@ def attach_align_device_hook_on_blocks(
execution_device = {key: execution_device for key in offload.keys()}
if not isinstance(offload, Mapping):
offload = {key: offload for key in execution_device.keys()}

if module_name in execution_device and module_name in offload and not offload[module_name]:
if not isinstance(cpu_offload, Mapping):
cpu_offload = {key: cpu_offload for key in execution_device.keys()}
if (
module_name in execution_device
and module_name in offload
and (not offload[module_name] or cpu_offload[module_name])
):
hook = AlignDevicesHook(
execution_device=execution_device[module_name],
cpu_offload=cpu_offload[module_name],
offload_buffers=offload_buffers,
io_same_device=(module_name == ""),
place_submodules=True,
Expand Down Expand Up @@ -548,6 +566,7 @@ def attach_align_device_hook_on_blocks(
child,
execution_device=execution_device,
offload=offload,
cpu_offload=cpu_offload,
weights_map=weights_map,
offload_buffers=offload_buffers,
module_name=child_name,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,24 @@ def test_dispatch_model_with_non_persistent_buffers(self):

with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir, offload_buffers=True)

output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

def test_dispatch_model_with_cpu_offload(self):
model = ModelForTest()
device_map = {"linear1": "disk", "batchnorm": "cpu", "linear2": 0}

x = torch.randn(2, 3)
expected = model(x)

with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir, cpu_offload=True)

self.assertEqual(model.linear1.weight.device, torch.device("meta"))
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new behavior of getting "cpu" here instead of "meta" looks more intuitive to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is what we are aiming in this PR ! We want to let the module on cpu and not on meta device.

self.assertEqual(model.linear2.weight.device, torch.device(0))

output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

Expand Down Expand Up @@ -548,6 +566,31 @@ def test_load_checkpoint_and_dispatch(self):
output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_load_checkpoint_and_dispatch_with_cpu_offload(self):
model = ModelForTest()
device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": 0}

x = torch.randn(2, 3)
expected = model(x)

with TemporaryDirectory() as tmp_dir:
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
torch.save(model.state_dict(), checkpoint)

new_model = ModelForTest()
new_model = load_checkpoint_and_dispatch(
new_model, checkpoint, device_map=device_map, cpu_offload=True, offload_folder=tmp_dir
)

# CPU-offloaded weights are on the meta device while waiting for the forward pass.
self.assertEqual(new_model.linear1.weight.device, torch.device("cpu"))
self.assertEqual(new_model.batchnorm.weight.device, torch.device("meta"))
self.assertEqual(new_model.linear2.weight.device, torch.device(0))

output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_mps
def test_load_checkpoint_and_dispatch_mps(self):
model = ModelForTest()
Expand Down
Loading