diff --git a/README.md b/README.md index b64515a19..7fae50d1a 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,21 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 13, 2024: + +- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. + +- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. + - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. + - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). + - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. +- `--skip_cache_check` option is added to each training script. + - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. + - Specify this option if you have a large number of cache files and the consistency check takes time. + - Even if this option is specified, the cache will be created if the file does not exist. + - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. + Oct 12, 2024 (update 1): - [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. diff --git a/fine_tune.py b/fine_tune.py index fd63385b3..cdc005d9a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/flux_train.py b/flux_train.py index 2fc13068e..46a8babdb 100644 --- a/flux_train.py +++ b/flux_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -81,7 +85,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -142,7 +146,7 @@ def train(args): if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) ) t5xxl_max_token_length = ( @@ -229,7 +233,7 @@ def train(args): strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( "--blocks_to_swap", diff --git a/flux_train_network.py b/flux_train_network.py index a24c1905b..aa92fe3ae 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -188,8 +188,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, - None, - False, + args.text_encoder_batch_size, + args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) @@ -222,7 +222,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[1].to(weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) # cache sample prompts if args.sample_prompts is not None: diff --git a/library/strategy_base.py b/library/strategy_base.py index c6cf825ce..e390c5f35 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy: def __init__( self, cache_to_disk: bool, - batch_size: int, + batch_size: Optional[int], skip_disk_cache_validity_check: bool, is_partial: bool = False, is_weighted: bool = False, diff --git a/library/train_util.py b/library/train_util.py index 1701bb992..4a446e81c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -31,6 +32,7 @@ import subprocess from io import BytesIO import toml + # from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -1029,7 +1031,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1057,60 +1059,77 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] batch: List[ImageInfo] = [] current_condition = None - logger.info("checking cache validity...") - for info in tqdm(image_infos): - subset = self.image_to_subset[info.image_key] + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index - if info.latents_npz is not None: # fine tuning dataset - continue + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - if not is_main_process: # prepare for multi-gpu, only store to info - continue + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - batch.append(info) - current_condition = condition + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - if len(batch) > 0: - batches.append((current_condition, batch)) + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if cache to disk, don't cache latents in non-main process, set to info only - if caching_strategy.cache_to_disk and not is_main_process: - return + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if len(batches) == 0: - logger.info("no latents to cache") - return + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと @@ -1187,7 +1206,7 @@ def __eq__(self, other): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. """ @@ -1202,15 +1221,25 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo # split by resolution batches = [] batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - # check disk cache exists and size of latents + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available or not is_main_process: # do not add to batch + if cache_available: # do not add to batch continue batch.append(info) @@ -2327,8 +2356,8 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - def new_cache_latents(self, model: Any, is_main_process: bool): - return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2432,10 +2461,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, is_main_process) + dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2453,10 +2483,11 @@ def cache_text_encoder_outputs_sd3( tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_text_encoder_outputs(models, is_main_process) + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -4054,15 +4085,18 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v_parameterization and not args.v2: logger.warning( @@ -4226,6 +4260,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -5100,15 +5140,24 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend = args.dynamo_backend kwargs_handlers = [ - InitProcessGroupKwargs( - backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, - timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None - ) if torch.cuda.device_count() > 1 else None, - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, - static_graph=args.ddp_static_graph - ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) diff --git a/sd3_train.py b/sd3_train.py index 5120105f2..7290956ad 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -103,7 +107,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -312,7 +316,7 @@ def train(args): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, - False, + args.skip_cache_check, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, args.apply_lg_attn_mask, args.apply_t5_attn_mask, @@ -325,7 +329,7 @@ def train(args): t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sdxl_train.py b/sdxl_train.py index aeff9c469..9b2d19165 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -131,7 +131,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -328,7 +328,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 67c8d52c8..74b3a64a4 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -84,7 +84,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -230,7 +230,7 @@ def unwrap_model(model): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63e..14ff7c240 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -93,7 +93,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -202,7 +202,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() @@ -431,7 +431,6 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Text Encoder outputs are cached diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 20e32155c..4a16a4891 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -67,7 +67,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -80,7 +80,7 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None @@ -102,9 +102,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs( - text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index cbfcef554..821a69558 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -49,7 +49,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 2f0098b42..e2faa58a7 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -17,42 +17,74 @@ BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging logger = logging.getLogger(__name__) +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -83,17 +115,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -106,72 +132,27 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - alpha_mask = batch["alpha_mask"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected( - image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask - ): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -181,8 +162,12 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -191,7 +176,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index a75d9da74..7be9ad781 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,55 +9,69 @@ import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments +from cache_latents import set_tokenize_strategy + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -88,15 +102,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -105,69 +115,71 @@ def cache_to_disk(args: argparse.Namespace) -> None: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -177,13 +189,29 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", + ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) return parser diff --git a/train_db.py b/train_db.py index e49a7e70f..683b42332 100644 --- a/train_db.py +++ b/train_db.py @@ -64,7 +64,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/train_network.py b/train_network.py index e48e6a070..d5330aef4 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -384,7 +384,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b3d3393f..4d8a3abbf 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -114,7 +114,7 @@ def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> L def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy