From 21a7ff12a75ecf43a85898838d1990cda853ffaf Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 1 May 2024 06:25:57 -1000 Subject: [PATCH] update the logic of `is_sequential_cpu_offload` (#7788) * up * add comment to the tests + fix dit --------- Co-authored-by: Sayak Paul --- .../community/pipeline_demofusion_sdxl.py | 6 +- src/diffusers/loaders/lora.py | 6 +- src/diffusers/loaders/textual_inversion.py | 6 +- src/diffusers/loaders/unet.py | 6 +- src/diffusers/pipelines/dit/pipeline_dit.py | 3 + src/diffusers/pipelines/pipeline_utils.py | 9 +- tests/pipelines/pixart_alpha/test_pixart.py | 4 - tests/pipelines/pixart_sigma/test_pixart.py | 4 - tests/pipelines/test_pipelines_common.py | 103 ++++++++++++++++-- 9 files changed, 123 insertions(+), 24 deletions(-) diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py index 93e1463638f0..f46d635dae2b 100644 --- a/examples/community/pipeline_demofusion_sdxl.py +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -1304,7 +1304,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if isinstance(component, torch.nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 8703cdee4011..d69db5a83af1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -369,7 +369,11 @@ def _optionally_disable_offloading(cls, _pipeline): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: - is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index c1c224975cb8..05ed64f5dcad 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -423,7 +423,11 @@ def load_textual_inversion( if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 294db44ee61d..3e74411865a3 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -359,7 +359,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index 289ea496028d..a3ea90874a12 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -227,6 +227,9 @@ def __call__( if output_type == "pil": samples = self.numpy_to_pil(samples) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (samples,) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 68433332546b..59e38c910d4a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -376,7 +376,11 @@ def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + return hasattr(module, "_hf_hook") and ( + isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) + or hasattr(module._hf_hook, "hooks") + and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) + ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): @@ -1005,8 +1009,7 @@ def remove_all_hooks(self): """ for _, model in self.components.items(): if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): - is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook) - accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload) + accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index d981b55260c7..dd358af08395 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -324,10 +324,6 @@ def test_raises_warning_for_mask_feature(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) - # PixArt transformer model does not work with sequential offload so skip it for now - def test_sequential_offload_forward_pass_twice(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 7b1d5e389f32..c0df15ae661d 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -308,10 +308,6 @@ def test_inference_with_multiple_images_per_prompt(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) - # PixArt transformer model does not work with sequential offload so skip it for now - def test_sequential_offload_forward_pass_twice(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0c0a765f662d..032fbb81ea31 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass( reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", ) def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): + import accelerate + components = self.get_dummy_components() pipe = self.pipeline_class(**components) for component in pipe.components.values(): @@ -1373,6 +1375,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): output_without_offload = pipe(**inputs)[0] pipe.enable_sequential_cpu_offload() + assert pipe._execution_device.type == pipe._offload_device.type inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] @@ -1380,11 +1383,48 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly + offloaded_modules = { + k: v + for k, v in pipe.components.items() + if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload + } + # 1. all offloaded modules should be saved to cpu and moved to meta device + self.assertTrue( + all(v.device.type == "meta" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}", + ) + # 2. all offloaded modules should have hook installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. all offloaded modules should have correct hooks installed, should be either one of these two + # - `AlignDevicesHook` + # - a SequentialHook` that contains `AlignDevicesHook` + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook"): + if isinstance(v._hf_hook, accelerate.hooks.SequentialHook): + # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook` + for hook in v._hf_hook.hooks: + if not isinstance(hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0]) + elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", + ) + @unittest.skipIf( torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher", ) def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): + import accelerate + generator_device = "cpu" components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): output_without_offload = pipe(**inputs)[0] pipe.enable_model_cpu_offload() + assert pipe._execution_device.type == pipe._offload_device.type + inputs = self.get_dummy_inputs(generator_device) output_with_offload = pipe(**inputs)[0] max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") - offloaded_modules = [ - v + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly + offloaded_modules = { + k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload - ] - ( - self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)), - f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}", + } + # 1. check if all offloaded modules are saved to cpu + self.assertTrue( + all(v.device.type == "cpu" for v in offloaded_modules.values()), + f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}", + ) + # 2. check if all offloaded modules have hooks installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload` + offloaded_modules_with_incorrect_hooks = {} + for k, v in offloaded_modules.items(): + if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + + self.assertTrue( + len(offloaded_modules_with_incorrect_hooks) == 0, + f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", ) @unittest.skipIf( @@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): self.assertLess( max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results" ) + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly offloaded_modules = { k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload } + # 1. check if all offloaded modules are saved to cpu self.assertTrue( all(v.device.type == "cpu" for v in offloaded_modules.values()), f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}", ) - + # 2. check if all offloaded modules have hooks installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload` offloaded_modules_with_incorrect_hooks = {} for k, v in offloaded_modules.items(): if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload): @@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): self.assertLess( max_diff, expected_max_diff, "running sequential offloading second time should have the inference results" ) + + # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly offloaded_modules = { k: v for k, v in pipe.components.items() if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload } + # 1. check if all offloaded modules are moved to meta device self.assertTrue( all(v.device.type == "meta" for v in offloaded_modules.values()), f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}", ) + # 2. check if all offloaded modules have hook installed + self.assertTrue( + all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()), + f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}", + ) + # 3. check if all offloaded modules have correct hooks installed, should be either one of these two + # - `AlignDevicesHook` + # - a SequentialHook` that contains `AlignDevicesHook` offloaded_modules_with_incorrect_hooks = {} for k, v in offloaded_modules.items(): - if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): - offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) + if hasattr(v, "_hf_hook"): + if isinstance(v._hf_hook, accelerate.hooks.SequentialHook): + # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook` + for hook in v._hf_hook.hooks: + if not isinstance(hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0]) + elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook): + offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook) self.assertTrue( len(offloaded_modules_with_incorrect_hooks) == 0,