Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 313 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
313 commits
Select commit Hold shift + click to select a range
6ab48b0
feat: Support multi-resolution training with caching latents to disk
kohya-ss Aug 20, 2024
7e459c0
Update T5 attention mask handling in FLUX
kohya-ss Aug 20, 2024
e17c42c
Add BFL/Diffusers LoRA converter #1467 #1458 #1483
kohya-ss Aug 21, 2024
2b07a92
Fix error in applying mask in Attention and add LoRA converter script
kohya-ss Aug 21, 2024
e1cd19c
add stochastic rounding, fix single block
kohya-ss Aug 21, 2024
98c91a7
Fix bug in FLUX multi GPU training
kohya-ss Aug 22, 2024
a4d27a2
Fix --debug_dataset to work.
kohya-ss Aug 22, 2024
2d8fa33
Fix to remove zero pad for t5xxl output
kohya-ss Aug 22, 2024
b0a9808
added a script to extract LoRA
kohya-ss Aug 22, 2024
bf9f798
chore: fix typos, remove debug print
kohya-ss Aug 22, 2024
99744af
Merge branch 'dev' into sd3
kohya-ss Aug 22, 2024
81411a3
speed up getting image sizes
kohya-ss Aug 22, 2024
2e89cd2
Fix issue with attention mask not being applied in single blocks
kohya-ss Aug 24, 2024
cf689e7
feat: Add option to split projection layers and apply LoRA
kohya-ss Aug 24, 2024
5639c2a
fix typo
kohya-ss Aug 24, 2024
ea92426
Merge branch 'dev' into sd3
kohya-ss Aug 24, 2024
72287d3
feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuni…
kohya-ss Aug 25, 2024
0087a46
FLUX.1 LoRA supports CLIP-L
kohya-ss Aug 27, 2024
3be712e
feat: Update direct loading fp8 ckpt for LoRA training
kohya-ss Aug 27, 2024
a61cf73
update readme
kohya-ss Aug 27, 2024
6c0e8a5
make guidance_scale keep float in args
Akegarasu Aug 29, 2024
a0cfb08
Cleaned up README
kohya-ss Aug 29, 2024
daa6ad5
Update README.md
kohya-ss Aug 29, 2024
930d709
Merge pull request #1525 from Akegarasu/sd3
kohya-ss Aug 29, 2024
8ecf0fc
Refactor code to ensure args.guidance_scale is always a float #1525
kohya-ss Aug 29, 2024
8fdfd8c
Update safetensors to version 0.4.4 in requirements.txt #1524
kohya-ss Aug 29, 2024
34f2315
fix: text_encoder_conds referenced before assignment
Akegarasu Aug 29, 2024
35882f8
fix
Akegarasu Aug 29, 2024
25c9040
Update flux_train_utils.py
sdbds Aug 30, 2024
928e0fc
Merge pull request #1529 from Akegarasu/sd3
kohya-ss Sep 1, 2024
ef510b3
Sd3 freeze x_block (#1417)
sdbds Sep 1, 2024
92e7600
Move freeze_blocks to sd3_train because it's only for sd3
kohya-ss Sep 1, 2024
1e30aa8
Merge pull request #1541 from sdbds/flux_shift
kohya-ss Sep 1, 2024
4f6d915
update help and README
kohya-ss Sep 1, 2024
6abacf0
update README
kohya-ss Sep 2, 2024
b65ae9b
T5XXL LoRA training, fp8 T5XXL support
kohya-ss Sep 4, 2024
b7cff0a
update README
kohya-ss Sep 4, 2024
56cb2fc
support T5XXL LoRA, reduce peak memory usage #1560
kohya-ss Sep 4, 2024
90ed2df
feat: Add support for merging CLIP-L and T5XXL LoRA models
kohya-ss Sep 4, 2024
d912952
set dtype before calling ae closes #1562
kohya-ss Sep 5, 2024
2889108
feat: Add --cpu_offload_checkpointing option to LoRA training
kohya-ss Sep 5, 2024
ce14447
Merge branch 'dev' into sd3
kohya-ss Sep 7, 2024
d29af14
add negative prompt for flux inference script
kohya-ss Sep 9, 2024
d10ff62
support individual LR for CLIP-L/T5XXL
kohya-ss Sep 10, 2024
65b8a06
update README
kohya-ss Sep 10, 2024
eaafa5c
Merge branch 'dev' into sd3
kohya-ss Sep 11, 2024
8311e88
typo fix
cocktailpeanut Sep 11, 2024
d83f2e9
Merge pull request #1592 from cocktailpeanut/sd3
kohya-ss Sep 11, 2024
a823fd9
Improve wandb logging (#1576)
p1atdev Sep 11, 2024
237317f
update README
kohya-ss Sep 11, 2024
cefe526
fix to work old notation for TE LR in .toml
kohya-ss Sep 12, 2024
f3ce80e
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
c15a3a1
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
0485f23
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
2d8ee3c
OFT for FLUX.1
kohya-ss Sep 14, 2024
c9ff4de
Add support for specifying rank for each layer in FLUX.1
kohya-ss Sep 14, 2024
6445bb2
update README
kohya-ss Sep 14, 2024
9f44ef1
add diffusers to FLUX.1 conversion script
kohya-ss Sep 15, 2024
be078bd
fix typo
kohya-ss Sep 15, 2024
96c677b
fix to work lienar/cosine lr scheduler closes #1602 ref #1393
kohya-ss Sep 16, 2024
d8d15f1
add support for specifying blocks in FLUX.1 LoRA training
kohya-ss Sep 16, 2024
0cbe95b
fix text_encoder_lr to work with int closes #1608
kohya-ss Sep 17, 2024
a2ad7e5
blocks_to_swap=0 means no swap
kohya-ss Sep 17, 2024
bbd160b
sd3 schedule free opt (#1605)
kohya-ss Sep 17, 2024
e745021
update README
kohya-ss Sep 17, 2024
1286e00
fix to call train/eval in schedulefree #1605
kohya-ss Sep 18, 2024
706a48d
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
b844c70
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
3957372
Retain alpha in `pil_resize`
emcmanus Sep 19, 2024
de4bb65
Update utils.py
emcmanus Sep 19, 2024
0535cd2
fix: backward compatibility for text_encoder_lr
Akegarasu Sep 20, 2024
24f8975
Merge pull request #1620 from Akegarasu/sd3
kohya-ss Sep 20, 2024
583d4a4
add compatibility for int LR (D-Adaptation etc.) #1620
kohya-ss Sep 20, 2024
95ff9db
Merge pull request #1619 from emcmanus/patch-1
kohya-ss Sep 20, 2024
fba7692
Merge branch 'dev' into sd3
kohya-ss Sep 23, 2024
65fb69f
Merge branch 'dev' into sd3
kohya-ss Sep 25, 2024
56a7bc1
new block swap for FLUX.1 fine tuning
kohya-ss Sep 25, 2024
da94fd9
fix typos
kohya-ss Sep 25, 2024
2cd6aa2
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
392e8de
fix flip_aug, alpha_mask, random_crop issue in caching in caching str…
kohya-ss Sep 26, 2024
3ebb65f
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
9249d00
experimental support for multi-gpus latents caching
kohya-ss Sep 26, 2024
24b1fdb
remove debug print
kohya-ss Sep 26, 2024
a9aa526
fix sample generation is not working in FLUX1 fine tuning #1647
kohya-ss Sep 28, 2024
822fe57
add workaround for 'Some tensors share memory' error #1614
kohya-ss Sep 28, 2024
1a0f5b0
re-fix sample generation is not working in FLUX1 split mode #1647
kohya-ss Sep 28, 2024
d050638
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
e0c3630
Support Sdxl Controlnet (#1648)
sdbds Sep 29, 2024
56a63f0
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Sep 29, 2024
8919b31
use original ControlNet instead of Diffusers
kohya-ss Sep 29, 2024
0243c65
fix typo
kohya-ss Sep 29, 2024
8bea039
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
d78f6a7
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Sep 29, 2024
793999d
sample generation in SDXL ControlNet training
kohya-ss Sep 30, 2024
33e942e
Merge branch 'sd3' into fast_image_sizes
kohya-ss Sep 30, 2024
c2440f9
fix cond image normlization, add independent LR for control
kohya-ss Oct 3, 2024
ba08a89
call optimizer eval/train for sample_at_first, also set train after r…
kohya-ss Oct 4, 2024
83e3048
load Diffusers format, check schnell/dev
kohya-ss Oct 6, 2024
126159f
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Oct 7, 2024
886f753
support weighted captions for sdxl LoRA and fine tuning
kohya-ss Oct 9, 2024
3de42b6
fix: distributed training in windows
Akegarasu Oct 10, 2024
9f4dac5
torch 2.4
Akegarasu Oct 10, 2024
f2bc820
support weighted captions for SD/SDXL
kohya-ss Oct 10, 2024
035c4a8
update docs and help text
kohya-ss Oct 11, 2024
43bfeea
Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
kohya-ss Oct 11, 2024
d005652
Merge pull request #1686 from Akegarasu/sd3
kohya-ss Oct 12, 2024
0d3058b
update README
kohya-ss Oct 12, 2024
ff4083b
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 12, 2024
c80c304
Refactor caching in train scripts
kohya-ss Oct 12, 2024
ecaea90
update README
kohya-ss Oct 12, 2024
e277b57
Update FLUX.1 support for compact models
kohya-ss Oct 12, 2024
5bb9f7f
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
74228c9
update cache_latents/text_encoder_outputs
kohya-ss Oct 13, 2024
c65cf38
Merge branch 'sd3' into fast_image_sizes
kohya-ss Oct 13, 2024
2244cf5
load images in parallel when caching latents
kohya-ss Oct 13, 2024
bfc3a65
fix to work cache latents/text encoder outputs
kohya-ss Oct 13, 2024
d02a6ef
Merge pull request #1660 from kohya-ss/fast_image_sizes
kohya-ss Oct 13, 2024
886ffb4
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
2d5f7fa
update README
kohya-ss Oct 13, 2024
1275e14
Merge pull request #1690 from kohya-ss/multi-gpu-caching
kohya-ss Oct 13, 2024
2500f5a
fix latents caching not working closes #1696
kohya-ss Oct 14, 2024
3cc5b8d
Diff Output Preserv loss for SDXL
kohya-ss Oct 18, 2024
d8d7142
fix to work caching latents #1696
kohya-ss Oct 18, 2024
ef70aa7
add FLUX.1 support
kohya-ss Oct 18, 2024
2c45d97
update README, remove unnecessary autocast
kohya-ss Oct 19, 2024
09b4d1e
Merge branch 'sd3' into diff_output_prsv
kohya-ss Oct 19, 2024
aa93242
Merge pull request #1710 from kohya-ss/diff_output_prsv
kohya-ss Oct 19, 2024
7fe8e16
fix to work ControlNetSubset with custom_attributes
kohya-ss Oct 19, 2024
138dac4
update README
kohya-ss Oct 20, 2024
623017f
refactor SD3 CLIP to transformers etc.
kohya-ss Oct 24, 2024
e3c43bd
reduce memory usage in sample image generation
kohya-ss Oct 24, 2024
0286114
support SD3.5L, fix final saving
kohya-ss Oct 24, 2024
f8c5146
support block swap with fused_optimizer_pass
kohya-ss Oct 24, 2024
5fba6f5
Merge branch 'dev' into sd3
kohya-ss Oct 25, 2024
f52fb66
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 25, 2024
d2c549d
support SD3 LoRA
kohya-ss Oct 25, 2024
0031d91
add latent scaling/shifting
kohya-ss Oct 25, 2024
56bf761
fix errors in SD3 LoRA training with Text Encoders close #1724
kohya-ss Oct 26, 2024
014064f
fix sample image generation without seed failed close #1726
kohya-ss Oct 26, 2024
8549669
Merge branch 'dev' into sd3
kohya-ss Oct 26, 2024
150579d
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 26, 2024
731664b
Merge branch 'dev' into sd3
kohya-ss Oct 27, 2024
b649bbf
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 27, 2024
db2b4d4
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encod…
kohya-ss Oct 27, 2024
a1255d6
Fix SD3 LoRA training to work (WIP)
kohya-ss Oct 27, 2024
d4f7849
prevent unintended cast for disk cached TE outputs
kohya-ss Oct 27, 2024
1065dd1
Fix to work dropout_rate for TEs
kohya-ss Oct 27, 2024
af8e216
Fix sample image gen to work with block swap
kohya-ss Oct 28, 2024
7555486
Fix error on saving T5XXL
kohya-ss Oct 28, 2024
0af4edd
Fix split_qkv
kohya-ss Oct 29, 2024
d4e19fb
Support Lora
kohya-ss Oct 29, 2024
80bb3f4
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 29, 2024
1e2f7b0
Support for checkpoint files with a mysterious prefix "model.diffusio…
kohya-ss Oct 29, 2024
ce5b532
Fix additional LoRA to work
kohya-ss Oct 29, 2024
c9a1417
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 29, 2024
b502f58
Fix emb_dim to work.
kohya-ss Oct 29, 2024
bdddc20
support SD3.5M
kohya-ss Oct 30, 2024
8c3c825
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 30, 2024
70a179e
Fix to use SDPA instead of xformers
kohya-ss Oct 30, 2024
1434d85
Support SD3.5M multi resolutional training
kohya-ss Oct 31, 2024
9e23368
Update SD3 training
kohya-ss Oct 31, 2024
830df4a
Fix crashing if image is too tall or wide.
kohya-ss Oct 31, 2024
9aa6f52
Fix memory leak in latent caching. bmp failed to cache
kohya-ss Nov 1, 2024
82daa98
remove duplicate resolution for scaled pos embed
kohya-ss Nov 1, 2024
264328d
Merge pull request #1719 from kohya-ss/sd3_5_support
kohya-ss Nov 1, 2024
e0db596
update multi-res training in SD3.5M
kohya-ss Nov 2, 2024
5e32ee2
fix crashing in DDP training closes #1751
kohya-ss Nov 2, 2024
81c0c96
faster block swap
kohya-ss Nov 5, 2024
aab943c
remove unused weight swapping functions from utils.py
kohya-ss Nov 5, 2024
4384903
Fix to work without latent cache #1758
kohya-ss Nov 6, 2024
40ed54b
Simplify Timestep weighting
Dango233 Nov 7, 2024
e54462a
Fix SD3 trained lora loading and merging
Dango233 Nov 7, 2024
bafd10d
Fix typo
Dango233 Nov 7, 2024
588ea9e
Merge pull request #1768 from Dango233/dango/timesteps_fix
kohya-ss Nov 7, 2024
5e86323
Update README and clean-up the code for SD3 timesteps
kohya-ss Nov 7, 2024
f264f40
Update README.md
kohya-ss Nov 7, 2024
5eb6d20
Update README.md
Dango233 Nov 7, 2024
387b40e
Merge pull request #1769 from Dango233/patch-1
kohya-ss Nov 7, 2024
e877b30
Merge branch 'dev' into sd3
kohya-ss Nov 7, 2024
123474d
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
kohya-ss Nov 7, 2024
b8d3fec
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 7, 2024
186aa5b
fix illeagal block is swapped #1764
kohya-ss Nov 7, 2024
b3248a8
fix: sort order when getting image size from cache file
feffy380 Nov 7, 2024
2a2042a
Merge pull request #1770 from feffy380/fix-size-from-cache
kohya-ss Nov 9, 2024
8fac3c3
update README
kohya-ss Nov 9, 2024
26bd454
init
sdbds Nov 11, 2024
02bd76e
Refactor block swapping to utilize custom offloading utilities
kohya-ss Nov 11, 2024
7feaae5
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
92482c7
Merge pull request #1774 from sdbds/avif_get_imagesize
kohya-ss Nov 11, 2024
3fe94b0
update comment
kohya-ss Nov 11, 2024
cde90b8
feat: implement block swapping for FLUX.1 LoRA (WIP)
kohya-ss Nov 11, 2024
17cf249
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
2cb7a6d
feat: add block swap for FLUX.1/SD3 LoRA training
kohya-ss Nov 12, 2024
2bb0f54
update grad hook creation to fix TE lr in sd3 fine tuning
kohya-ss Nov 14, 2024
5c5b544
refactor: remove unused prepare_split_model method from FluxNetworkTr…
kohya-ss Nov 14, 2024
fd2d879
docs: update README
kohya-ss Nov 14, 2024
0047bb1
Merge pull request #1779 from kohya-ss/faster-block-swap
kohya-ss Nov 14, 2024
ccfaa00
add flux controlnet base module
minux302 Nov 15, 2024
42f6edf
fix for adding controlnet
minux302 Nov 15, 2024
e358b11
fix dataloader
minux302 Nov 16, 2024
2a188f0
Fix to work DOP with bock swap
kohya-ss Nov 17, 2024
b2660bb
train run
minux302 Nov 17, 2024
35778f0
fix sample_images type
minux302 Nov 17, 2024
4dd4cd6
work cn load and validation
minux302 Nov 18, 2024
31ca899
fix depth value
minux302 Nov 18, 2024
2a61fc0
docs: fix typo from block_to_swap to blocks_to_swap in README
kohya-ss Nov 20, 2024
0b5229a
save cn
minux302 Nov 21, 2024
420a180
Implement pseudo Huber loss for Flux and SD3
recris Nov 27, 2024
740ec1d
Fix issues found in review
recris Nov 28, 2024
9dff44d
fix device
minux302 Nov 29, 2024
575f583
add README
minux302 Nov 29, 2024
be5860f
add schnell option to load_cn
minux302 Nov 29, 2024
f40632b
rm abundant arg
minux302 Nov 29, 2024
928b939
Allow unknown schedule-free optimizers to continue to module loader
rockerBOO Nov 20, 2024
87f5224
Support d*lr for ProdigyPlus optimizer
rockerBOO Nov 20, 2024
6593cfb
Fix d * lr step log
rockerBOO Nov 21, 2024
c7cadbc
Add pytest testing
rockerBOO Nov 29, 2024
2dd063a
add torch torchvision accelerate versions
rockerBOO Nov 29, 2024
e59e276
Add dadaptation
rockerBOO Nov 29, 2024
dd3b846
Install pytorch first to pin version
rockerBOO Nov 29, 2024
89825d6
Run typos workflows once where appropriate
rockerBOO Nov 29, 2024
4f7f248
Bump typos action
rockerBOO Nov 29, 2024
9c885e5
fix: improve pos_embed handling for oversized images and update resol…
kohya-ss Nov 30, 2024
7b61e9e
Fix issues found in review (pt 2)
recris Nov 30, 2024
a5a27fe
Merge pull request #1808 from recris/huber-loss-flux
kohya-ss Dec 1, 2024
14f642f
fix: huber_schedule exponential not working on sd3_train.py
kohya-ss Dec 1, 2024
0fe6320
fix flux_train.py is not working
kohya-ss Dec 1, 2024
cc11989
fix: refactor huber-loss calculation in multiple training scripts
kohya-ss Dec 1, 2024
1476040
fix: update help text for huber loss parameters in train_util.py
kohya-ss Dec 1, 2024
bdf9a8c
Merge pull request #1815 from kohya-ss/flux-huber-loss
kohya-ss Dec 1, 2024
34e7f50
docs: update README for huber loss
kohya-ss Dec 1, 2024
14c9ba9
Merge pull request #1811 from rockerBOO/schedule-free-prodigy
kohya-ss Dec 1, 2024
1dc873d
update README and clean up code for schedulefree optimizer
kohya-ss Dec 1, 2024
e3fd6c5
Merge pull request #1812 from rockerBOO/tests
kohya-ss Dec 2, 2024
09a3740
Merge pull request #1813 from minux302/flux-controlnet
kohya-ss Dec 2, 2024
e369b9a
docs: update README with FLUX.1 ControlNet training details and impro…
kohya-ss Dec 2, 2024
5ab00f9
Update workflow tests with cleanup and documentation
rockerBOO Dec 2, 2024
63738ec
Add tests documentation
rockerBOO Dec 2, 2024
2610e96
Pytest
rockerBOO Dec 2, 2024
3e5d89c
Add more resources
rockerBOO Dec 2, 2024
8b36d90
feat: support block_to_swap for FLUX.1 ControlNet training
kohya-ss Dec 2, 2024
6bee18d
fix: resolve model corruption issue with pos_embed when using --enabl…
kohya-ss Dec 7, 2024
2be3366
Merge pull request #1817 from rockerBOO/workflow-tests-fixes
kohya-ss Dec 7, 2024
abff4b0
Unify controlnet parameters name and change scripts name. (#1821)
sdbds Dec 7, 2024
e425996
feat: unify ControlNet model name option and deprecate old training s…
kohya-ss Dec 7, 2024
3cb8cb2
Prevent git credentials from leaking into other actions
rockerBOO Dec 9, 2024
8e378cf
add RAdamScheduleFree support
nhamanasu Dec 11, 2024
d3305f9
Merge pull request #1828 from rockerBOO/workflow-security-audit
kohya-ss Dec 15, 2024
f2d38e6
Merge pull request #1830 from nhamanasu/sd3
kohya-ss Dec 15, 2024
e896539
update requirements.txt and README to include RAdamScheduleFree optim…
kohya-ss Dec 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,134 @@
This repository contains training, generation and utility scripts for Stable Diffusion.

## FLUX.1 LoRA training (WIP)

This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.

__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__

The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

Aug 16, 2024:

FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model.

Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell.

Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training.

Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified.

Aug 13, 2024:

__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`.

This argument is available even if `--split_mode` is not specified.

__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments.

This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default.

Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.

Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.

We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2
```

The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below:

```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
```

The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below:

```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single"
```

LoRAs for Text Encoders are not tested yet.

We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows:

- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux).
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3).
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).

`--loss_type` may be useful for FLUX.1 training. The default is `l2`.

In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.

additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work!

Other settings may work better, so please try different settings.

We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.

The trained LoRA model can be used with ComfyUI.

The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.

Aug 12: `--interactive` option is now working.

```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```

## SD3 training

SD3 training is done with `sd3_train.py`.

__Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
- With this change, dataset initialization is significantly faster, especially for large datasets.

- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.

- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.

---

`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.

`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.

`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.

t5xxl works with `fp16` now.

There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.

`text_encoder_batch_size` is added experimentally for caching faster.

```toml
learning_rate = 1e-6 # seems to depend on the batch size
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1
text_encoder_batch_size = 4
cache_latents = true
cache_latents_to_disk = true
```

__2024/7/27:__

Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。

データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。

SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。

---

[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。

Expand Down
54 changes: 39 additions & 15 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -39,6 +39,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd


def train(args):
Expand All @@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -81,10 +90,10 @@ def train(args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -165,8 +174,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)

train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)

vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand All @@ -192,6 +202,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
text_encoder.eval()

text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

if not cache_latents:
vae.requires_grad_(False)
vae.eval()
Expand All @@ -214,7 +227,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -317,7 +334,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand All @@ -342,19 +361,22 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down Expand Up @@ -409,7 +431,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

# 指定ステップごとにモデルを保存
Expand Down Expand Up @@ -472,7 +494,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

is_main_process = accelerator.is_main_process
if is_main_process:
Expand Down
Loading