Skip to content

Commit

Permalink
Test error raised when loading normal and expanding loras together in…
Browse files Browse the repository at this point in the history
… Flux (#10188)

* add test for expanding lora and normal lora error

* Update tests/lora/test_lora_layers_flux.py

* fix things.

* Update src/diffusers/loaders/peft.py

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
a-r-r-o-w and sayakpaul authored Dec 15, 2024
1 parent 96a9097 commit 22c4f07
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,12 +2337,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
f"this please open an issue at https://github.com/huggingface/diffusers/issues."
)

logger.debug(
debug_message = (
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
f"checkpoint contains higher number of features than expected. The number of input_features will be "
f"expanded from {module_in_features} to {in_features}, and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
f"expanded from {module_in_features} to {in_features}"
)
if module_out_features != out_features:
debug_message += (
", and the number of output features will be "
f"expanded from {module_out_features} to {out_features}."
)
else:
debug_message += "."
logger.debug(debug_message)

has_param_with_shape_update = True
parent_module_name, _, current_module_name = name.rpartition(".")
Expand Down
19 changes: 17 additions & 2 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
weights.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
Expand Down Expand Up @@ -316,8 +317,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except RuntimeError as e:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)

self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

warn_msg = ""
if incompatible_keys is not None:
Expand Down
116 changes: 116 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,122 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

def test_lora_expanding_shape_with_normal_lora_raises_error(self):
# TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
# another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
# When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer

pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4

shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")

self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

# The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
# input features before expansion. This should raise an error about the weight shapes being incompatible.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
# We should have `adapter-1` as the only adapter.
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])

# Check if the output is the same after lora loading error
lora_output_after_error = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output, lora_output_after_error, atol=1e-3, rtol=1e-3))

# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
# original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
# weight is compatible with the current model inadequate. This should be addressed when attempting support for
# https://github.com/huggingface/diffusers/issues/10180 (TODO)
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
components["transformer"] = transformer

pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4

lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}

with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
self.assertFalse(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}

# We should check for input shapes being incompatible here. But because above mentioned issue is
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
# mismatch error.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down

0 comments on commit 22c4f07

Please sign in to comment.