diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e1ac6a498a12..40d8804893e2 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -45,6 +45,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -768,6 +769,21 @@ def load_textual_inversion( f" `{self.load_textual_inversion.__name__}`" ) + # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -921,6 +937,12 @@ def load_textual_inversion( for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding + # offload back + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + class LoraLoaderMixin: r""" @@ -952,6 +974,21 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ + # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recurive = False + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recurive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recurive) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_text_encoder( @@ -961,6 +998,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod def lora_state_dict( cls, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index b20d1f0c636e..c64204501b97 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1549,6 +1549,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1576,6 +1596,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index f54adc1a1aba..391d58134627 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1212,6 +1212,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1239,6 +1259,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2ab8003b95f0..1cb1b05a3d27 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -916,6 +916,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -943,6 +963,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod def save_lora_weights( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index bfb3861c9965..4c7aa1ff4668 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1070,6 +1070,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1097,6 +1117,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index ab8722b0bfa0..23105b1413e3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1384,6 +1384,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1411,6 +1431,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 253427a87b1f..f5aa95b6571b 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -1081,6 +1081,42 @@ def test_a1111(self): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_a1111_with_model_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_a1111_with_sequential_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_kohya_sd_v15_with_higher_dimensions(self): generator = torch.Generator().manual_seed(0) @@ -1257,10 +1293,10 @@ def test_sdxl_1_0_lora(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -1413,3 +1449,21 @@ def test_sdxl_1_0_fuse_unfuse_all(self): assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd) assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd) assert state_dicts_almost_equal(unet_sd, new_unet_sd) + + def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 7935a63eceaa..31de557a0ac3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1019,6 +1019,56 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 8e-1 + def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_model_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + + def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_sequential_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + @require_torch_2 def test_stable_diffusion_compile(self): seed = 0