From 36770942562a31cd7d28ac06d4e1bf4332a206e9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 27 Nov 2024 12:57:04 +0900 Subject: [PATCH] Text Encoder cache (WIP) --- flux_train.py | 20 +- flux_train_network.py | 11 +- library/strategy_base.py | 304 ++++++++++++++++++++++++------- library/strategy_flux.py | 126 ++++++------- library/strategy_sd.py | 7 +- library/strategy_sd3.py | 166 ++++++++--------- library/strategy_sdxl.py | 103 ++++++----- library/train_util.py | 303 +++++++++++++----------------- library/utils.py | 35 ++++ sd3_train.py | 12 +- sd3_train_network.py | 6 +- sdxl_train.py | 6 +- sdxl_train_control_net.py | 6 +- sdxl_train_control_net_lllite.py | 6 +- sdxl_train_network.py | 6 +- 15 files changed, 637 insertions(+), 480 deletions(-) diff --git a/flux_train.py b/flux_train.py index a7d38c58d..e5e4d17c1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -151,15 +151,20 @@ def train(args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.debug_dataset: + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + t5xxl_max_token_length, + args.apply_t5_attn_mask, + False, ) ) - t5xxl_max_token_length = ( - args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) - ) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) train_dataset_group.set_current_strategies() @@ -236,7 +241,12 @@ def train(args): t5xxl.to(accelerator.device) text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + t5xxl_max_token_length, + args.apply_t5_attn_mask, + False, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..6dcfadba2 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -10,8 +10,6 @@ init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util -import train_network from library.utils import setup_logging setup_logging() @@ -19,6 +17,9 @@ logger = logging.getLogger(__name__) +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +import train_network + class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -174,13 +175,17 @@ def get_text_encoders_train_flags(self, args, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: + fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, + t5xxl_max_token_length, + args.apply_t5_attn_mask, is_partial=self.train_clip_l or self.train_t5xxl, - apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None diff --git a/library/strategy_base.py b/library/strategy_base.py index 6bcf3dc6a..1e02aae53 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -10,10 +10,6 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection -# TODO remove circular import by moving ImageInfo to a separate file -# from library.train_util import ImageInfo - -from library import utils from library.utils import setup_logging setup_logging() @@ -21,6 +17,8 @@ logger = logging.getLogger(__name__) +from library import utils + def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]: if dtype is None: @@ -43,6 +41,58 @@ def get_available_dtypes() -> List[torch.dtype]: return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None: + """ + Removes lower precision values from tensor_dict. + """ + available_dtypes = get_available_dtypes() + available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes] + + for key_without_dtype in keys_without_dtype: + available_itemsize = None + for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes): + key = key_without_dtype + dtype_suffix + + if key in tensor_dict: + if available_itemsize is None: + available_itemsize = dtype.itemsize + elif available_itemsize > dtype.itemsize: + # if higher precision latents are already cached, remove lower precision latents + del tensor_dict[key] + + +def get_compatible_dtype_keys( + dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]] +) -> list[Optional[str]]: + """ + Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable. + If the key is not found, it returns None. + If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor. + + :param dict_keys: set of keys in the dictionary + :param keys_without_dtype: list of keys without dtype suffix to check + :param dtype: dtype to check, or None for any dtype + :return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key. + """ + compatible_dtypes = get_compatible_dtypes(dtype) + dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes] + + available_keys = [] + for key_without_dtype in keys_without_dtype: + available_key = None + if key_without_dtype in dict_keys: + available_key = key_without_dtype + else: + for dtype_suffix in dtype_suffixes: + key = key_without_dtype + dtype_suffix + if key in dict_keys: + available_key = key + break + available_keys.append(available_key) + + return available_keys + + class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class @@ -347,17 +397,26 @@ class TextEncoderOutputsCachingStrategy: def __init__( self, + architecture: str, cache_to_disk: bool, batch_size: Optional[int], skip_disk_cache_validity_check: bool, + max_token_length: int, + masked: bool = False, is_partial: bool = False, is_weighted: bool = False, ) -> None: + """ + max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model. + """ + self._architecture = architecture self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._max_token_length = max_token_length + self._masked = masked self._is_partial = is_partial - self._is_weighted = is_weighted + self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt @classmethod def set_strategy(cls, strategy): @@ -369,6 +428,18 @@ def set_strategy(cls, strategy): def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: return cls._strategy + @property + def architecture(self): + return self._architecture + + @property + def max_token_length(self): + return self._max_token_length + + @property + def masked(self): + return self._masked + @property def cache_to_disk(self): return self._cache_to_disk @@ -377,6 +448,11 @@ def cache_to_disk(self): def batch_size(self): return self._batch_size + @property + def cache_suffix(self): + suffix_masked = "_m" if self.masked else "" + return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors" + @property def is_partial(self): return self._is_partial @@ -385,24 +461,145 @@ def is_partial(self): def is_weighted(self): return self._is_weighted - def get_outputs_npz_path(self, image_abs_path: str) -> str: - raise NotImplementedError + def get_cache_path(self, absolute_path: str) -> str: + return os.path.splitext(absolute_path)[0] + self.cache_suffix - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]: raise NotImplementedError - def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]: + """ + get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None. + all dtype tensors are returned, because cache validation is done in advance. + """ + with safe_open(cache_path, framework="pt") as f: + metadata = f.metadata() + version = metadata.get("format_version", "0.0.0") + major, minor, patch = map(int, version.split(".")) + if major > 1: # or (major == 1 and minor > 0): + if not self.load_version_warning_printed: + self.load_version_warning_printed = True + logger.warning( + f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues." + ) + + dict_keys = f.keys() + results = [] + compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None) + for key in compatible_keys: + results.append(f.get_tensor(key) if key is not None else None) + + return results + + def is_disk_cached_outputs_expected( + self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]] + ) -> bool: raise NotImplementedError + def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str: + """ + masked: may be False even if self.masked is True. It is False for some outputs. + """ + key_suffix = f"_{prompt_id}" + if dtype is not None and dtype.is_floating_point: # float tensor only + key_suffix += "_" + utils.dtype_to_normalized_str(dtype) + return key_suffix + + def get_compatible_output_keys( + self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]] + ) -> list[Optional[str], Optional[str]]: + """ + returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable. + """ + key_suffix = self.get_key_suffix(caption_index, None) + keys_without_dtype = [k + key_suffix for k in base_keys] + return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype) + + def _default_is_disk_cached_outputs_expected( + self, + cache_path: str, + captions: list[str], + base_keys: list[tuple[str, bool]], + preferred_dtype: Optional[Union[str, torch.dtype]], + ): + if not self.cache_to_disk: + return False + if not os.path.exists(cache_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + with utils.MemoryEfficientSafeOpen(cache_path) as f: + keys = f.keys() + metadata = f.metadata() + + # check captions in metadata + for i, caption in enumerate(captions): + if metadata.get(f"caption{i+1}") != caption: + return False + + compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype) + if any(key is None for key in compatible_keys): + return False + except Exception as e: + logger.error(f"Error loading file: {cache_path}") + raise e + + return True + def cache_batch_outputs( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + self, + tokenize_strategy: TokenizeStrategy, + models: list[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: list[tuple[utils.ImageInfo, int, str]], ): raise NotImplementedError + def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]): + tensor_dict = {} + + overwrite = False + if os.path.exists(cache_path): + # load existing safetensors and update it + overwrite = True + + with utils.MemoryEfficientSafeOpen(cache_path) as f: + metadata = f.metadata() + keys = f.keys() + for key in keys: + tensor_dict[key] = f.get_tensor(key) + assert metadata["architecture"] == self.architecture + + file_version = metadata.get("format_version", "0.0.0") + major, minor, patch = map(int, file_version.split(".")) + if major > 1 or (major == 1 and minor > 0): + self.save_version_warning_printed = True + logger.warning( + f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues." + ) + else: + metadata = {} + metadata["architecture"] = self.architecture + metadata["format_version"] = "1.0.0" + + metadata[f"caption{caption_index+1}"] = caption + + for key, output in zip(keys, outputs): + dtype = output.dtype # long or one of float + key_suffix = self.get_key_suffix(caption_index, dtype) + tensor_dict[key + key_suffix] = output + + # remove lower precision latents if higher precision latents are already cached + if overwrite: + suffix_without_dtype = self.get_key_suffix(caption_index, None) + remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype]) + + save_file(tensor_dict, cache_path, metadata=metadata) -class LatentsCachingStrategy: - # TODO commonize utillity functions to this class, such as npz handling etc. +class LatentsCachingStrategy: _strategy = None # strategy instance: actual strategy class def __init__( @@ -495,36 +692,22 @@ def get_key_suffix( def get_compatible_latents_keys( self, keys: set[str], - dtype: Union[str, torch.dtype], + dtype: Optional[Union[str, torch.dtype]], flip_aug: bool, bucket_reso: Optional[Tuple[int, int]] = None, latents_size: Optional[Tuple[int, int]] = None, - ) -> Tuple[Optional[str], Optional[str]]: + ) -> list[Optional[str], Optional[str]]: """ bucket_reso is (W, H), latents_size is (H, W) """ - latents_key = None - flipped_latents_key = None - - compatible_dtypes = get_compatible_dtypes(dtype) - - for compat_dtype in compatible_dtypes: - key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype) - - if latents_key is None: - latents_key = "latents" + key_suffix - if latents_key not in keys: - latents_key = None - if flip_aug and flipped_latents_key is None: - flipped_latents_key = "latents_flipped" + key_suffix - if flipped_latents_key not in keys: - flipped_latents_key = None - - if latents_key is not None and (flipped_latents_key is not None or not flip_aug): - break + key_suffix = self.get_key_suffix(bucket_reso, latents_size, None) + keys_without_dtype = ["latents" + key_suffix] + if flip_aug: + keys_without_dtype.append("latents_flipped" + key_suffix) - return latents_key, flipped_latents_key + compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype) + return compatible_keys if flip_aug else compatible_keys[0] + [None] def _default_is_disk_cached_latents_expected( self, @@ -555,24 +738,13 @@ def _default_is_disk_cached_latents_expected( # print(f"alpha_mask not found: {latents_cache_path}") return False - if preferred_dtype is None: - # remove dtype suffix from keys, because any dtype is acceptable - keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)] - keys = set(keys) - if "latents" + key_suffix_without_dtype not in keys: - # print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}") - return False - if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys: - # print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}") - return False - else: - # specific dtype or compatible dtype is required - latents_key, flipped_latents_key = self.get_compatible_latents_keys( - keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso - ) - if latents_key is None or (flip_aug and flipped_latents_key is None): - # print(f"Precise dtype not found: {latents_cache_path}") - return False + # preferred_dtype is None if any dtype is acceptable + latents_key, flipped_latents_key = self.get_compatible_latents_keys( + keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso + ) + if latents_key is None or (flip_aug and flipped_latents_key is None): + # print(f"Precise dtype not found: {latents_cache_path}") + return False except Exception as e: logger.error(f"Error loading file: {latents_cache_path}") raise e @@ -581,7 +753,14 @@ def _default_is_disk_cached_latents_expected( # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List[utils.ImageInfo], + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -711,22 +890,7 @@ def save_latents_to_disk( # remove lower precision latents if higher precision latents are already cached if overwrite: - available_dtypes = get_available_dtypes() - available_itemsize = None - available_itemsize_flipped = None - for dtype in available_dtypes: - key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype) - if "latents" + key_suffix in tensor_dict: - if available_itemsize is None: - available_itemsize = dtype.itemsize - elif available_itemsize > dtype.itemsize: - # if higher precision latents are already cached, remove lower precision latents - del tensor_dict["latents" + key_suffix] - - if "latents_flipped" + key_suffix in tensor_dict: - if available_itemsize_flipped is None: - available_itemsize_flipped = dtype.itemsize - elif available_itemsize_flipped > dtype.itemsize: - del tensor_dict["latents_flipped" + key_suffix] + suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None) + remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype]) save_file(tensor_dict, cache_path, metadata=metadata) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 2439acc04..135ba76c4 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -5,9 +5,6 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast -from library import flux_utils, train_util -from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy - from library.utils import setup_logging setup_logging() @@ -15,6 +12,8 @@ logger = logging.getLogger(__name__) +from library import flux_utils, train_util, utils +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" @@ -86,64 +85,56 @@ def encode_tokens( class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): - FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + KEYS = ["l_pooled", "t5_out", "txt_ids"] + KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"] def __init__( self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, + max_token_length: int, + masked: bool, is_partial: bool = False, - apply_t5_attn_mask: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) - self.apply_t5_attn_mask = apply_t5_attn_mask + super().__init__( + FluxLatentsCachingStrategy.ARCHITECTURE, + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + max_token_length, + masked, + is_partial, + ) self.warn_fp8_weights = False - def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - - def is_disk_cached_outputs_expected(self, npz_path: str): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - try: - npz = np.load(npz_path) - if "l_pooled" not in npz: - return False - if "t5_out" not in npz: - return False - if "txt_ids" not in npz: - return False - if "t5_attn_mask" not in npz: - return False - if "apply_t5_attn_mask" not in npz: - return False - npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] - if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: - data = np.load(npz_path) - l_pooled = data["l_pooled"] - t5_out = data["t5_out"] - txt_ids = data["txt_ids"] - t5_attn_mask = data["t5_attn_mask"] - # apply_t5_attn_mask should be same as self.apply_t5_attn_mask + def is_disk_cached_outputs_expected( + self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]] + ): + keys = FluxTextEncoderOutputsCachingStrategy.KEYS + if self.masked: + keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED + return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype) + + def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]: + l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys( + cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS + ) + if self.masked: + t5_attn_mask = self.load_from_disk_for_keys( + cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED + )[0] + else: + t5_attn_mask = None return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: list[tuple[utils.ImageInfo, int, str]], ): if not self.warn_fp8_weights: if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: @@ -154,44 +145,38 @@ def cache_batch_outputs( self.warn_fp8_weights = True flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy - captions = [info.caption for info in infos] + captions = [caption for _, _, caption in batch] tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) - if l_pooled.dtype == torch.bfloat16: - l_pooled = l_pooled.float() - if t5_out.dtype == torch.bfloat16: - t5_out = t5_out.float() - if txt_ids.dtype == torch.bfloat16: - txt_ids = txt_ids.float() + l_pooled = l_pooled.cpu() + t5_out = t5_out.cpu() + txt_ids = txt_ids.cpu() + t5_attn_mask = tokens_and_masks[2].cpu() - l_pooled = l_pooled.cpu().numpy() - t5_out = t5_out.cpu().numpy() - txt_ids = txt_ids.cpu().numpy() - t5_attn_mask = tokens_and_masks[2].cpu().numpy() + keys = FluxTextEncoderOutputsCachingStrategy.KEYS + if self.masked: + keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED - for i, info in enumerate(infos): + for i, (info, caption_index, caption) in enumerate(batch): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] t5_attn_mask_i = t5_attn_mask[i] - apply_t5_attn_mask_i = self.apply_t5_attn_mask if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - l_pooled=l_pooled_i, - t5_out=t5_out_i, - txt_ids=txt_ids_i, - t5_attn_mask=t5_attn_mask_i, - apply_t5_attn_mask=apply_t5_attn_mask_i, - ) + outputs = [l_pooled_i, t5_out_i, txt_ids_i] + if self.masked: + outputs += [t5_attn_mask_i] + self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs) else: # it's fine that attn mask is not None. it's overwritten before calling the model if necessary - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + while len(info.text_encoder_outputs) <= caption_index: + info.text_encoder_outputs.append(None) + info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i] class FluxLatentsCachingStrategy(LatentsCachingStrategy): @@ -215,8 +200,7 @@ def load_latents_from_disk( ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: return self._default_load_latents_from_disk(cache_path, bucket_reso) - # TODO remove circular dependency for ImageInfo - def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") vae_device = vae.device vae_dtype = vae.dtype diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 65ea294b0..d8a860f42 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -4,8 +4,6 @@ import torch from transformers import CLIPTokenizer -from library import train_util -from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy from library.utils import setup_logging setup_logging() @@ -13,6 +11,8 @@ logger = logging.getLogger(__name__) +from library import train_util, utils +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy TOKENIZER_ID = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ @@ -157,8 +157,7 @@ def load_latents_from_disk( ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: return self._default_load_latents_from_disk(cache_path, bucket_reso) - # TODO remove circular dependency for ImageInfo - def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() vae_device = vae.device vae_dtype = vae.dtype diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 2ad5288b3..384103c61 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -6,10 +6,6 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel -from library import sd3_utils, train_util -from library import sd3_models -from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy - from library.utils import setup_logging setup_logging() @@ -17,6 +13,9 @@ logger = logging.getLogger(__name__) +from library import train_util, utils +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -254,7 +253,8 @@ def concat_encodings( class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): - SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + KEYS = ["lg_out", "t5_out", "lg_pooled"] + KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"] def __init__( self, @@ -262,70 +262,51 @@ def __init__( batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False, - apply_lg_attn_mask: bool = False, - apply_t5_attn_mask: bool = False, + max_token_length: int = 256, + masked: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) - self.apply_lg_attn_mask = apply_lg_attn_mask - self.apply_t5_attn_mask = apply_t5_attn_mask - - def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - - def is_disk_cached_outputs_expected(self, npz_path: str): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - try: - npz = np.load(npz_path) - if "lg_out" not in npz: - return False - if "lg_pooled" not in npz: - return False - if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used - return False - if "apply_lg_attn_mask" not in npz: - return False - if "t5_out" not in npz: - return False - if "t5_attn_mask" not in npz: - return False - npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] - if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: - return False - if "apply_t5_attn_mask" not in npz: - return False - npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] - if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: - data = np.load(npz_path) - lg_out = data["lg_out"] - lg_pooled = data["lg_pooled"] - t5_out = data["t5_out"] - - l_attn_mask = data["clip_l_attn_mask"] - g_attn_mask = data["clip_g_attn_mask"] - t5_attn_mask = data["t5_attn_mask"] - - # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + """ + apply_lg_attn_mask and apply_t5_attn_mask must be same + """ + super().__init__( + Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + max_token_length, + masked=masked, + is_partial=is_partial, + ) + + def is_disk_cached_outputs_expected( + self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]] + ) -> bool: + keys = Sd3TextEncoderOutputsCachingStrategy.KEYS + if self.masked: + keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED + return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype) + + def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]: + lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys( + cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS + ) + if self.masked: + l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys( + cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED + ) + else: + l_attn_mask = g_attn_mask = t5_attn_mask = None return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def cache_batch_outputs( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: list[tuple[utils.ImageInfo, int, str]], ): sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy - captions = [info.caption for info in infos] + captions = [caption for _, _, caption in batch] tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): @@ -334,51 +315,47 @@ def cache_batch_outputs( tokenize_strategy, models, tokens_and_masks, - apply_lg_attn_mask=self.apply_lg_attn_mask, - apply_t5_attn_mask=self.apply_t5_attn_mask, + apply_lg_attn_mask=self.masked, + apply_t5_attn_mask=self.masked, enable_dropout=False, ) - if lg_out.dtype == torch.bfloat16: - lg_out = lg_out.float() - if lg_pooled.dtype == torch.bfloat16: - lg_pooled = lg_pooled.float() - if t5_out.dtype == torch.bfloat16: - t5_out = t5_out.float() + lg_out = lg_out.cpu() + lg_pooled = lg_pooled.cpu() + t5_out = t5_out.cpu() - lg_out = lg_out.cpu().numpy() - lg_pooled = lg_pooled.cpu().numpy() - t5_out = t5_out.cpu().numpy() + l_attn_mask = tokens_and_masks[3].cpu() + g_attn_mask = tokens_and_masks[4].cpu() + t5_attn_mask = tokens_and_masks[5].cpu() - l_attn_mask = tokens_and_masks[3].cpu().numpy() - g_attn_mask = tokens_and_masks[4].cpu().numpy() - t5_attn_mask = tokens_and_masks[5].cpu().numpy() - - for i, info in enumerate(infos): + keys = Sd3TextEncoderOutputsCachingStrategy.KEYS + if self.masked: + keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED + for i, (info, caption_index, caption) in enumerate(batch): lg_out_i = lg_out[i] t5_out_i = t5_out[i] lg_pooled_i = lg_pooled[i] l_attn_mask_i = l_attn_mask[i] g_attn_mask_i = g_attn_mask[i] t5_attn_mask_i = t5_attn_mask[i] - apply_lg_attn_mask = self.apply_lg_attn_mask - apply_t5_attn_mask = self.apply_t5_attn_mask if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - lg_out=lg_out_i, - lg_pooled=lg_pooled_i, - t5_out=t5_out_i, - clip_l_attn_mask=l_attn_mask_i, - clip_g_attn_mask=g_attn_mask_i, - t5_attn_mask=t5_attn_mask_i, - apply_lg_attn_mask=apply_lg_attn_mask, - apply_t5_attn_mask=apply_t5_attn_mask, - ) + outputs = [lg_out_i, t5_out_i, lg_pooled_i] + if self.masked: + outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i] + self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs) else: # it's fine that attn mask is not None. it's overwritten before calling the model if necessary - info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + while len(info.text_encoder_outputs) <= caption_index: + info.text_encoder_outputs.append(None) + info.text_encoder_outputs[caption_index] = [ + lg_out_i, + t5_out_i, + lg_pooled_i, + l_attn_mask_i, + g_attn_mask_i, + t5_attn_mask_i, + ] class Sd3LatentsCachingStrategy(LatentsCachingStrategy): @@ -402,8 +379,7 @@ def load_latents_from_disk( ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: return self._default_load_latents_from_disk(cache_path, bucket_reso) - # TODO remove circular dependency for ImageInfo - def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") vae_device = vae.device vae_dtype = vae.dtype diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6b3e2afa6..f1b5abd05 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -4,8 +4,6 @@ import numpy as np import torch from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection -from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy - from library.utils import setup_logging @@ -14,6 +12,8 @@ logger = logging.getLogger(__name__) +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy +from library import utils TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -21,6 +21,9 @@ class SdxlTokenizeStrategy(TokenizeStrategy): def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225 + """ self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 @@ -220,51 +223,51 @@ def encode_tokens_with_weights( class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): - SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + ARCHITECTURE_SDXL = "sdxl" + KEYS = ["hidden_state1", "hidden_state2", "pool2"] def __init__( self, cache_to_disk: bool, - batch_size: int, + batch_size: Optional[int], skip_disk_cache_validity_check: bool, + max_token_length: Optional[int] = None, is_partial: bool = False, is_weighted: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) - - def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - - def is_disk_cached_outputs_expected(self, npz_path: str): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - try: - npz = np.load(npz_path) - if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: - data = np.load(npz_path) - hidden_state1 = data["hidden_state1"] - hidden_state2 = data["hidden_state2"] - pool2 = data["pool2"] - return [hidden_state1, hidden_state2, pool2] + """ + max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225 + """ + max_token_length = max_token_length or 75 + super().__init__( + SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL, + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + is_partial, + is_weighted, + max_token_length=max_token_length, + ) + + def is_disk_cached_outputs_expected( + self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]] + ) -> bool: + # SDXL does not support attn mask + base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS + return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype) + + def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]: + return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS) def cache_batch_outputs( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: list[tuple[utils.ImageInfo, int, str]], ): sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy - captions = [info.caption for info in infos] + captions = [caption for _, _, caption in batch] if self.is_weighted: tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) @@ -279,28 +282,24 @@ def cache_batch_outputs( tokenize_strategy, models, [tokens1, tokens2] ) - if hidden_state1.dtype == torch.bfloat16: - hidden_state1 = hidden_state1.float() - if hidden_state2.dtype == torch.bfloat16: - hidden_state2 = hidden_state2.float() - if pool2.dtype == torch.bfloat16: - pool2 = pool2.float() - - hidden_state1 = hidden_state1.cpu().numpy() - hidden_state2 = hidden_state2.cpu().numpy() - pool2 = pool2.cpu().numpy() + hidden_state1 = hidden_state1.cpu() + hidden_state2 = hidden_state2.cpu() + pool2 = pool2.cpu() - for i, info in enumerate(infos): + for i, (info, caption_index, caption) in enumerate(batch): hidden_state1_i = hidden_state1[i] hidden_state2_i = hidden_state2[i] pool2_i = pool2[i] if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - hidden_state1=hidden_state1_i, - hidden_state2=hidden_state2_i, - pool2=pool2_i, + self.save_outputs_to_disk( + info.text_encoder_outputs_cache_path, + caption_index, + caption, + SdxlTextEncoderOutputsCachingStrategy.KEYS, + [hidden_state1_i, hidden_state2_i, pool2_i], ) else: - info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] + while len(info.text_encoder_outputs) <= caption_index: + info.text_encoder_outputs.append(None) + info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index 02770ca3b..b9c43ebdd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -83,7 +83,7 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, pil_resize, ImageInfo setup_logging() import logging @@ -146,36 +146,6 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -class ImageInfo: - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: Tuple[int, int] = None - self.resized_size: Tuple[int, int] = None - self.bucket_reso: Tuple[int, int] = None - self.latents: Optional[torch.Tensor] = None - self.latents_flipped: Optional[torch.Tensor] = None - self.latents_cache_path: Optional[str] = None # set in cache_latents - self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - # crop left top right bottom in original pixel size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int]] = None - self.cond_img_path: Optional[str] = None - self.image: Optional[Image.Image] = None # optional, original PIL Image - self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs - - # new - self.text_encoder_outputs: Optional[List[torch.Tensor]] = None - # old - self.text_encoder_outputs1: Optional[torch.Tensor] = None - self.text_encoder_outputs2: Optional[torch.Tensor] = None - self.text_encoder_pool2: Optional[torch.Tensor] = None - - self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime - - class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: if max_size is not None: @@ -751,116 +721,111 @@ def enable_XTI(self, layers=None, token_strings=None): def add_replacement(self, str_from, str_to): self.replacements[str_from] = str_to - def process_caption(self, subset: BaseSubset, caption): - # caption に prefix/suffix を付ける - if subset.caption_prefix: - caption = subset.caption_prefix + " " + caption - if subset.caption_suffix: - caption = caption + " " + subset.caption_suffix - - # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + def process_caption(self, subset: BaseSubset, caption: str, tags: Optional[str]) -> str: + # drop out caption is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate is_drop_out = ( is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 ) - if is_drop_out: - caption = "" - else: - # process wildcards - if subset.enable_wildcard: - # if caption is multiline, random choice one line - if "\n" in caption: - caption = random.choice(caption.split("\n")) - - # wildcard is like '{aaa|bbb|ccc...}' - # escape the curly braces like {{ or }} - replacer1 = "⦅" - replacer2 = "⦆" - while replacer1 in caption or replacer2 in caption: - replacer1 += "⦅" - replacer2 += "⦆" - - caption = caption.replace("{{", replacer1).replace("}}", replacer2) - - # replace the wildcard - def replace_wildcard(match): - return random.choice(match.group(1).split("|")) - - caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) - - # unescape the curly braces - caption = caption.replace(replacer1, "{").replace(replacer2, "}") + return "" + + # add prefix and suffix for caption + # DreamBooth: treated as tags, FineTuning: treated as caption, tags are processed separately + if subset.caption_prefix: + caption = subset.caption_prefix + " " + caption + if subset.caption_suffix: + caption = caption + " " + subset.caption_suffix + + # shuffle tags + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + if tags is None and caption is not None: # DreamBooth method + tags = caption + caption = "" + + fixed_tokens = [] + flex_tokens = [] + fixed_suffix_tokens = [] + if hasattr(subset, "keep_tokens_separator") and subset.keep_tokens_separator and subset.keep_tokens_separator in tags: + fixed_part, flex_part = tags.split(subset.keep_tokens_separator, 1) + if subset.keep_tokens_separator in flex_part: + flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) + fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] + + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] else: - # if caption is multiline, use the first line - caption = caption.split("\n")[0] - - if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - fixed_tokens = [] - flex_tokens = [] - fixed_suffix_tokens = [] - if ( - hasattr(subset, "keep_tokens_separator") - and subset.keep_tokens_separator - and subset.keep_tokens_separator in caption - ): - fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) - if subset.keep_tokens_separator in flex_part: - flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) - fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] - - fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] - flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] - else: - tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - - if subset.token_warmup_step < 1: # 初回に上書きする - subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) - if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = ( - math.floor( - (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) - ) - + subset.token_warmup_min - ) - flex_tokens = flex_tokens[:tokens_len] - - def dropout_tags(tokens): - if subset.caption_tag_dropout_rate <= 0: - return tokens - l = [] - for token in tokens: - if random.random() >= subset.caption_tag_dropout_rate: - l.append(token) - return l - - if subset.shuffle_caption: - random.shuffle(flex_tokens) - - flex_tokens = dropout_tags(flex_tokens) - - caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) - - # process secondary separator - if subset.secondary_separator: - caption = caption.replace(subset.secondary_separator, subset.caption_separator) - - # textual inversion対応 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) - else: - caption = str_to + tokens = [t.strip() for t in tags.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + + if subset.token_warmup_step < 1: # 初回に上書きする + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = ( + math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) + flex_tokens = flex_tokens[:tokens_len] + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + tags = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) + + if tags is not None: + caption = caption + " " + tags + + # process wildcards + if subset.enable_wildcard: + # wildcard is like '{aaa|bbb|ccc...}' + # escape the curly braces like {{ or }} + replacer1 = "⦅" + replacer2 = "⦆" + while replacer1 in caption or replacer2 in caption: + replacer1 += "⦅" + replacer2 += "⦆" + + caption = caption.replace("{{", replacer1).replace("}}", replacer2) + + # replace the wildcard + def replace_wildcard(match): + return random.choice(match.group(1).split("|")) + + caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) + + # unescape the curly braces + caption = caption.replace(replacer1, "{").replace(replacer2, "}") + + # process secondary separator + if subset.secondary_separator: + caption = caption.replace(subset.secondary_separator, subset.caption_separator) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) else: - caption = caption.replace(str_from, str_to) + caption = str_to + else: + caption = caption.replace(str_from, str_to) return caption @@ -1171,24 +1136,28 @@ def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Acceler for i, info in enumerate(tqdm(image_infos)): # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + cache_path = caching_strategy.get_cache_path(info.absolute_path) + info.text_encoder_outputs_cache_path = cache_path # set npz filename regardless of cache availability # if the modulo of num_processes is not equal to process_index, skip caching # this makes each process cache different text encoder outputs if i % num_processes != process_index: continue - cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + cache_available = caching_strategy.is_disk_cached_outputs_expected(cache_path) if cache_available: # do not add to batch continue - batch.append(info) + for j, caption in enumerate(info.captions): + # do not recommend to use tags when caching text encoder outputs + if info.list_of_tags is not None and len(info.list_of_tags) > 0: + caption = caption + " " + info.list_of_tags[j % len(info.list_of_tags)] + batch.append((info, j, caption)) # if number of data in batch is enough, flush the batch - if len(batch) >= batch_size: - batches.append(batch) - batch = [] + while len(batch) >= batch_size: + batches.append(batch[:batch_size]) + batch = batch[batch_size:] if len(batch) > 0: batches.append(batch) @@ -1413,54 +1382,43 @@ def __getitem__(self, index): flippeds.append(flipped) # captionとtext encoder outputを処理する - caption = image_info.caption # default - tokenization_required = ( self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial ) text_encoder_outputs = None input_ids = None + caption = "" if image_info.text_encoder_outputs is not None: - # cached + # cached on memory text_encoder_outputs = image_info.text_encoder_outputs - elif image_info.text_encoder_outputs_npz is not None: + if len(text_encoder_outputs) == 1: + text_encoder_outputs = text_encoder_outputs[0] + else: + text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0] + elif image_info.text_encoder_outputs_cache_path is not None: # on disk - text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( - image_info.text_encoder_outputs_npz + index = 0 + if len(image_info.captions) > 1: + index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0] + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk( + image_info.text_encoder_outputs_cache_path, index ) else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - caption = self.process_caption(subset, image_info.caption) + caption = "" + tags = None # None if no tags in dataset metadata or Dreambooth method is used + if image_info.captions is not None and len(image_info.captions) > 0: + # captions_weights may be None + caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0] + if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0: + tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0] + + caption = self.process_caption(subset, caption, tags) input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension - # if self.XTI_layers: - # caption_layer = [] - # for layer in self.XTI_layers: - # token_strings_from = " ".join(self.token_strings) - # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - # caption_ = caption.replace(token_strings_from, token_strings_to) - # caption_layer.append(caption_) - # captions.append(caption_layer) - # else: - # captions.append(caption) - - # if not self.token_padding_disabled: # this option might be omitted in future - # # TODO get_input_ids must support SD3 - # if self.XTI_layers: - # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - # else: - # token_caption = self.get_input_ids(caption, self.tokenizers[0]) - # input_ids_list.append(token_caption) - - # if len(self.tokenizers) > 1: - # if self.XTI_layers: - # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - # else: - # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - # input_ids2_list.append(token_caption2) input_ids_list.append(input_ids) captions.append(caption) @@ -1798,7 +1756,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_train_images += subset.num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + captions = caption.split("\n") # empty line is allowed + info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: diff --git a/library/utils.py b/library/utils.py index 0a0333eb4..03394c361 100644 --- a/library/utils.py +++ b/library/utils.py @@ -21,6 +21,41 @@ def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +class ImageInfo: + def __init__( + self, image_key: str, num_repeats: int, captions: Optional[Union[str, list[str]]], is_reg: bool, absolute_path: str + ) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.captions: Optional[list[str]] = None if captions is None else ([captions] if isinstance(captions, str) else captions) + self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling + self.list_of_tags: Optional[list[str]] = None + self.tags_weights: Optional[list[float]] = None + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_cache_path: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + # crop left top right bottom in original pixel size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = None + self.cond_img_path: Optional[str] = None + self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached + self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None + # old + self.text_encoder_outputs1: Optional[torch.Tensor] = None + self.text_encoder_outputs2: Optional[torch.Tensor] = None + self.text_encoder_pool2: Optional[torch.Tensor] = None + + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + + # region Logging diff --git a/sd3_train.py b/sd3_train.py index c9b9783a9..fbbdf8bdb 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -75,6 +75,12 @@ def train(args): ) args.cache_text_encoder_outputs = True + if args.cache_text_encoder_outputs: + assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, ( + "apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs" + " / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります" + ) + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" @@ -168,8 +174,8 @@ def train(args): args.text_encoder_batch_size, False, False, - False, - False, + args.t5xxl_max_token_length, + args.apply_lg_attn_mask, ) ) train_dataset_group.set_current_strategies() @@ -278,8 +284,8 @@ def train(args): args.text_encoder_batch_size, args.skip_cache_check, train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial + args.t5xxl_max_token_length, args.apply_lg_attn_mask, - args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) diff --git a/sd3_train_network.py b/sd3_train_network.py index 1726e325f..164f834a2 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -43,6 +43,10 @@ def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): assert ( train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, ( + "apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs" + " / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります" + ) # prepare CLIP-L/CLIP-G/T5XXL training flags self.train_clip = not args.network_train_unet_only @@ -183,8 +187,8 @@ def get_text_encoder_outputs_caching_strategy(self, args): args.text_encoder_batch_size, args.skip_cache_check, is_partial=self.train_clip or self.train_t5xxl, + max_token_length=args.t5xxl_max_token_length, apply_lg_attn_mask=args.apply_lg_attn_mask, - apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None diff --git a/sdxl_train.py b/sdxl_train.py index aa21d669c..b62d10b5b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -321,7 +321,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, + None, + args.skip_cache_check, + args.max_token_length, + is_weighted=args.weighted_captions, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 315f20b56..50f2cf20d 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -223,7 +223,11 @@ def unwrap_model(model): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, + None, + args.skip_cache_check, + args.max_token_length, + is_weighted=args.weighted_captions, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index a80c64372..aa3fd6623 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -195,7 +195,11 @@ def train(args): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, + None, + args.skip_cache_check, + args.max_token_length, + is_weighted=args.weighted_captions, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d45df6e05..3730f1216 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -81,7 +81,11 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, + None, + args.skip_cache_check, + args.max_tolen_length, + is_weighted=args.weighted_captions, ) else: return None