Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Training] Flux Control LoRA training script #10130

Merged
merged 102 commits into from
Dec 12, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 5, 2024

Will be updated later.

This PR has both LoRA and a non-LoRA version for training Control Flux.

Number of changed files are more than expected because the flux-control-lora branch in merged into this PR branch.

@yiyixuxu yiyixuxu requested a review from a-r-r-o-w December 10, 2024 18:33
@a-r-r-o-w
Copy link
Member

@sayakpaul I think some things like the model card are yet to be updated, and there's probably more things on your mind. Once you mark this ready for review, I can take a deeper look - everything looks great already though! Thanks for involving me in the discussions and looking forward to starting my own training run soon!

@sayakpaul
Copy link
Member Author

I think some things like the model card are yet to be updated

I think I have the model cards already updated. But let me know if you think some things should be added/updated.

Marking this PR as ready to review.

@sayakpaul sayakpaul marked this pull request as ready for review December 11, 2024 07:40
@sayakpaul
Copy link
Member Author

@apolinario @linoytsaban do you think it could make sense to show users how to derive a LoRA from the full Control fine-tune like how the BFL folks did?

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM! Just some minor comments

examples/control-lora/README.md Outdated Show resolved Hide resolved
examples/control-lora/README.md Outdated Show resolved Hide resolved
return pixel_latents.to(weight_dtype)


def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean!

# controlnet-{repo_id}

These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is what I missed earlier when I though model card was incomplete. The validation images will be displayed, awesome!

examples/control-lora/train_control_flux.py Outdated Show resolved Hide resolved
sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
# Concatenate across channels.
# Question: Should we concatenate before adding noise?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you have currently is correct. Noisy latent + clean control latent

examples/control-lora/train_control_lora_flux.py Outdated Show resolved Hide resolved
examples/control-lora/train_control_lora_flux.py Outdated Show resolved Hide resolved
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if args.use_lora_bias and args.gaussian_init_lora:
raise ValueError("`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why this is the case. We could set initialization to gaussian before the lora bias changes IIRC, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we do:

LoraConfig(lora_bias=True, init_lora_weights="gaussian", ...)

it errors out. Maybe a question for @BenjaminBossan.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed earlier, it just hasn't been implemented yet. The question is: If a user wants Gaussian init for the LoRA weights, how do we initialize the LoRA bias? If we have an answer to that, we can enable Gaussian init for lora_bias.

@sayakpaul
Copy link
Member Author

@a-r-r-o-w I have addressed your comments, thank you!

If you could review the changes and the README mods (made in 02016ca), that would be helpful!

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get the GPUs burning! LoRA testing scripts can be added in a follow-up PR if required but not very important

@sayakpaul
Copy link
Member Author

Thanks, I think also okay to not add tests just yet. Based on the usage, we can always revisit.

@sayakpaul sayakpaul merged commit 8170dc3 into main Dec 12, 2024
12 checks passed
@sayakpaul sayakpaul deleted the flux-control-lora-training-script branch December 12, 2024 10:05
@Adenialzz
Copy link
Contributor

Adenialzz commented Dec 15, 2024

Hi, thanks for your implementation.

I'm curious why we set model offloading as default and don't have the option to turn it off. In my testing, this can significantly reduce training speed ( ~5x in my case). I recommend making model offloading optional in args

@sayakpaul
Copy link
Member Author

@Adenialzz, thanks for your comment. Feel free to open a PR :)

@Adenialzz
Copy link
Contributor

Sure. Here #10225 .

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* update

* add

* update

* add control-lora conversion script; make flux loader handle norms; fix rank calculation assumption

* control lora updates

* remove copied-from

* create separate pipelines for flux control

* make fix-copies

* update docs

* add tests

* fix

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* remove control lora changes

* apply suggestions from review

* Revert "remove control lora changes"

This reverts commit 73cfc51.

* update

* update

* improve log messages

* updates.

* updates

* support register_config.

* fix

* fix

* fix

* updates

* updates

* updates

* fix-copies

* fix

* apply suggestions from review

* add tests

* remove conversion script; enable on-the-fly conversion

* bias -> lora_bias.

* fix-copies

* peft.py

* fix lora conversion

* changes

Co-authored-by: a-r-r-o-w <[email protected]>

* fix-copies

* updates for tests

* fix

* alpha_pattern.

* add a test for varied lora ranks and alphas.

* revert changes in num_channels_latents = self.transformer.config.in_channels // 8

* revert moe

* add a sanity check on unexpected keys when loading norm layers.

* contro lora.

* fixes

* fixes

* fixes

* tests

* reviewer feedback

* fix

* proper peft version for lora_bias

* fix-copies

* updates

* updates

* updates

* remove debug code

* update docs

* integration tests

* nis

* fuse and unload.

* fix

* add slices.

* more updates.

* button up readme

* train()

* add full fine-tuning version.

* fixes

* Apply suggestions from code review

Co-authored-by: Aryan <[email protected]>

* set_grads_to_none remove.

* readme

---------

Co-authored-by: Aryan <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>
Co-authored-by: a-r-r-o-w <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants