Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EMA #893

Closed
wants to merge 16 commits into from
Closed

Add EMA #893

wants to merge 16 commits into from

Conversation

vvern999
Copy link
Contributor

@vvern999 vvern999 commented Oct 22, 2023

Add EMA support. Usually it is used like this:

  • At the start of training, the model parameters are copied to the EMA parameters.
  • At each gradient update step, the EMA parameters are updated using the following formula: ema_param = ema_param * decay + param * (1 - decay)
  • At the end of training, the model parameters are replaced by the EMA parameters.

In this implementation by default both EMA and non-EMA weights are saved (model is saved in 2 files). Usually you only need EMA weights if you using EMA.
Usually EMA is used when training large models, but I tried it on short finetuning and Lora training, and it seems to work. though difference between EMA and normal weights is small.

@FurkanGozukara
Copy link

amazing thank you so much

@kohya-ss
Copy link
Owner

Thank you for this! I will review and merge it sooner.

@FurkanGozukara
Copy link

i hope this can be tested and merged

this will be only SD 1.5 right?

@vvern999
Copy link
Contributor Author

vvern999 commented Nov 1, 2023

tests on SD1:
1000 step lora - https://files.catbox.moe/wf03b9.png - very small difference, mostly in backgrounds
1000 step finetune - https://files.catbox.moe/m381z4.jpg - noticeable difference
Not sure if with torch.no_grad(), accelerator.autocast() is necessary, but it doesn't seem to break anything.

this will be only SD 1.5 right?

It works with SDXL Lora training. SDXL finetuning with ema would need a lot of memory. It probably won't fit in 24GB. so it's not useful.

@FurkanGozukara
Copy link

tests on SD1: 1000 step lora - https://files.catbox.moe/wf03b9.png - very small difference, mostly in backgrounds 1000 step finetune - https://files.catbox.moe/m381z4.jpg - noticeable difference Not sure if with torch.no_grad(), accelerator.autocast() is necessary, but it doesn't seem to break anything.

this will be only SD 1.5 right?

It works with SDXL Lora training. SDXL finetuning with ema would need a lot of memory. It probably won't fit in 24GB. so it's not useful.

but if working great. for best quality we can do on higher vram machine

@vvern999
Copy link
Contributor Author

vvern999 commented Nov 3, 2023

Trained lycoris on multiple characters with --enable_ema --ema_decay=0.9995 for about 10000 steps.
Looks like it works properly now, ema stabilizes training a little.
Prompt was "character riding a horse", and character on the left is supposed to have 2 feathers in her hair. On EMA version this looks correct.

Untitled1
Untitled2

@FurkanGozukara
Copy link

@vvern999 EMA was huge improvemtn when I was using Automatic1111 DreamBooth extension

can you verify SDXL too? also how did you come up with value --ema_decay=0.9995

@IdiotSandwichTheThird
Copy link

IdiotSandwichTheThird commented Nov 4, 2023

I know it is difficult, but please, for the love of god, STOP adding features to only some of the trainers. It causes so much unnecessary confusion, especially when it is silently not doing anything, like in this case when using --enable_ema with sdxl_train.py with this PR.

If you don't want to add the functionality, inform the user that it's not doing anything.

@FurkanGozukara
Copy link

I know it is difficult, but please, for the love of god, STOP adding features to only some of the trainers. It causes so much unnecessary confusion, especially when it is silently not doing anything, like in this case with the SDXL trainer in this PR.

it adds huge quality to the SD 1.5 based training. why you complain?

@vvern999
Copy link
Contributor Author

vvern999 commented Nov 5, 2023

@IdiotSandwichTheThird
Added the messages.
I was only interested in adding this for Lora/Lycoris training, i should have made it clear, sorry. I assumed people use this repo mostly for Lora training.

Plenty of other training scripts like hcp-diffusion, everdream2, onetrainer, etc. support EMA for SDXL training.
You can use them for SDXL training. Some of these scripts also support additional optimizations like deepspeed/colossal-ai and interesting features like custom loss functions, or training embeddings and model at the same time. Those features are missing here.

@FurkanGozukara
ema_decay=0.9995 is in the middle of recommended values. I remember reading somewhere that it should be 0.999 for small datasets and 0.9999 for large datasets.

@FurkanGozukara
Copy link

@IdiotSandwichTheThird I added the messages. As for SDXL: plenty of other training scripts like hcp-diffusion, everdream2, onetrainer, etc. support EMA for SDXL finetuning. You can use them.

I was only interested in adding this for Lora/Lycoris training, i should have made it clear, sorry.

@FurkanGozukara ema_decay=0.9995 is in the middle of recommended values. I remember reading somewhere that it should be 0.999 for small datasets and 0.9999 for large datasets.

I really prefer Kohya to support it

can you add EMA for SDXL of Kohya DreamBooth?

@FurkanGozukara
Copy link

can we expect EMA for SDXL?

@IdiotSandwichTheThird
Copy link

@vvern999
It appears there is a weird bug where training crashes when the commandline argument --learning_rate_te is used together with --enable_ema.

Traceback (most recent call last):
  File "F:\Kohya2\sd-scripts-new\.fluffyv\train_db.py", line 537, in <module>
    train(args)
  File "F:\Kohya2\sd-scripts-new\.fluffyv\train_db.py", line 222, in train
    ema = EMAModel(trainable_params, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
  File "F:\Kohya2\sd-scripts-new\.fluffyv\library\train_util.py", line 2305, in __init__
    self.shadow_params = [p.clone().detach() for p in parameters]
  File "F:\Kohya2\sd-scripts-new\.fluffyv\library\train_util.py", line 2305, in <listcomp>
    self.shadow_params = [p.clone().detach() for p in parameters]
AttributeError: 'dict' object has no attribute 'clone'

@IdiotSandwichTheThird
Copy link

IdiotSandwichTheThird commented Nov 20, 2023

Additionally, I could not get EMA to actually train with the train_db script, the resulting ema checkpoint ended up being the same as the base model, only the non-EMA trained.
There's no errors output regarding to this during training, so I'm not sure how to troubleshoot it.

@vvern999
Copy link
Contributor Author

vvern999 commented Nov 20, 2023

@IdiotSandwichTheThird

ema checkpoint ended up being the same as the base model

Either something is wrong with copying / updating weights or default decay value is too high.
With values like --ema_decay=0.5 difference is very noticable

@IdiotSandwichTheThird
Copy link

@IdiotSandwichTheThird

ema checkpoint ended up being the same as the base model

Either something is wrong with copying / updating weights or default decay value is too high. With values like --ema_decay=0.5 difference is very noticable

Okay with ema_decay 0.5, there is indeed a difference, I guess I set it way too high, possibly too because of the extremely low LR I choose.

@ThereforeGames
Copy link

ema_decay=0.9995 is in the middle of recommended values. I remember reading somewhere that it should be 0.999 for small datasets and 0.9999 for large datasets.

Just curious - what is "large" in this context? Dataset size varies greatly between lora training and finetuning, for example. Are we talking about 100 images or 10k?

@vvern999
Copy link
Contributor Author

vvern999 commented Dec 7, 2023

There's a new method of doing EMA. https://arxiv.org/abs/2312.02696
Screenshot 2023-12-07 2

Closing this PR. It would be better to implement the method from that paper instead.

@vvern999 vvern999 closed this Dec 7, 2023
@FurkanGozukara
Copy link

There's a new method of doing EMA. https://arxiv.org/abs/2312.02696 Screenshot 2023-12-07 2

Closing this PR. It would be better to implement the method from that paper instead.

I am looking forward to EMA implementation for SDXL

when can we expect?

@FurkanGozukara
Copy link

ema is not added to the code yet right?

@mykeehu
Copy link

mykeehu commented Jan 30, 2024

I also regret that the EMA development was not implemented, because if it could have achieved a better result, it would have been better to use it.

@FurkanGozukara
Copy link

I tested EMA on OneTrainer and definitely improves quality. Also EMA can be made to run on CPU there so no extra VRAM

However EMA on OneTrainer didn't work on SDXL but worked on SD 1.5

@vvern999
Copy link
Contributor Author

vvern999 commented Feb 1, 2024

There's a nice explanation of how it should work - https://github.com/cloneofsimo/karras-power-ema-tutorial
I'll try to implement it in a few days.

@FurkanGozukara
Copy link

There's a nice explanation of how it should work - https://github.com/cloneofsimo/karras-power-ema-tutorial I'll try to implement it in a few days.

awesome

@parth1313
Copy link

parth1313 commented Oct 10, 2024

@vvern999 @FurkanGozukara I am still interested in knowing if EMA for full SDXL finetuning is implemented or not ??

@FurkanGozukara
Copy link

@vvern999 @FurkanGozukara I am still interested in knowing if EMA for full SDXL finetuning is implemented or not ??

i think not implemented

@parth1313
Copy link

parth1313 commented Oct 10, 2024

@FurkanGozukara @vvern999
I have few quesions :
Did you test this script with --enable_ema?
Will it only do LORA training?
Has EMA method been implemented from that Paper or traditional method?

Can you explain in details, as i am new to it.

@FurkanGozukara
Copy link

@FurkanGozukara @vvern999 I have few quesions : Did you test this script with --enable_ema? Will it only do LORA training? Has EMA method been implemented from that Paper or traditional method?

Can you explain in details, as i am new to it.

i didnt test ema yet at all on kohya

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants