Skip to content

Commit

Permalink
[Training] Add datasets version of LCM LoRA SDXL (huggingface#5778)
Browse files Browse the repository at this point in the history
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <[email protected]>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <[email protected]>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <[email protected]>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Fabio Rigano <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <[email protected]>
Co-authored-by: Charchit Sharma <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
12 people authored and adrvs committed Jan 2, 2024
1 parent 6e5cb95 commit 9594df5
Show file tree
Hide file tree
Showing 4 changed files with 1,507 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def save_model_card(
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""

Expand Down
36 changes: 35 additions & 1 deletion examples/consistency_distillation/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--report_to=wandb \
--seed=453645634 \
--push_to_hub \
```
```

We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.

Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):

```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

accelerate launch train_lcm_distill_lora_sdxl.py \
--pretrained_teacher_model=${MODEL_NAME} \
--pretrained_vae_model_name_or_path=${VAE_PATH} \
--output_dir="pokemons-lora-lcm-sdxl" \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=1024 \
--train_batch_size=24 \
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--use_8bit_adam \
--lora_rank=64 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=3000 \
--checkpointing_steps=500 \
--validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```

112 changes: 112 additions & 0 deletions examples/consistency_distillation/test_lcm_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class TextToImageLCM(ExamplesTestsAccelerate):
def test_text_to_image_lcm_lora_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)

test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--checkpointing_steps 2
--resume_from_checkpoint latest
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
Loading

0 comments on commit 9594df5

Please sign in to comment.