-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM on train on 48GPU with batch 1 proc 1 #49
Comments
Seems my training is under 48GB bash train_scripts/train.sh configs/sana_config/1024ms/Sana_1600M_img1024.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false --data.load_vae_feat=false --train.train_batch_size=1 Actually if you switch optimizer type in config file to train:
optimizer:
lr: 1.0e-4
type: AdamW
weight_decay: 0.01
eps: 1.0e-8
betas: [0.9, 0.999] |
But how you do it? You have 32x VAE vs 8x in SDXL, less model size and need more then 2.5x against SDXL for train in 1024? |
The model you are using is 1.6B and all the VAE and Text Encoder are all extracting feature online. |
it is all about training scripts Currently with using Kohya we are able to fully fine to 12 billion parameters FLUX dev in 16 bit precision even on 6 GB GPUs via using block swapping :) |
Cool. Hh, I can't imagine the speed. lol : ) |
i'm very hope what you will add some minimum optimizations in the future.. |
What do you mean it's dead? |
with latest improvements speeds are really decent rtx 3090 is 7 second per sample image - batch size 1 |
it's dead for full fine-tunings from GPU poor guys |
Could you please try this: |
@Muinez , nice man. I'll try this out. |
@Muinez thank you! bucketing works nice, code is clean and very simple! i add checkpoints dir creation around line 500 (it fail on 1st run):
its not work for me on 48Gb/A40 with batch 1 with CAME optimizer (default settings, exclude batch/num proc) - OOM after generating validation images with wandb (probably after decoding - you may see memory spikes at validation) but its looks like it do something strange and work on Adam https://wandb.ai/recoilme/sana/runs/1?nw=nwuserrecoilme What do you think about add adafactor? Is it possible? I have totally better results with this optimizer on SDXL. Adam is like tank - good only in pretrain case.. |
The loss means you are training from scratch. (no ckpts loaded) |
Looking forward to your news. If you meet NaN problem, you could try to change your config file with: model:
mixed_precision: bf16 , which is mentioned in other issue. |
I'll look into |
1st of all im GPU poor, and train on old A40, slowly, sorry for that I train not in one big train, but step by step
So instead 1 big train with high LR on billions 100 ep, i train by 4-8 ep on small dataset, in wich i add new data to learn on each iteration So far, 2 steps done 1st step train 4 ep on 4k dataset (just mix of very aesthetic) with 1e-4 2 step - add 2k images in dataset (new concept, illustrations) and train 4ep (it small - here 8ep will works better) For now i think add 4k anime imgs and maybe switch on bf16? (in sdxl bf16 works much better then fp16, fp16 train just not work in sdxl). What do you think in general? Is it good result for 1 day? Any suggestions? (i dont train from zero before) |
Looks like just switch precision don't work out of the box
Looking at this code now, but not sure its good https://github.com/bghira/SimpleTuner/blob/main/helpers/training/optimizers/adamw_bfloat16/__init__.py |
no need to change the optimizer to bf16 |
Are you training from scratch or loaded a Pretrained checkpoints? |
i train from scratch (it's started by my mistake, lol, i forgot load_from)
ok, switched on bf16 and let it try |
it mistake in config, lr:9e-5 (int) instead 9.0e-5 |
if i make buckets in fp16, make model in bf16 and start train Does it mean what we may train in bf16 and convert for inference in fp16 and it will work? What do you think? |
You don't need to redo the VAE latents for bf16, and I think it's pointless to train from scratch because it will take too long to achieve reasonable results |
I'm also currently trying to train a 512 model with a batch size of 96, and I'm noticing that the model requires a high learning rate (5e-5) and is learning extremely slowly. After over 10,000 steps, it still hasn't fully grasped the concept, and the samples don't look great. It seems the model is indeed struggling to work with this VAE. I'm also training on low-quality images for the negative prompt, and while it helps a bit, the improvements are minimal. |
@Muinez pls take a look at #50 (comment) looks critical if i right about train from scratch - its research - how faster train in bf16 and how fast train in general |
We have a totally different diffusion training schedule from SDXL. Flow matching vs DDPM. The loss is coursed by this, and nothing with the VAE here. If you try to train SD3, I think the loss should be larger than 0.6. Sana's loss will converge at ~0.7 at 1024px and ~0.8 at 512px. |
Actually, we train with 1e-4 if the new dataset has many new concepts. |
This might be a silly question since I don't have much experience with this, but if we add a sampling step to the training pipeline, could it help the model mitigate its own errors? I mean, instead of providing the model with an almost perfect noise-to-signal ratio, it would also learn to correct itself along the way |
Could you pls explain more about it? |
Make the model predict from a noised sample, treat it as a sampling step, add (or subtract, can’t remember) a random percentage of the prediction back to the noised sample, then make another prediction based on this modified version. After that, calculate the MSE loss between this final prediction and the actual target |
What's the insight? What's the difference between the MSE loss between the actual target and the model prediction directly? |
I’m not entirely sure, but the idea is that by letting the model make a prediction, we’re assuming it will introduce some errors, and the loss would then train the model to predict the correct image from that "incorrect" noise. Whereas if we’re just training from a perfect sample where we added the noise ourselves, it doesn’t contain those kinds of errors. Or am I wrong? |
#pseudocode
noisy_data = add_noise(real_data, time_steps)
predicted_sample = predict(noisy_data, time_steps)
weighted_mean_sample = (noisy_data + time_steps * predicted_sample) / 2
final_prediction = predict(weighted_mean_sample, time_steps)
loss = (criterion(real_data, final_prediction) + criterion(real_data, predicted_sample)) / 2 Or maybe something like this could work, though I'm not sure if I’ve got it right |
Hh, this is more like something done by CM(consistency model) |
I didn’t know, I was just coming from the idea that the model is trained to predict data without noise. That’s why I decided to take the average to mix noisy data and sample predicted by the model, which still has errors, and then predict the target data based on that. Anyway, I hope you get the idea. I’m not sure if you’ll develop this further or do anything with it. Looking at the code now, I probably didn’t do it quite right. Maybe we should’ve applied add_noise on time_steps * predicted_sample, but I’m not sure |
Yah, I totally understand your statement. Cool idea. I have tried a similar idea before. refer to: Pixart-delta, which combines PixArt and LCM together. LCM or CM is something like you said. Refer to: https://arxiv.org/abs/2401.05252 |
https://huggingface.co/Muinez/sana-512-anime This is what I ended up with. Negative prompt: bad artwork |
why download_full_repo=True 😭 |
Oh, looks pretty cool on my side. 👍 |
from_pretrained uses find_model, which uses hf_download_or_fpath. I asked about this because when I run the app with my model, it downloads a 19GB file that isn't needed since I specifically moved the state_dict to a separate file |
Yes, the model has generally learned the quality of the dataset (style, etc.). For example, it generates scenery beautifully, but the details are quite poor, and it's not very coherent |
lol, got it. We suppose there is only related ckpt in the hf hub when using from_pretrained. Not a big deal, anyway. could be refined better. If you want, you can draft a PR also. 👍 |
Make sense, we will try to address the detail issue in our next version this year or at the beginning of next year. |
Besides, the 1024 model will have much less detail issues you mentioned. Since for a 512px image, our vea will compress it into only 16x16 tokens. If you want, 1024px models may give you much better results. |
Well, I’m already out of money to rent a GPU 😅. Besides, I’ve already tried, and the results weren’t great—either I ended up with NaN loss or the model just trained much slower. Maybe I should have increased the learning rate back then, but I’m not sure how stable it would have been with a batch size of 30 |
ok, I see. I tried bf16 fine-tuning , no NaN and converge fast. Maybe we will released this more stable version later in the future |
By the way, about Sana’s low variability—maybe it’s because fewer tokens with higher dimensions make the model focus on precision over diversity. More tokens might allow for greater variation since each one wouldn’t need to carry as much global information. Or maybe it just feels like the model generates overly similar images, or it could even be a dataset issue, idk. What do you think? |
Actually, I don't think the variability is related to the less token number. First, the CFG and PAG guidance scale will affect the variability. Then, model size is other core element affecting the generalization. Also, the data may affect. |
@lawrence-cj It seems like the 600M model has a lower loss than the 1600M model. The loss is around 0.6 on the 600M, while on the 1600M, it's around 0.7. It also feels like the model is learning new concepts much faster on its own. Shouldn't it be the other way around, or is it just me? |
|
Why its understand complex prompt (not from dataset), but 1girl - dont? |
Seems like the BF16 training is corrupt the model at the very first beginning. We will release a BF16 version for fine-tuning later. |
ok, but its not retrain fp16 in bf16 its train from empty, zero, new model |
i waiting sana so long for training on potato, but its not working on A40 with 48GPU(
The text was updated successfully, but these errors were encountered: