-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable cpu offload with weights inside the module #2214
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this annoying issue. Overall, I understand too little about the mechanism used by accelerate to control this, so I'd leave that part of the review to others.
This PR adds the possibility to perform cpu offload with the weights stored inside the module.
Sorry for being dense, but where exactly is that happening?
not sure about the naming of the arg as it can be confusing
Yes, I'd definitely rename the argument, especially since already have a cpu_offload
function in the same file.
@@ -468,6 +478,8 @@ def attach_align_device_hook_on_blocks( | |||
offload (`bool`, *optional*, defaults to `False`): | |||
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole | |||
module, or a dictionary mapping module name to boolean. | |||
cpu_offload (`Union[bool, Dict[str, bool]]`, *optional*, defaults to `False`): | |||
Whether the weights offloaded on the cpu should be kept in the module or not. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring misses to explain what the option is for passing a dict here.
dispatch_model(model, device_map, offload_dir=tmp_dir, cpu_offload=True) | ||
|
||
self.assertEqual(model.linear1.weight.device, torch.device("meta")) | ||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new behavior of getting "cpu" here instead of "meta" looks more intuitive to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is what we are aiming in this PR ! We want to let the module on cpu
and not on meta
device.
@@ -302,6 +302,7 @@ def dispatch_model( | |||
offload_dir: Optional[Union[str, os.PathLike]] = None, | |||
offload_index: Optional[Dict[str, str]] = None, | |||
offload_buffers: bool = False, | |||
cpu_offload: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, should newly added parameters be placed last in case someone calls this function with purely positional arguments?
@@ -230,12 +233,12 @@ def __init__( | |||
): | |||
self.execution_device = execution_device | |||
self.offload = offload | |||
self.cpu_offload = cpu_offload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to pass both offload
and cpu_offload
? It seems the former would take precedent over the latter. Maybe this could be checked or documented?
|
||
elif self.cpu_offload: | ||
for name, _ in named_module_tensors(module, recurse=self.place_submodules): | ||
set_module_tensor_to_device(module, name, "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is special handling for Linear8bitLt
required, similar to above?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@@ -302,6 +302,7 @@ def dispatch_model( | |||
offload_dir: Optional[Union[str, os.PathLike]] = None, | |||
offload_index: Optional[Dict[str, str]] = None, | |||
offload_buffers: bool = False, | |||
cpu_offload: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be confusing as CPU offloading is already indicated in the device_map
.
IMO ideally there should not be any argument added, and by default the weights of modules offloaded on RAM should be on cpu
device, not meta
. it this is kind of a breaking change in case anybody is assuming that by default attached weights are on meta
and weights_map
holds the true weights.
What does this PR do ?
This PR adds the possibility to perform cpu offload with the weights stored inside the module. You just need to pass
cpu_offload = True
in thedispatch_model
. (not sure about the naming of the arg as it can be confusing)Before this PR, all offloaded modules were placed on the
meta
device and the weights were either stored in a dict (cpu offload) or a mmap (disk offload). We would then move the modules to the execution device with their respective value taken from the dict/mmap during theforward
.For the user, this seems a little counter intuitive to put weights in a
dict
in the cpu offload case. Moreover, letting these weights on the modules should not degrade the performance during inference at all. Offloading created a number of issues about the parameters being on themeta
device. While this does not completely solves issues related tometa
device, this should cover most cases users don't use disk offload that much.For now, the default value is
False
but I would like to make it the default behavior + extend it to Transformers. LMK if this make sense.cc @mfuntowicz since you had an issue with offloaded model + quantization
cc @LysandreJik for visibility
Solves partially 1190
TODO: