Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 313 commits into
base: dev
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
313 commits
Select commit Hold shift + click to select a range
6ab48b0
feat: Support multi-resolution training with caching latents to disk
kohya-ss Aug 20, 2024
7e459c0
Update T5 attention mask handling in FLUX
kohya-ss Aug 20, 2024
e17c42c
Add BFL/Diffusers LoRA converter #1467 #1458 #1483
kohya-ss Aug 21, 2024
2b07a92
Fix error in applying mask in Attention and add LoRA converter script
kohya-ss Aug 21, 2024
e1cd19c
add stochastic rounding, fix single block
kohya-ss Aug 21, 2024
98c91a7
Fix bug in FLUX multi GPU training
kohya-ss Aug 22, 2024
a4d27a2
Fix --debug_dataset to work.
kohya-ss Aug 22, 2024
2d8fa33
Fix to remove zero pad for t5xxl output
kohya-ss Aug 22, 2024
b0a9808
added a script to extract LoRA
kohya-ss Aug 22, 2024
bf9f798
chore: fix typos, remove debug print
kohya-ss Aug 22, 2024
99744af
Merge branch 'dev' into sd3
kohya-ss Aug 22, 2024
81411a3
speed up getting image sizes
kohya-ss Aug 22, 2024
2e89cd2
Fix issue with attention mask not being applied in single blocks
kohya-ss Aug 24, 2024
cf689e7
feat: Add option to split projection layers and apply LoRA
kohya-ss Aug 24, 2024
5639c2a
fix typo
kohya-ss Aug 24, 2024
ea92426
Merge branch 'dev' into sd3
kohya-ss Aug 24, 2024
72287d3
feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuni…
kohya-ss Aug 25, 2024
0087a46
FLUX.1 LoRA supports CLIP-L
kohya-ss Aug 27, 2024
3be712e
feat: Update direct loading fp8 ckpt for LoRA training
kohya-ss Aug 27, 2024
a61cf73
update readme
kohya-ss Aug 27, 2024
6c0e8a5
make guidance_scale keep float in args
Akegarasu Aug 29, 2024
a0cfb08
Cleaned up README
kohya-ss Aug 29, 2024
daa6ad5
Update README.md
kohya-ss Aug 29, 2024
930d709
Merge pull request #1525 from Akegarasu/sd3
kohya-ss Aug 29, 2024
8ecf0fc
Refactor code to ensure args.guidance_scale is always a float #1525
kohya-ss Aug 29, 2024
8fdfd8c
Update safetensors to version 0.4.4 in requirements.txt #1524
kohya-ss Aug 29, 2024
34f2315
fix: text_encoder_conds referenced before assignment
Akegarasu Aug 29, 2024
35882f8
fix
Akegarasu Aug 29, 2024
25c9040
Update flux_train_utils.py
sdbds Aug 30, 2024
928e0fc
Merge pull request #1529 from Akegarasu/sd3
kohya-ss Sep 1, 2024
ef510b3
Sd3 freeze x_block (#1417)
sdbds Sep 1, 2024
92e7600
Move freeze_blocks to sd3_train because it's only for sd3
kohya-ss Sep 1, 2024
1e30aa8
Merge pull request #1541 from sdbds/flux_shift
kohya-ss Sep 1, 2024
4f6d915
update help and README
kohya-ss Sep 1, 2024
6abacf0
update README
kohya-ss Sep 2, 2024
b65ae9b
T5XXL LoRA training, fp8 T5XXL support
kohya-ss Sep 4, 2024
b7cff0a
update README
kohya-ss Sep 4, 2024
56cb2fc
support T5XXL LoRA, reduce peak memory usage #1560
kohya-ss Sep 4, 2024
90ed2df
feat: Add support for merging CLIP-L and T5XXL LoRA models
kohya-ss Sep 4, 2024
d912952
set dtype before calling ae closes #1562
kohya-ss Sep 5, 2024
2889108
feat: Add --cpu_offload_checkpointing option to LoRA training
kohya-ss Sep 5, 2024
ce14447
Merge branch 'dev' into sd3
kohya-ss Sep 7, 2024
d29af14
add negative prompt for flux inference script
kohya-ss Sep 9, 2024
d10ff62
support individual LR for CLIP-L/T5XXL
kohya-ss Sep 10, 2024
65b8a06
update README
kohya-ss Sep 10, 2024
eaafa5c
Merge branch 'dev' into sd3
kohya-ss Sep 11, 2024
8311e88
typo fix
cocktailpeanut Sep 11, 2024
d83f2e9
Merge pull request #1592 from cocktailpeanut/sd3
kohya-ss Sep 11, 2024
a823fd9
Improve wandb logging (#1576)
p1atdev Sep 11, 2024
237317f
update README
kohya-ss Sep 11, 2024
cefe526
fix to work old notation for TE LR in .toml
kohya-ss Sep 12, 2024
f3ce80e
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
c15a3a1
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
0485f23
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
2d8ee3c
OFT for FLUX.1
kohya-ss Sep 14, 2024
c9ff4de
Add support for specifying rank for each layer in FLUX.1
kohya-ss Sep 14, 2024
6445bb2
update README
kohya-ss Sep 14, 2024
9f44ef1
add diffusers to FLUX.1 conversion script
kohya-ss Sep 15, 2024
be078bd
fix typo
kohya-ss Sep 15, 2024
96c677b
fix to work lienar/cosine lr scheduler closes #1602 ref #1393
kohya-ss Sep 16, 2024
d8d15f1
add support for specifying blocks in FLUX.1 LoRA training
kohya-ss Sep 16, 2024
0cbe95b
fix text_encoder_lr to work with int closes #1608
kohya-ss Sep 17, 2024
a2ad7e5
blocks_to_swap=0 means no swap
kohya-ss Sep 17, 2024
bbd160b
sd3 schedule free opt (#1605)
kohya-ss Sep 17, 2024
e745021
update README
kohya-ss Sep 17, 2024
1286e00
fix to call train/eval in schedulefree #1605
kohya-ss Sep 18, 2024
706a48d
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
b844c70
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
3957372
Retain alpha in `pil_resize`
emcmanus Sep 19, 2024
de4bb65
Update utils.py
emcmanus Sep 19, 2024
0535cd2
fix: backward compatibility for text_encoder_lr
Akegarasu Sep 20, 2024
24f8975
Merge pull request #1620 from Akegarasu/sd3
kohya-ss Sep 20, 2024
583d4a4
add compatibility for int LR (D-Adaptation etc.) #1620
kohya-ss Sep 20, 2024
95ff9db
Merge pull request #1619 from emcmanus/patch-1
kohya-ss Sep 20, 2024
fba7692
Merge branch 'dev' into sd3
kohya-ss Sep 23, 2024
65fb69f
Merge branch 'dev' into sd3
kohya-ss Sep 25, 2024
56a7bc1
new block swap for FLUX.1 fine tuning
kohya-ss Sep 25, 2024
da94fd9
fix typos
kohya-ss Sep 25, 2024
2cd6aa2
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
392e8de
fix flip_aug, alpha_mask, random_crop issue in caching in caching str…
kohya-ss Sep 26, 2024
3ebb65f
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
9249d00
experimental support for multi-gpus latents caching
kohya-ss Sep 26, 2024
24b1fdb
remove debug print
kohya-ss Sep 26, 2024
a9aa526
fix sample generation is not working in FLUX1 fine tuning #1647
kohya-ss Sep 28, 2024
822fe57
add workaround for 'Some tensors share memory' error #1614
kohya-ss Sep 28, 2024
1a0f5b0
re-fix sample generation is not working in FLUX1 split mode #1647
kohya-ss Sep 28, 2024
d050638
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
e0c3630
Support Sdxl Controlnet (#1648)
sdbds Sep 29, 2024
56a63f0
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Sep 29, 2024
8919b31
use original ControlNet instead of Diffusers
kohya-ss Sep 29, 2024
0243c65
fix typo
kohya-ss Sep 29, 2024
8bea039
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
d78f6a7
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Sep 29, 2024
793999d
sample generation in SDXL ControlNet training
kohya-ss Sep 30, 2024
33e942e
Merge branch 'sd3' into fast_image_sizes
kohya-ss Sep 30, 2024
c2440f9
fix cond image normlization, add independent LR for control
kohya-ss Oct 3, 2024
ba08a89
call optimizer eval/train for sample_at_first, also set train after r…
kohya-ss Oct 4, 2024
83e3048
load Diffusers format, check schnell/dev
kohya-ss Oct 6, 2024
126159f
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Oct 7, 2024
886f753
support weighted captions for sdxl LoRA and fine tuning
kohya-ss Oct 9, 2024
3de42b6
fix: distributed training in windows
Akegarasu Oct 10, 2024
9f4dac5
torch 2.4
Akegarasu Oct 10, 2024
f2bc820
support weighted captions for SD/SDXL
kohya-ss Oct 10, 2024
035c4a8
update docs and help text
kohya-ss Oct 11, 2024
43bfeea
Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
kohya-ss Oct 11, 2024
d005652
Merge pull request #1686 from Akegarasu/sd3
kohya-ss Oct 12, 2024
0d3058b
update README
kohya-ss Oct 12, 2024
ff4083b
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 12, 2024
c80c304
Refactor caching in train scripts
kohya-ss Oct 12, 2024
ecaea90
update README
kohya-ss Oct 12, 2024
e277b57
Update FLUX.1 support for compact models
kohya-ss Oct 12, 2024
5bb9f7f
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
74228c9
update cache_latents/text_encoder_outputs
kohya-ss Oct 13, 2024
c65cf38
Merge branch 'sd3' into fast_image_sizes
kohya-ss Oct 13, 2024
2244cf5
load images in parallel when caching latents
kohya-ss Oct 13, 2024
bfc3a65
fix to work cache latents/text encoder outputs
kohya-ss Oct 13, 2024
d02a6ef
Merge pull request #1660 from kohya-ss/fast_image_sizes
kohya-ss Oct 13, 2024
886ffb4
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
2d5f7fa
update README
kohya-ss Oct 13, 2024
1275e14
Merge pull request #1690 from kohya-ss/multi-gpu-caching
kohya-ss Oct 13, 2024
2500f5a
fix latents caching not working closes #1696
kohya-ss Oct 14, 2024
3cc5b8d
Diff Output Preserv loss for SDXL
kohya-ss Oct 18, 2024
d8d7142
fix to work caching latents #1696
kohya-ss Oct 18, 2024
ef70aa7
add FLUX.1 support
kohya-ss Oct 18, 2024
2c45d97
update README, remove unnecessary autocast
kohya-ss Oct 19, 2024
09b4d1e
Merge branch 'sd3' into diff_output_prsv
kohya-ss Oct 19, 2024
aa93242
Merge pull request #1710 from kohya-ss/diff_output_prsv
kohya-ss Oct 19, 2024
7fe8e16
fix to work ControlNetSubset with custom_attributes
kohya-ss Oct 19, 2024
138dac4
update README
kohya-ss Oct 20, 2024
623017f
refactor SD3 CLIP to transformers etc.
kohya-ss Oct 24, 2024
e3c43bd
reduce memory usage in sample image generation
kohya-ss Oct 24, 2024
0286114
support SD3.5L, fix final saving
kohya-ss Oct 24, 2024
f8c5146
support block swap with fused_optimizer_pass
kohya-ss Oct 24, 2024
5fba6f5
Merge branch 'dev' into sd3
kohya-ss Oct 25, 2024
f52fb66
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 25, 2024
d2c549d
support SD3 LoRA
kohya-ss Oct 25, 2024
0031d91
add latent scaling/shifting
kohya-ss Oct 25, 2024
56bf761
fix errors in SD3 LoRA training with Text Encoders close #1724
kohya-ss Oct 26, 2024
014064f
fix sample image generation without seed failed close #1726
kohya-ss Oct 26, 2024
8549669
Merge branch 'dev' into sd3
kohya-ss Oct 26, 2024
150579d
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 26, 2024
731664b
Merge branch 'dev' into sd3
kohya-ss Oct 27, 2024
b649bbf
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 27, 2024
db2b4d4
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encod…
kohya-ss Oct 27, 2024
a1255d6
Fix SD3 LoRA training to work (WIP)
kohya-ss Oct 27, 2024
d4f7849
prevent unintended cast for disk cached TE outputs
kohya-ss Oct 27, 2024
1065dd1
Fix to work dropout_rate for TEs
kohya-ss Oct 27, 2024
af8e216
Fix sample image gen to work with block swap
kohya-ss Oct 28, 2024
7555486
Fix error on saving T5XXL
kohya-ss Oct 28, 2024
0af4edd
Fix split_qkv
kohya-ss Oct 29, 2024
d4e19fb
Support Lora
kohya-ss Oct 29, 2024
80bb3f4
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 29, 2024
1e2f7b0
Support for checkpoint files with a mysterious prefix "model.diffusio…
kohya-ss Oct 29, 2024
ce5b532
Fix additional LoRA to work
kohya-ss Oct 29, 2024
c9a1417
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 29, 2024
b502f58
Fix emb_dim to work.
kohya-ss Oct 29, 2024
bdddc20
support SD3.5M
kohya-ss Oct 30, 2024
8c3c825
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 30, 2024
70a179e
Fix to use SDPA instead of xformers
kohya-ss Oct 30, 2024
1434d85
Support SD3.5M multi resolutional training
kohya-ss Oct 31, 2024
9e23368
Update SD3 training
kohya-ss Oct 31, 2024
830df4a
Fix crashing if image is too tall or wide.
kohya-ss Oct 31, 2024
9aa6f52
Fix memory leak in latent caching. bmp failed to cache
kohya-ss Nov 1, 2024
82daa98
remove duplicate resolution for scaled pos embed
kohya-ss Nov 1, 2024
264328d
Merge pull request #1719 from kohya-ss/sd3_5_support
kohya-ss Nov 1, 2024
e0db596
update multi-res training in SD3.5M
kohya-ss Nov 2, 2024
5e32ee2
fix crashing in DDP training closes #1751
kohya-ss Nov 2, 2024
81c0c96
faster block swap
kohya-ss Nov 5, 2024
aab943c
remove unused weight swapping functions from utils.py
kohya-ss Nov 5, 2024
4384903
Fix to work without latent cache #1758
kohya-ss Nov 6, 2024
40ed54b
Simplify Timestep weighting
Dango233 Nov 7, 2024
e54462a
Fix SD3 trained lora loading and merging
Dango233 Nov 7, 2024
bafd10d
Fix typo
Dango233 Nov 7, 2024
588ea9e
Merge pull request #1768 from Dango233/dango/timesteps_fix
kohya-ss Nov 7, 2024
5e86323
Update README and clean-up the code for SD3 timesteps
kohya-ss Nov 7, 2024
f264f40
Update README.md
kohya-ss Nov 7, 2024
5eb6d20
Update README.md
Dango233 Nov 7, 2024
387b40e
Merge pull request #1769 from Dango233/patch-1
kohya-ss Nov 7, 2024
e877b30
Merge branch 'dev' into sd3
kohya-ss Nov 7, 2024
123474d
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
kohya-ss Nov 7, 2024
b8d3fec
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 7, 2024
186aa5b
fix illeagal block is swapped #1764
kohya-ss Nov 7, 2024
b3248a8
fix: sort order when getting image size from cache file
feffy380 Nov 7, 2024
2a2042a
Merge pull request #1770 from feffy380/fix-size-from-cache
kohya-ss Nov 9, 2024
8fac3c3
update README
kohya-ss Nov 9, 2024
26bd454
init
sdbds Nov 11, 2024
02bd76e
Refactor block swapping to utilize custom offloading utilities
kohya-ss Nov 11, 2024
7feaae5
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
92482c7
Merge pull request #1774 from sdbds/avif_get_imagesize
kohya-ss Nov 11, 2024
3fe94b0
update comment
kohya-ss Nov 11, 2024
cde90b8
feat: implement block swapping for FLUX.1 LoRA (WIP)
kohya-ss Nov 11, 2024
17cf249
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
2cb7a6d
feat: add block swap for FLUX.1/SD3 LoRA training
kohya-ss Nov 12, 2024
2bb0f54
update grad hook creation to fix TE lr in sd3 fine tuning
kohya-ss Nov 14, 2024
5c5b544
refactor: remove unused prepare_split_model method from FluxNetworkTr…
kohya-ss Nov 14, 2024
fd2d879
docs: update README
kohya-ss Nov 14, 2024
0047bb1
Merge pull request #1779 from kohya-ss/faster-block-swap
kohya-ss Nov 14, 2024
ccfaa00
add flux controlnet base module
minux302 Nov 15, 2024
42f6edf
fix for adding controlnet
minux302 Nov 15, 2024
e358b11
fix dataloader
minux302 Nov 16, 2024
2a188f0
Fix to work DOP with bock swap
kohya-ss Nov 17, 2024
b2660bb
train run
minux302 Nov 17, 2024
35778f0
fix sample_images type
minux302 Nov 17, 2024
4dd4cd6
work cn load and validation
minux302 Nov 18, 2024
31ca899
fix depth value
minux302 Nov 18, 2024
2a61fc0
docs: fix typo from block_to_swap to blocks_to_swap in README
kohya-ss Nov 20, 2024
0b5229a
save cn
minux302 Nov 21, 2024
420a180
Implement pseudo Huber loss for Flux and SD3
recris Nov 27, 2024
740ec1d
Fix issues found in review
recris Nov 28, 2024
9dff44d
fix device
minux302 Nov 29, 2024
575f583
add README
minux302 Nov 29, 2024
be5860f
add schnell option to load_cn
minux302 Nov 29, 2024
f40632b
rm abundant arg
minux302 Nov 29, 2024
928b939
Allow unknown schedule-free optimizers to continue to module loader
rockerBOO Nov 20, 2024
87f5224
Support d*lr for ProdigyPlus optimizer
rockerBOO Nov 20, 2024
6593cfb
Fix d * lr step log
rockerBOO Nov 21, 2024
c7cadbc
Add pytest testing
rockerBOO Nov 29, 2024
2dd063a
add torch torchvision accelerate versions
rockerBOO Nov 29, 2024
e59e276
Add dadaptation
rockerBOO Nov 29, 2024
dd3b846
Install pytorch first to pin version
rockerBOO Nov 29, 2024
89825d6
Run typos workflows once where appropriate
rockerBOO Nov 29, 2024
4f7f248
Bump typos action
rockerBOO Nov 29, 2024
9c885e5
fix: improve pos_embed handling for oversized images and update resol…
kohya-ss Nov 30, 2024
7b61e9e
Fix issues found in review (pt 2)
recris Nov 30, 2024
a5a27fe
Merge pull request #1808 from recris/huber-loss-flux
kohya-ss Dec 1, 2024
14f642f
fix: huber_schedule exponential not working on sd3_train.py
kohya-ss Dec 1, 2024
0fe6320
fix flux_train.py is not working
kohya-ss Dec 1, 2024
cc11989
fix: refactor huber-loss calculation in multiple training scripts
kohya-ss Dec 1, 2024
1476040
fix: update help text for huber loss parameters in train_util.py
kohya-ss Dec 1, 2024
bdf9a8c
Merge pull request #1815 from kohya-ss/flux-huber-loss
kohya-ss Dec 1, 2024
34e7f50
docs: update README for huber loss
kohya-ss Dec 1, 2024
14c9ba9
Merge pull request #1811 from rockerBOO/schedule-free-prodigy
kohya-ss Dec 1, 2024
1dc873d
update README and clean up code for schedulefree optimizer
kohya-ss Dec 1, 2024
e3fd6c5
Merge pull request #1812 from rockerBOO/tests
kohya-ss Dec 2, 2024
09a3740
Merge pull request #1813 from minux302/flux-controlnet
kohya-ss Dec 2, 2024
e369b9a
docs: update README with FLUX.1 ControlNet training details and impro…
kohya-ss Dec 2, 2024
5ab00f9
Update workflow tests with cleanup and documentation
rockerBOO Dec 2, 2024
63738ec
Add tests documentation
rockerBOO Dec 2, 2024
2610e96
Pytest
rockerBOO Dec 2, 2024
3e5d89c
Add more resources
rockerBOO Dec 2, 2024
8b36d90
feat: support block_to_swap for FLUX.1 ControlNet training
kohya-ss Dec 2, 2024
6bee18d
fix: resolve model corruption issue with pos_embed when using --enabl…
kohya-ss Dec 7, 2024
2be3366
Merge pull request #1817 from rockerBOO/workflow-tests-fixes
kohya-ss Dec 7, 2024
abff4b0
Unify controlnet parameters name and change scripts name. (#1821)
sdbds Dec 7, 2024
e425996
feat: unify ControlNet model name option and deprecate old training s…
kohya-ss Dec 7, 2024
3cb8cb2
Prevent git credentials from leaking into other actions
rockerBOO Dec 9, 2024
8e378cf
add RAdamScheduleFree support
nhamanasu Dec 11, 2024
d3305f9
Merge pull request #1828 from rockerBOO/workflow-security-audit
kohya-ss Dec 15, 2024
f2d38e6
Merge pull request #1830 from nhamanasu/sd3
kohya-ss Dec 15, 2024
e896539
update requirements.txt and README to include RAdamScheduleFree optim…
kohya-ss Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
T5XXL LoRA training, fp8 T5XXL support
kohya-ss committed Sep 4, 2024
commit b65ae9b439e4324359014d6d720aa01def3a19fc
45 changes: 34 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -11,6 +11,11 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 4, 2024:
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI.
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.

Sep 1, 2024:
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
@@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
@@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI.

There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.

- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).

- `--timestep_sampling` is the method to sample timesteps (0-1):
- `sigma`: sigma-based, same as SD3
- `uniform`: uniform random
@@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times

#### Key Features for FLUX.1 LoRA training

1. CLIP-L LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L LoRA.
1. CLIP-L and T5XXL LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.

| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|FLUX.1|`--network_train_unet_only`|-|o|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|

- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L LoRA training.
- *3: Not tested yet.

2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
- When specifying this option, the `--fp8_base` option is automatically enabled.

3. Split Q/K/V Projection Layers (Experimental):
@@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.

```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```

### FLUX.1 fine-tuning
@@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
@@ -256,7 +279,7 @@ CLIP-L LoRA is not supported.
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__

```
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
```

You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
112 changes: 90 additions & 22 deletions flux_train_network.py
Original file line number Diff line number Diff line change
@@ -43,13 +43,9 @@ def assert_extra_args(self, args, train_dataset_group):
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"

# assert (
# args.network_train_unet_only or not args.cache_text_encoder_outputs
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
if not args.network_train_unet_only:
logger.info(
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
)
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False

if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
@@ -63,12 +59,10 @@ def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
name = self.get_flux_model_name(args)

# if we load to cpu, flux.to(fp8) takes a long time
if args.fp8_base:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype

# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
@@ -85,9 +79,21 @@ def load_target_model(self, args, weight_dtype, accelerator):
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()

# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype

# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")

ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

@@ -154,25 +160,35 @@ def get_latents_caching_strategy(self, args):
def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)

def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl

if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)

def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.is_train_text_encoder(args):
if self.train_clip_l and not self.train_t5xxl:
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return text_encoders # ignored
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding

def get_text_encoders_train_flags(self, args, text_encoders):
return [True, False] if self.is_train_text_encoder(args) else [False, False]
return [self.train_clip_l, self.train_t5xxl]

def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
is_partial=self.is_train_text_encoder(args),
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
@@ -193,8 +209,16 @@ def cache_text_encoder_outputs_if_needed(

# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)

if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)

with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)

@@ -235,7 +259,7 @@ def cache_text_encoder_outputs_if_needed(
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)

# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -255,9 +279,12 @@ def cache_text_encoder_outputs_if_needed(
# return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

if not args.split_mode:
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
return

@@ -281,7 +308,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)

@@ -421,6 +448,47 @@ def update_metadata(self, metadata, args):
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)

def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0: # CLIP-L
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)

def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL

def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)

hidden_states = module.wo(hidden_states)
return hidden_states

return forward

for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)

if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
Loading