Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] better support offloading when side loading is enabled. #4855

Merged
merged 13 commits into from
Sep 5, 2023
39 changes: 39 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -763,6 +764,19 @@ def load_textual_inversion(
f" `{self.load_textual_inversion.__name__}`"
)

# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't one of these two hooks styles hook into every sub module as well, so shouldn't one of the checks be recursive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr to help here a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr to help here a bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to do remove_hook_from_module(component, recursive=True)

CC @SunMarc too for a second glance :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is it required here? Sorry for not making my comment clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess trying to understand just what we're aiming to achieve (solid guess based on context, let me know if I'm accurate):

  • Given a model that may be loaded in via device_map="auto" or some form of device_map
  • We wish to remove the hooks when wanting to load the full model into memory/remove it from BMI (big model inference)
  • With the potential of placing it back later

Is this accurate? Otherwise may need a bit more info/context I'm missing somehow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3922 (comment)

We want to be able to detect if a torch.nn.Module has hooks and we want to remove them. That is the bit relevant to accelerate. Then after loading some auxiliary weights, we want to load the appropriate hooks back in.

Let me know if that helps?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understood from the codebase, if we have is_sequential_cpu_offload, it means that the components were offloaded using cpu_offload which places recursively the hooks on each submodules. In the case of is_model_cpu_offload, we use cpu_offload_from_hook which place only one hook on the module, so that the entire model will be offloaded when another hook is triggered.
I would then suggest using remove_hook_from_module(component, recursive=True) for the first case and remove_hook_from_module(component, recursive=False) for the second case if you don't want to just recursively remove all the hooks for both cases !

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
Expand Down Expand Up @@ -916,6 +930,12 @@ def load_textual_inversion(
for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

# offload back
if is_model_cpu_offload:
self.enable_model_cpu_offload()
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()


class LoraLoaderMixin:
r"""
Expand Down Expand Up @@ -946,6 +966,19 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component)

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder(
Expand All @@ -955,6 +988,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
def lora_state_dict(
cls,
Expand Down
23 changes: 23 additions & 0 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
Comment on lines +1230 to +1232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the log statement might be a bit noisy. It'd be nice if we expected the user to do additional things with the placed accelerate hooks and should be aware if they expected some state to be maintained or something but we definitely don't want the user to touch the hooks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's relatively simple given the context the message is being raised from. If you have a better suggestion, let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think my main point is the log is a bit noisy given that it leaks what is supposed to be an internal implementation detail, I think it's not really something that should be exposed to an end user

remove_hook_from_module(component)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1248,6 +1265,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -934,6 +951,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
def save_lora_weights(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1092,6 +1109,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1397,6 +1414,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down