Skip to content

Commit

Permalink
SOFT MIN SNR GAMMA
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Aug 22, 2024
1 parent de9e83d commit cca6a8c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion flux.bat
Original file line number Diff line number Diff line change
@@ -1 +1 @@
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 6 flux_train_network.py --pretrained_model_name_or_path "D:/SDXL/webui_forge_cu121_torch231/webui/models/Stable-diffusion/flux1-dev.safetensors" --clip_l "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/clip_l.safetensors" --t5xxl "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/t5xxl_fp16.safetensors" --ae "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/ae.safetensors" --cache_latents --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 6 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 16 --learning_rate 5e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --resolution="512,512" --save_every_n_steps="100" --train_data_dir="F:/train/data/sakuranomiya_maika" --output_dir "D:/SDXL/webui_forge_cu121_torch231/webui/models/Lora/sakuranomiya_maika" --logging_dir "F:/sakuranomiya_maika" --output_name sakuranomiya_maika_7 --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 --optimizer_type came --optimizer_args "betas=0.9,0.999,0.9999" "weight_decay=0.01" --split_mode --network_args "train_blocks=single" --max_train_steps="1500" --enable_bucket --caption_extension=".txt" --train_batch_size=4 --apply_t5_attn_mask --noise_offset 0.1 --min_snr_gamma 5
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 6 flux_train_network.py --pretrained_model_name_or_path "D:/SDXL/webui_forge_cu121_torch231/webui/models/Stable-diffusion/flux1-dev.safetensors" --clip_l "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/clip_l.safetensors" --t5xxl "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/t5xxl_fp16.safetensors" --ae "D:/SDXL/webui_forge_cu121_torch231/webui/models/VAE/ae.safetensors" --cache_latents --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 6 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 8 --learning_rate 5e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --resolution="512,512" --save_every_n_steps="100" --train_data_dir="F:/train/data/sakuranomiya_maika" --output_dir "D:/SDXL/webui_forge_cu121_torch231/webui/models/Lora/sakuranomiya_maika" --logging_dir "F:/sakuranomiya_maika" --output_name sakuranomiya_maika_8 --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 --optimizer_type came --optimizer_args "betas=0.9,0.999,0.9999" "weight_decay=0.01" --split_mode --network_args "train_blocks=single" --max_train_steps="2000" --enable_bucket --caption_extension=".txt" --train_batch_size=4 --apply_t5_attn_mask --noise_offset 0.1 --min_snr_gamma 1 --lr_scheduler "REX" --lr_warmup_steps 200
8 changes: 2 additions & 6 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def enforce_zero_terminal_snr(betas):

def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
loss = loss * soft_min_snr_gamma_weight
return loss


Expand Down

0 comments on commit cca6a8c

Please sign in to comment.