-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] better support offloading when side loading is enabled. #4855
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I think for tests, we will need to add them as SLOW tests requiring GPUs. Given the importance of these use cases, I won't mind adding them. Any objections add these new SLOW tests? @patrickvonplaten @williamberman. |
src/diffusers/loaders.py
Outdated
# Remove any existing hooks. | ||
is_model_cpu_offload = False | ||
is_sequential_cpu_offload = False | ||
for _, component in self.components.items(): | ||
if isinstance(component, nn.Module): | ||
if hasattr(component, "_hf_hook"): | ||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) | ||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) | ||
logger.info( | ||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." | ||
) | ||
remove_hook_from_module(component) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't one of these two hooks styles hook into every sub module as well, so shouldn't one of the checks be recursive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr to help here a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr to help here a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to do remove_hook_from_module(component, recursive=True)
CC @SunMarc too for a second glance :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But is it required here? Sorry for not making my comment clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess trying to understand just what we're aiming to achieve (solid guess based on context, let me know if I'm accurate):
- Given a model that may be loaded in via
device_map="auto"
or some form ofdevice_map
- We wish to remove the hooks when wanting to load the full model into memory/remove it from BMI (big model inference)
- With the potential of placing it back later
Is this accurate? Otherwise may need a bit more info/context I'm missing somehow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to be able to detect if a torch.nn.Module
has hooks and we want to remove them. That is the bit relevant to accelerate
. Then after loading some auxiliary weights, we want to load the appropriate hooks back in.
Let me know if that helps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understood from the codebase, if we have is_sequential_cpu_offload
, it means that the components were offloaded using cpu_offload
which places recursively the hooks on each submodules. In the case of is_model_cpu_offload
, we use cpu_offload_from_hook
which place only one hook on the module, so that the entire model will be offloaded when another hook is triggered.
I would then suggest using remove_hook_from_module(component, recursive=True)
for the first case and remove_hook_from_module(component, recursive=False)
for the second case if you don't want to just recursively remove all the hooks for both cases !
logger.info( | ||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the log statement might be a bit noisy. It'd be nice if we expected the user to do additional things with the placed accelerate hooks and should be aware if they expected some state to be maintained or something but we definitely don't want the user to touch the hooks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's relatively simple given the context the message is being raised from. If you have a better suggestion, let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think my main point is the log is a bit noisy given that it leaks what is supposed to be an internal implementation detail, I think it's not really something that should be exposed to an end user
design makes sense! few quick questions |
if isinstance(component, nn.Module): | ||
if hasattr(component, "_hf_hook"): | ||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) | ||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Looks good to me! |
@patrickvonplaten ready for another round of review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go for me!
huggingface#4855)" This reverts commit e4b8e79.
…ngface#4855) * better support offloading when side loading is enabled. * load_textual_inversion * better messaging for textual inversion. * fixes * address PR feedback. * sdxl support. * improve messaging * recursive removal when cpu sequential offloading is enabled. * add: lora tests * recruse. * add: offload tests for textual inversion.
…abled… (huggingface#4927) Revert "[Core] better support offloading when side loading is enabled. (huggingface#4855)" This reverts commit e4b8e79.
…ngface#4855) * better support offloading when side loading is enabled. * load_textual_inversion * better messaging for textual inversion. * fixes * address PR feedback. * sdxl support. * improve messaging * recursive removal when cpu sequential offloading is enabled. * add: lora tests * recruse. * add: offload tests for textual inversion.
…abled… (huggingface#4927) Revert "[Core] better support offloading when side loading is enabled. (huggingface#4855)" This reverts commit e4b8e79.
Potentially fixes:
Todo: