Skip to content

Commit

Permalink
[LoRA] support Kohya Flux LoRAs that have text encoders as well (#9542)
Browse files Browse the repository at this point in the history
* support kohya flux loras that have tes.
  • Loading branch information
sayakpaul authored Sep 30, 2024
1 parent 8e7d6c0 commit f9fd511
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
41 changes: 39 additions & 2 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue

lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)

if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank

scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up

if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")

if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}

return ait_sd
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

return _convert_sd_scripts_to_ai_toolkit(state_dict)

Expand Down
20 changes: 20 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,26 @@ def test_flux_kohya(self):

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()

prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images

out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
Expand Down

0 comments on commit f9fd511

Please sign in to comment.