Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill #10259

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 17, 2024

What does this PR do?

Fixes #10180, #10227, #10184

In short, this PR enables few-steps inference for Flux Control, Fill, Redux, etc.

Fill + Turbo LoRA
from diffusers import FluxFillPipeline
from diffusers.utils import load_image
import torch

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")

adapter_id = "alimama-creative/FLUX.1-Turbo-Alpha"
pipe.load_lora_weights(adapter_id)

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=8,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fill-dev.png")
Flux Control LoRA + Turbo LoRA (different from the previous one)
from diffusers import FluxControlPipeline
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = control_pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=8,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("output.png")

Todods

  • Integration tests
  • Docs

@sayakpaul sayakpaul added the lora label Dec 17, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

Requesting for a review from @BenjaminBossan for initial stuff. Then will request reviews from others.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the functionality of loading LoRA adapters when shapes need to be expanded. The PR LGTM, I only have a nit.

One question that came up (maybe it was already discussed and I just missed it or forgot): Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.

src/diffusers/loaders/lora_pipeline.py Show resolved Hide resolved
src/diffusers/loaders/lora_pipeline.py Outdated Show resolved Hide resolved
@sayakpaul sayakpaul requested a review from a-r-r-o-w December 17, 2024 11:57
@sayakpaul sayakpaul marked this pull request as ready for review December 17, 2024 11:57
@sayakpaul sayakpaul changed the title [WIP][LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill [LoRA] feat: support loading regular Flux LoRAs into Flux Control, and Fill Dec 17, 2024
@sayakpaul
Copy link
Member Author

@BenjaminBossan

Right now, this type of expansion is permanent, right? I.e. even after unloading the LoRA that made the expansion necessary in the first place, the expansion is not undone. Probably that would be quite hard to add and not worth the effort, I'm just curious.

The LoRA state dict expansion is permanent. But model-level state dict expansion can be undone is being added in #10206

Comment on lines -343 to -356
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
# We should error out because lora input features is less than original. We only
# support expanding the module, not shrinking it
with self.assertRaises(NotImplementedError):
pipe.load_lora_weights(lora_state_dict, "adapter-1")
Copy link
Member Author

@sayakpaul sayakpaul Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this part of the test because in case LoRA input feature dimensions are less than the original, we expand it.

This is tested below with test_lora_expanding_shape_with_normal_lora() and test_load_regular_lora().

@BenjaminBossan
Copy link
Member

The LoRA state dict expansion is permanent. But model-level state dict expansion can be undone is being added in #10206

Yes, something similar for LoRA would be nice, but it's not as important, as the overhead for LoRA should be relatively small.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can't load multiple loras when using Flux Control LoRA
3 participants