Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] better support offloading when side loading is enabled. #4855

Merged
merged 13 commits into from
Sep 5, 2023
43 changes: 43 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -768,6 +769,21 @@ def load_textual_inversion(
f" `{self.load_textual_inversion.__name__}`"
)

# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
Expand Down Expand Up @@ -921,6 +937,12 @@ def load_textual_inversion(
for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

# offload back
if is_model_cpu_offload:
self.enable_model_cpu_offload()
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()


class LoraLoaderMixin:
r"""
Expand Down Expand Up @@ -952,6 +974,21 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recurive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)

state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder(
Expand All @@ -961,6 +998,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
def lora_state_dict(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1576,6 +1596,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
26 changes: 26 additions & 0 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
Comment on lines +1230 to +1232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the log statement might be a bit noisy. It'd be nice if we expected the user to do additional things with the placed accelerate hooks and should be aware if they expected some state to be maintained or something but we definitely don't want the user to touch the hooks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's relatively simple given the context the message is being raised from. If you have a better suggestion, let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think my main point is the log is a bit noisy given that it leaks what is supposed to be an internal implementation detail, I think it's not really something that should be exposed to an end user

recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1239,6 +1259,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -943,6 +963,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
def save_lora_weights(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1097,6 +1117,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")

is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
Expand Down Expand Up @@ -1411,6 +1431,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
lora_scale=self.lora_scale,
)

# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()

@classmethod
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
def save_lora_weights(
Expand Down
56 changes: 55 additions & 1 deletion tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,42 @@ def test_a1111(self):

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_a1111_with_sequential_cpu_offload(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_kohya_sd_v15_with_higher_dimensions(self):
generator = torch.Generator().manual_seed(0)

Expand Down Expand Up @@ -1257,10 +1293,10 @@ def test_sdxl_1_0_lora(self):
generator = torch.Generator().manual_seed(0)

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
Expand Down Expand Up @@ -1413,3 +1449,21 @@ def test_sdxl_1_0_fuse_unfuse_all(self):
assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
assert state_dicts_almost_equal(unet_sd, new_unet_sd)

def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
generator = torch.Generator().manual_seed(0)

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_sequential_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])

self.assertTrue(np.allclose(images, expected, atol=1e-3))
Loading