diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index f6dea33e8e82..d829cc3a844b 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): f"transformer.single_transformer_blocks.{i}.norm.linear", ) + remaining_keys = list(sds_sd.keys()) + te_state_dict = {} + if remaining_keys: + if not all(k.startswith("lora_te1") for k in remaining_keys): + raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") + for key in remaining_keys: + if not key.endswith("lora_down.weight"): + continue + + lora_name = key.split(".")[0] + lora_name_up = f"{lora_name}.lora_up.weight" + lora_name_alpha = f"{lora_name}.alpha" + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + if lora_name.startswith(("lora_te_", "lora_te1_")): + down_weight = sds_sd.pop(key) + sd_lora_rank = down_weight.shape[0] + te_state_dict[diffusers_name] = down_weight + te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up) + + if lora_name_alpha in sds_sd: + alpha = sds_sd.pop(lora_name_alpha).item() + scale = alpha / sd_lora_rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + te_state_dict[diffusers_name] *= scale_down + te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up + if len(sds_sd) > 0: - logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}") + + if te_state_dict: + te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()} - return ait_sd + new_state_dict = {**ait_sd, **te_state_dict} + return new_state_dict return _convert_sd_scripts_to_ai_toolkit(state_dict) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 0c336ebc3cbf..a75f9df91047 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -228,6 +228,26 @@ def test_flux_kohya(self): assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_flux_kohya_with_text_encoder(self): + self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline.enable_model_cpu_offload() + + prompt = "optimus is cleaning the house with broomstick" + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=4.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219]) + + assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_flux_xlabs(self): self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") self.pipeline.fuse_lora()