From 9249d00311002c84b189c2f6792cbe7aa344a1d5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:19:56 +0900 Subject: [PATCH 1/7] experimental support for multi-gpus latents caching --- library/train_util.py | 27 ++++++++++++++++----------- train_network.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3768b6051..2ca662dcb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -981,7 +981,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1013,8 +1013,12 @@ def __eq__(self, other): batch: List[ImageInfo] = [] current_condition = None + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + logger.info("checking cache validity...") - for info in tqdm(image_infos): + for i, info in enumerate(tqdm(image_infos)): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset @@ -1024,9 +1028,14 @@ def __eq__(self, other): if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - if not is_main_process: # prepare for multi-gpu, only store to info + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: continue + print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask ) @@ -1051,10 +1060,6 @@ def __eq__(self, other): if len(batch) > 0: batches.append((current_condition, batch)) - # if cache to disk, don't cache latents in non-main process, set to info only - if caching_strategy.cache_to_disk and not is_main_process: - return - if len(batches) == 0: logger.info("no latents to cache") return @@ -2258,8 +2263,8 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - def new_cache_latents(self, model: Any, is_main_process: bool): - return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2363,10 +2368,10 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, is_main_process) + dataset.new_cache_latents(model, accelerator) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True diff --git a/train_network.py b/train_network.py index b24f89b1e..7eb7aa49c 100644 --- a/train_network.py +++ b/train_network.py @@ -384,7 +384,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) From 24b1fdb66485af70b3c79feaf8ff1a348b66668e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:22:06 +0900 Subject: [PATCH 2/7] remove debug print --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2ca662dcb..8d6164b1b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1031,10 +1031,10 @@ def __eq__(self, other): # if the modulo of num_processes is not equal to process_index, skip caching # this makes each process cache different latents - if i % num_processes != process_index: + if i % num_processes != process_index: continue - print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask From c80c304779775f4d00fd8f4856bfc8e6599e2de0 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:18:41 +0900 Subject: [PATCH 3/7] Refactor caching in train scripts --- README.md | 10 +++++ fine_tune.py | 2 +- flux_train.py | 14 ++++--- flux_train_network.py | 6 +-- library/train_util.py | 64 +++++++++++++++++++++++--------- sd3_train.py | 17 +++++++-- sdxl_train.py | 4 +- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 5 +-- sdxl_train_network.py | 8 ++-- sdxl_train_textual_inversion.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 14 files changed, 95 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 37fc911f6..2b2562831 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. +- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. +- `--skip_cache_check` option is added to each training script. + - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. + - Specify this option if you have a large number of cache files and the consistency check takes time. + - Even if this option is specified, the cache will be created if the file does not exist. + - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/fine_tune.py b/fine_tune.py index fd63385b3..cdc005d9a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/flux_train.py b/flux_train.py index ecc87c0a8..e18a92443 100644 --- a/flux_train.py +++ b/flux_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -81,7 +85,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -142,7 +146,7 @@ def train(args): if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) ) t5xxl_max_token_length = ( @@ -181,7 +185,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -229,7 +233,7 @@ def train(args): strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( "--blocks_to_swap", diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28e..3bd8316d4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -188,8 +188,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, - None, - False, + args.text_encoder_batch_size, + args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) @@ -222,7 +222,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[1].to(weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) # cache sample prompts if args.sample_prompts is not None: diff --git a/library/train_util.py b/library/train_util.py index 67eaae41b..4e6b3408d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import subprocess from io import BytesIO import toml + # from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -1192,7 +1193,7 @@ def __eq__(self, other): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. """ @@ -1207,15 +1208,25 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo # split by resolution batches = [] batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - # check disk cache exists and size of latents + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available or not is_main_process: # do not add to batch + if cache_available: # do not add to batch continue batch.append(info) @@ -2420,6 +2431,7 @@ def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2437,10 +2449,11 @@ def cache_text_encoder_outputs_sd3( tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_text_encoder_outputs(models, is_main_process) + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -4210,6 +4223,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend = args.dynamo_backend kwargs_handlers = [ - InitProcessGroupKwargs( - backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, - timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None - ) if torch.cuda.device_count() > 1 else None, - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, - static_graph=args.ddp_static_graph - ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) diff --git a/sd3_train.py b/sd3_train.py index 5120105f2..7290956ad 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -103,7 +107,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -312,7 +316,7 @@ def train(args): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, - False, + args.skip_cache_check, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, args.apply_lg_attn_mask, args.apply_t5_attn_mask, @@ -325,7 +329,7 @@ def train(args): t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sdxl_train.py b/sdxl_train.py index aeff9c469..9b2d19165 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -131,7 +131,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -328,7 +328,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 67c8d52c8..74b3a64a4 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -84,7 +84,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -230,7 +230,7 @@ def unwrap_model(model): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63e..14ff7c240 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -93,7 +93,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -202,7 +202,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() @@ -431,7 +431,6 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Text Encoder outputs are cached diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 20e32155c..4a16a4891 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -67,7 +67,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -80,7 +80,7 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None @@ -102,9 +102,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs( - text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index cbfcef554..821a69558 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -49,7 +49,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_db.py b/train_db.py index e49a7e70f..683b42332 100644 --- a/train_db.py +++ b/train_db.py @@ -64,7 +64,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/train_network.py b/train_network.py index 7437157b9..d5330aef4 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b3d3393f..4d8a3abbf 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -114,7 +114,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy From 74228c9953b4ba0f8b0d68e8f6c8a8a6a469c2f5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 16:27:22 +0900 Subject: [PATCH 4/7] update cache_latents/text_encoder_outputs --- library/strategy_base.py | 2 +- tools/cache_latents.py | 147 +++++++++++------------ tools/cache_text_encoder_outputs.py | 178 ++++++++++++++++------------ 3 files changed, 166 insertions(+), 161 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 2bff4178a..363996cec 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy: def __init__( self, cache_to_disk: bool, - batch_size: int, + batch_size: Optional[int], skip_disk_cache_validity_check: bool, is_partial: bool = False, is_weighted: bool = False, diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 2f0098b42..d8154ec31 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -17,42 +17,73 @@ BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging logger = logging.getLogger(__name__) +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd or is_sdxl: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -83,17 +114,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -106,72 +131,27 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - alpha_mask = batch["alpha_mask"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected( - image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask - ): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -182,7 +162,11 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) + parser.add_argument( + "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" + ) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -191,7 +175,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index a75d9da74..d294d46c4 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,55 +9,68 @@ import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments +from tools import cache_latents + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -88,15 +101,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -105,66 +114,68 @@ def cache_to_disk(args: argparse.Namespace) -> None: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") @@ -179,11 +190,20 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser From 2244cf5b835cc35179f29b1babb4a2d19f54bfae Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 18:22:19 +0900 Subject: [PATCH 5/7] load images in parallel when caching latents --- library/train_util.py | 93 ++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e6b3408d..1db470d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -1058,7 +1059,6 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] batch: List[ImageInfo] = [] current_condition = None @@ -1066,57 +1066,70 @@ def __eq__(self, other): num_processes = accelerator.num_processes process_index = accelerator.process_index - logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) - if info.latents_npz is not None: # fine tuning dataset - continue + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: + if info.latents_npz is not None: # fine tuning dataset continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - batch.append(info) - current_condition = condition + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if len(batch) > 0: - batches.append((current_condition, batch)) + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - if len(batches) == 0: - logger.info("no latents to cache") - return + batch.append(info) + current_condition = condition - # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From bfc3a65acda7f90abef9c16db279d2952f73fb77 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:08:16 +0900 Subject: [PATCH 6/7] fix to work cache latents/text encoder outputs --- library/train_util.py | 11 +++++++---- tools/cache_latents.py | 11 ++++++----- tools/cache_text_encoder_outputs.py | 18 +++++++++++++----- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1db470d63..926609267 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4064,15 +4064,18 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v_parameterization and not args.v2: logger.warning( diff --git a/tools/cache_latents.py b/tools/cache_latents.py index d8154ec31..e2faa58a7 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ import torch from tqdm import tqdm -from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -30,7 +30,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa else: is_schnell = False - if is_sd or is_sdxl: + if is_sd: tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) elif is_sdxl: tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -51,6 +51,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" args.cache_latents = True @@ -161,10 +162,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - parser.add_argument( - "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index d294d46c4..7be9ad781 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -27,7 +27,7 @@ BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments -from tools import cache_latents +from cache_latents import set_tokenize_strategy setup_logging() import logging @@ -38,6 +38,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs_to_disk = True @@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: assert ( is_sdxl or args.weighted_captions is None ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" - - cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する use_user_config = args.dataset_config is not None @@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( @@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) return parser From 2d5f7fa709c31d07a1bb44b5be391c29b77d3cfc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:23:21 +0900 Subject: [PATCH 7/7] update README --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 544c665de..7fae50d1a 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,15 @@ The command to install PyTorch is as follows: ### Recent Updates -Oct 12, 2024 (update 1): +Oct 13, 2024: + +- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. -- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. + - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). + - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. - `--skip_cache_check` option is added to each training script. - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - Specify this option if you have a large number of cache files and the consistency check takes time.