From cca6a8c408b447438f6b1b0273adde43b831c1e9 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 23 Aug 2024 02:24:02 +0800 Subject: [PATCH] SOFT MIN SNR GAMMA --- flux.bat | 2 +- library/custom_train_functions.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/flux.bat b/flux.bat index b7572a4a..38606ec9 100644 --- a/flux.bat +++ b/flux.bat @@ -1 +1 @@ -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 6 flux_train_network.py --pretrained_model_name_or_path "D:/SDXL/webui_forge_cu121_torch231/webui/models/Stable-diffusion/flux1-dev.safetensors" --clip_l "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/clip_l.safetensors" --t5xxl "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/t5xxl_fp16.safetensors" --ae "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/ae.safetensors" --cache_latents --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 6 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 16 --learning_rate 5e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --resolution="512,512" --save_every_n_steps="100" --train_data_dir="F:/train/data/sakuranomiya_maika" --output_dir "D:/SDXL/webui_forge_cu121_torch231/webui/models/Lora/sakuranomiya_maika" --logging_dir "F:/sakuranomiya_maika" --output_name sakuranomiya_maika_7 --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 --optimizer_type came --optimizer_args "betas=0.9,0.999,0.9999" "weight_decay=0.01" --split_mode --network_args "train_blocks=single" --max_train_steps="1500" --enable_bucket --caption_extension=".txt" --train_batch_size=4 --apply_t5_attn_mask --noise_offset 0.1 --min_snr_gamma 5 \ No newline at end of file +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 6 flux_train_network.py --pretrained_model_name_or_path "D:/SDXL/webui_forge_cu121_torch231/webui/models/Stable-diffusion/flux1-dev.safetensors" --clip_l "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/clip_l.safetensors" --t5xxl "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/t5xxl_fp16.safetensors" --ae "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/ae.safetensors" --cache_latents --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 6 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 8 --learning_rate 5e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --resolution="512,512" --save_every_n_steps="100" --train_data_dir="F:/train/data/sakuranomiya_maika" --output_dir "D:/SDXL/webui_forge_cu121_torch231/webui/models/Lora/sakuranomiya_maika" --logging_dir "F:/sakuranomiya_maika" --output_name sakuranomiya_maika_8 --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 --optimizer_type came --optimizer_args "betas=0.9,0.999,0.9999" "weight_decay=0.01" --split_mode --network_args "train_blocks=single" --max_train_steps="2000" --enable_bucket --caption_extension=".txt" --train_batch_size=4 --apply_t5_attn_mask --noise_offset 0.1 --min_snr_gamma 1 --lr_scheduler "REX" --lr_warmup_steps 200 \ No newline at end of file diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5..f6720811 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -65,12 +65,8 @@ def enforce_zero_terminal_snr(betas): def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) - if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) - else: - snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) - loss = loss * snr_weight + soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) + loss = loss * soft_min_snr_gamma_weight return loss