diff --git a/README.md b/README.md index 81a3199bc..f9c85e3ac 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more __Options for GPUs with less VRAM:__ -By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. +By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. -Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. +Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`. Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: @@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. -The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: +The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below: ``` --blocks_to_swap 16 diff --git a/flux_train_network.py b/flux_train_network.py index 6dcfadba2..19b9fda87 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -450,6 +450,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices],