Skip to content

Commit

Permalink
Merge pull request #1690 from kohya-ss/multi-gpu-caching
Browse files Browse the repository at this point in the history
Caching latents and Text Encoder outputs with multiple GPUs
  • Loading branch information
kohya-ss authored Oct 13, 2024
2 parents d02a6ef + 2d5f7fa commit 1275e14
Show file tree
Hide file tree
Showing 17 changed files with 347 additions and 259 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 13, 2024:

- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.

- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU.
- Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching.
- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching.
- Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment).
- `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching.
- `--skip_cache_check` option is added to each training script.
- When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped.
- Specify this option if you have a large number of cache files and the consistency check takes time.
- Even if this option is specified, the cache will be created if the file does not exist.
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.

Oct 12, 2024 (update 1):

- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models.
Expand Down
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

Expand Down
12 changes: 8 additions & 4 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check

# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
Expand All @@ -81,7 +85,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

Expand Down Expand Up @@ -142,7 +146,7 @@ def train(args):
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
Expand Down Expand Up @@ -229,7 +233,7 @@ def train(args):
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)

# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
Expand Down Expand Up @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
Expand Down
6 changes: 3 additions & 3 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def get_text_encoder_outputs_caching_strategy(self, args):
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
Expand Down Expand Up @@ -222,7 +222,7 @@ def cache_text_encoder_outputs_if_needed(
text_encoders[1].to(weight_dtype)

with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)

# cache sample prompts
if args.sample_prompts is not None:
Expand Down
2 changes: 1 addition & 1 deletion library/strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy:
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
Expand Down
Loading

1 comment on commit 1275e14

@AbstractEyes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big patch. Much appreciated.

Please sign in to comment.