Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Training] Flux Control LoRA training script #10130

Merged
merged 102 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
2829679
update
a-r-r-o-w Nov 21, 2024
be67dbd
Merge branch 'main' into flux-new
a-r-r-o-w Nov 21, 2024
f56ffb1
add
yiyixuxu Nov 21, 2024
7e4df06
update
a-r-r-o-w Nov 21, 2024
9ea52da
Merge remote-tracking branch 'origin/flux-fill-yiyi' into flux-new
a-r-r-o-w Nov 21, 2024
217e90c
add control-lora conversion script; make flux loader handle norms; fi…
a-r-r-o-w Nov 21, 2024
b4f1cbf
control lora updates
a-r-r-o-w Nov 22, 2024
414b30b
remove copied-from
a-r-r-o-w Nov 22, 2024
6b02ac2
create separate pipelines for flux control
a-r-r-o-w Nov 22, 2024
3169bf5
make fix-copies
a-r-r-o-w Nov 22, 2024
f7f006d
update docs
a-r-r-o-w Nov 22, 2024
8bb940e
add tests
a-r-r-o-w Nov 22, 2024
9e615fd
fix
a-r-r-o-w Nov 22, 2024
6d168db
Merge branch 'main' into flux-new
a-r-r-o-w Nov 22, 2024
89fd970
Apply suggestions from code review
a-r-r-o-w Nov 22, 2024
73cfc51
remove control lora changes
a-r-r-o-w Nov 22, 2024
c94966f
apply suggestions from review
a-r-r-o-w Nov 22, 2024
cfe13e7
Revert "remove control lora changes"
a-r-r-o-w Nov 22, 2024
0c959a7
update
a-r-r-o-w Nov 23, 2024
6ef2c8b
update
a-r-r-o-w Nov 23, 2024
42970ee
improve log messages
a-r-r-o-w Nov 23, 2024
2ec93ba
Merge branch 'main' into flux-control-lora
a-r-r-o-w Nov 23, 2024
993f3d3
Merge branch 'main' into flux-control-lora
sayakpaul Nov 25, 2024
6523fa6
updates.
sayakpaul Nov 25, 2024
81ab40b
updates
sayakpaul Nov 25, 2024
4432e73
Merge branch 'main' into flux-control-lora
sayakpaul Nov 25, 2024
0f747c0
Merge branch 'main' into flux-control-lora
sayakpaul Nov 26, 2024
6d0c6dc
Merge branch 'flux-control-lora' into sayak-flux-control-lora
sayakpaul Nov 26, 2024
1633619
support register_config.
sayakpaul Nov 26, 2024
b9039b1
fix
sayakpaul Nov 26, 2024
5f94d74
fix
sayakpaul Nov 26, 2024
bd31651
fix
sayakpaul Nov 26, 2024
e18b7ad
Merge branch 'main' into flux-control-lora
a-r-r-o-w Nov 27, 2024
f54ec56
updates
sayakpaul Nov 28, 2024
8032405
updates
sayakpaul Nov 28, 2024
6b70bf7
updates
sayakpaul Nov 28, 2024
3726e2d
fix-copies
sayakpaul Nov 28, 2024
b6ca9d9
Merge branch 'main' into flux-control-lora
sayakpaul Nov 28, 2024
908d151
fix
sayakpaul Nov 29, 2024
6af2097
Merge branch 'main' into flux-control-lora
sayakpaul Nov 29, 2024
07d44e7
apply suggestions from review
a-r-r-o-w Dec 1, 2024
b66e691
add tests
a-r-r-o-w Dec 1, 2024
66d7466
remove conversion script; enable on-the-fly conversion
a-r-r-o-w Dec 2, 2024
d827d1e
Merge branch 'main' into flux-control-lora
a-r-r-o-w Dec 2, 2024
64c821b
bias -> lora_bias.
sayakpaul Dec 2, 2024
30a89a6
fix-copies
sayakpaul Dec 2, 2024
bca1eaa
peft.py
sayakpaul Dec 2, 2024
6ce181b
Merge branch 'main' into flux-control-lora
sayakpaul Dec 2, 2024
e7df197
fix lora conversion
a-r-r-o-w Dec 2, 2024
5fd9fda
changes
sayakpaul Dec 3, 2024
a8c50ba
fix-copies
sayakpaul Dec 3, 2024
b12f797
updates for tests
sayakpaul Dec 3, 2024
f9bd3eb
fix
sayakpaul Dec 3, 2024
6b35c92
Merge branch 'main' into flux-control-lora
sayakpaul Dec 3, 2024
84c168c
alpha_pattern.
sayakpaul Dec 4, 2024
118ed9b
Merge branch 'main' into flux-control-lora
sayakpaul Dec 4, 2024
be1d788
add a test for varied lora ranks and alphas.
sayakpaul Dec 4, 2024
5b1bcd8
revert changes in num_channels_latents = self.transformer.config.in_c…
sayakpaul Dec 4, 2024
cde01e3
revert moe
sayakpaul Dec 4, 2024
79af91d
Merge branch 'main' into flux-control-lora
sayakpaul Dec 5, 2024
4b3efcc
Merge branch 'main' into flux-control-lora
sayakpaul Dec 5, 2024
f688ecf
add a sanity check on unexpected keys when loading norm layers.
sayakpaul Dec 5, 2024
eac6fd1
contro lora.
sayakpaul Dec 5, 2024
a6158d7
fixes
sayakpaul Dec 5, 2024
90708fa
fixes
sayakpaul Dec 5, 2024
8765e1b
Merge branch 'flux-control-lora' into flux-control-lora-training-script
sayakpaul Dec 5, 2024
d6518b7
Merge branch 'main' into flux-control-lora
sayakpaul Dec 6, 2024
9a83eff
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 6, 2024
ecbc4cb
fixes
sayakpaul Dec 6, 2024
55058e2
tests
sayakpaul Dec 6, 2024
a8bd03b
reviewer feedback
sayakpaul Dec 6, 2024
49c0242
fix
sayakpaul Dec 6, 2024
8b050ea
proper peft version for lora_bias
sayakpaul Dec 6, 2024
3204627
fix-copies
sayakpaul Dec 6, 2024
6ce2307
Merge branch 'flux-control-lora' into flux-control-lora-training-script
sayakpaul Dec 6, 2024
1330d17
updates
sayakpaul Dec 6, 2024
9007de0
updates
sayakpaul Dec 6, 2024
67bc7e4
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 6, 2024
7521fec
updates
sayakpaul Dec 6, 2024
2b9bfa3
Merge branch 'main' into flux-control-lora
a-r-r-o-w Dec 6, 2024
130e592
remove debug code
a-r-r-o-w Dec 6, 2024
b20ec7d
update docs
a-r-r-o-w Dec 6, 2024
79d023a
Merge branch 'main' into flux-control-lora
sayakpaul Dec 7, 2024
d1715d3
integration tests
sayakpaul Dec 7, 2024
eb862de
Merge branch 'flux-control-lora' into flux-control-lora-training-script
sayakpaul Dec 7, 2024
cbad4b3
nis
sayakpaul Dec 7, 2024
cd7c155
fuse and unload.
sayakpaul Dec 7, 2024
25616e2
fix
sayakpaul Dec 7, 2024
0b83deb
add slices.
sayakpaul Dec 7, 2024
4d10b33
Merge branch 'flux-control-lora' into flux-control-lora-training-script
sayakpaul Dec 8, 2024
b06376f
more updates.
sayakpaul Dec 9, 2024
7d63a2a
button up readme
sayakpaul Dec 9, 2024
9d90c51
train()
sayakpaul Dec 9, 2024
f46330b
add full fine-tuning version.
sayakpaul Dec 9, 2024
f188e80
fixes
sayakpaul Dec 10, 2024
6c1a60f
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 11, 2024
2610e6a
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 11, 2024
ecc16ed
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 12, 2024
7dfe378
Apply suggestions from code review
sayakpaul Dec 12, 2024
f1d9550
set_grads_to_none remove.
sayakpaul Dec 12, 2024
02016ca
readme
sayakpaul Dec 12, 2024
64814c9
Merge branch 'main' into flux-control-lora-training-script
sayakpaul Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/control-lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Training Control LoRA with Flux

This (experimental) example shows how train Control LoRA with [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) to condition it with additional structural controls (like depth maps, poses, etc.).
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

We simply expand the input channels of Flux.1 Dev from 64 to 128 to allow for additional inputs and then train a regular LoRA on top of it. To account for the newly added input channels, we additional append a LoRA on the underlying layer (`x_embedder`). Inference, however, is performed with the `FluxControlPipeline`.
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:

```bash
huggingface-cli login
```

Example command:

```bash
accelerate launch train_control_lora_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control-lora" \
--mixed_precision="bf16" \
--train_batch_size=1 \
--rank=64 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--seed="0" \
--push_to_hub
```

`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).

You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.

The training script exposes additional CLI args that might be useful to experiment with:

* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.

## Training with DeepSpeed

It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):

```yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```

And then while launching training, pass the config file:

```bash
accelerate launch --config_file=CONFIG_FILE.yaml ...
```

## Full fine-tuning

We provide a non-LoRA version of the training script `train_control_flux.py`. Here is an example command:

```bash
accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control" \
--mixed_precision="bf16" \
--train_batch_size=2 \
--dataloader_num_workers=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--proportion_empty_prompts=0.2 \
--learning_rate=5e-5 \
--adam_weight_decay=1e-4 \
--set_grads_to_none \
--report_to="wandb" \
--lr_scheduler="cosine" \
--lr_warmup_steps=1000 \
--checkpointing_steps=1000 \
--max_train_steps=10000 \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--seed="0" \
--push_to_hub
```

Change the `validation_image` and `validation_prompt` as needed.
6 changes: 6 additions & 0 deletions examples/control-lora/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transformers==4.47.0
wandb
torch
torchvision
accelerate==1.2.0
peft>=0.14.0
Loading
Loading