Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 5, 2024
1 parent eac6fd1 commit a6158d7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
35 changes: 35 additions & 0 deletions examples/control-lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Training Control LoRA with Flux

This example shows how train Control LoRA with Flux to condition it with additional structural controls (like depth maps, poses, etc.).

This is still an experimental version and the following differences exist:

* No use of bias on `lora_B`.
* Mo updates on the norm scales.

We simply expand the input channels of Flux.1 Dev from 64 to 128 to allow for additional inputs and then train a regular LoRA on top of it. To account for the newly added input channels, we additional append a LoRA on the underlying layer (`x_embedder`). Inference, however, is performed with the `FluxControlPipeline`.

Example command:

```bash
accelerate launch train_control_lora_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control-lora" \
--mixed_precision="bf16" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--seed="0" \
--push_to_hub
```

You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999).
6 changes: 5 additions & 1 deletion examples/control-lora/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,12 +817,16 @@ def main(args):
new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
new_linear.bias.copy_(flux_transformer.x_embedder.bias)
flux_transformer.x_embedder = new_linear
flux_transformer.register_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)

if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
# add the input layer to the mix.
if "x_embedder" not in target_modules:
target_modules.append("x_embedder")
else:
target_modules = [
"x_embedder",
"attn.to_k",
"attn.to_q",
"attn.to_v",
Expand Down

0 comments on commit a6158d7

Please sign in to comment.