Skip to content

Commit

Permalink
update the logic of is_sequential_cpu_offload (#7788)
Browse files Browse the repository at this point in the history
* up

* add comment to the tests + fix dit

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
yiyixuxu and sayakpaul authored May 1, 2024
1 parent 8909ab4 commit 21a7ff1
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 24 deletions.
6 changes: 5 additions & 1 deletion examples/community/pipeline_demofusion_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,11 @@ def _optionally_disable_offloading(cls, _pipeline):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)

logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,11 @@ def load_textual_inversion(
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)

logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def __call__(
if output_type == "pil":
samples = self.numpy_to_pil(samples)

# Offload all models
self.maybe_free_model_hooks()

if not return_dict:
return (samples,)

Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,11 @@ def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
)

def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
Expand Down Expand Up @@ -1005,8 +1009,7 @@ def remove_all_hooks(self):
"""
for _, model in self.components.items():
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
is_sequential_cpu_offload = isinstance(getattr(model, "_hf_hook"), accelerate.hooks.AlignDevicesHook)
accelerate.hooks.remove_hook_from_module(model, recurse=is_sequential_cpu_offload)
accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = []

def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
Expand Down
4 changes: 0 additions & 4 deletions tests/pipelines/pixart_alpha/test_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,6 @@ def test_raises_warning_for_mask_feature(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)

# PixArt transformer model does not work with sequential offload so skip it for now
def test_sequential_offload_forward_pass_twice(self):
pass


@slow
@require_torch_gpu
Expand Down
4 changes: 0 additions & 4 deletions tests/pipelines/pixart_sigma/test_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def test_inference_with_multiple_images_per_prompt(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)

# PixArt transformer model does not work with sequential offload so skip it for now
def test_sequential_offload_forward_pass_twice(self):
pass


@slow
@require_torch_gpu
Expand Down
103 changes: 94 additions & 9 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass(
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
)
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
import accelerate

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
Expand All @@ -1373,18 +1375,56 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
output_without_offload = pipe(**inputs)[0]

pipe.enable_sequential_cpu_offload()
assert pipe._execution_device.type == pipe._offload_device.type

inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]

max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")

# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
}
# 1. all offloaded modules should be saved to cpu and moved to meta device
self.assertTrue(
all(v.device.type == "meta" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
)
# 2. all offloaded modules should have hook installed
self.assertTrue(
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
)
# 3. all offloaded modules should have correct hooks installed, should be either one of these two
# - `AlignDevicesHook`
# - a SequentialHook` that contains `AlignDevicesHook`
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook"):
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
for hook in v._hf_hook.hooks:
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)

self.assertTrue(
len(offloaded_modules_with_incorrect_hooks) == 0,
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)

@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
import accelerate

generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
Expand All @@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
output_without_offload = pipe(**inputs)[0]

pipe.enable_model_cpu_offload()
assert pipe._execution_device.type == pipe._offload_device.type

inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]

max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
offloaded_modules = [
v

# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
]
(
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
}
# 1. check if all offloaded modules are saved to cpu
self.assertTrue(
all(v.device.type == "cpu" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
)
# 2. check if all offloaded modules have hooks installed
self.assertTrue(
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
)
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)

self.assertTrue(
len(offloaded_modules_with_incorrect_hooks) == 0,
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)

@unittest.skipIf(
Expand Down Expand Up @@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
self.assertLess(
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
)

# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
}
# 1. check if all offloaded modules are saved to cpu
self.assertTrue(
all(v.device.type == "cpu" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
)

# 2. check if all offloaded modules have hooks installed
self.assertTrue(
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
)
# 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
Expand Down Expand Up @@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
self.assertLess(
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
)

# make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
offloaded_modules = {
k: v
for k, v in pipe.components.items()
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
}
# 1. check if all offloaded modules are moved to meta device
self.assertTrue(
all(v.device.type == "meta" for v in offloaded_modules.values()),
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
)
# 2. check if all offloaded modules have hook installed
self.assertTrue(
all(hasattr(v, "_hf_hook") for k, v in offloaded_modules.items()),
f"No hook attached: {[k for k, v in offloaded_modules.items() if not hasattr(v, '_hf_hook')]}",
)
# 3. check if all offloaded modules have correct hooks installed, should be either one of these two
# - `AlignDevicesHook`
# - a SequentialHook` that contains `AlignDevicesHook`
offloaded_modules_with_incorrect_hooks = {}
for k, v in offloaded_modules.items():
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
if hasattr(v, "_hf_hook"):
if isinstance(v._hf_hook, accelerate.hooks.SequentialHook):
# if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
for hook in v._hf_hook.hooks:
if not isinstance(hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook.hooks[0])
elif not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)

self.assertTrue(
len(offloaded_modules_with_incorrect_hooks) == 0,
Expand Down

0 comments on commit 21a7ff1

Please sign in to comment.