diff --git a/fine_tune.py b/fine_tune.py index 0090bd190..b1668d657 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -177,7 +177,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator) + train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 019c737a6..8e2d0b052 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -180,7 +180,7 @@ def process_batch(is_last): # バッチへ追加 image_info = train_util.ImageInfo(image_key, 1, "", False, image_path) - image_info.latents_npz = npz_file_name + image_info.latents_cache_path = npz_file_name image_info.bucket_reso = reso image_info.resized_size = resized_size image_info.image = image diff --git a/flux_train.py b/flux_train.py index a89e2f139..a7d38c58d 100644 --- a/flux_train.py +++ b/flux_train.py @@ -198,7 +198,7 @@ def train(args): ae.requires_grad_(False) ae.eval() - train_dataset_group.new_cache_latents(ae, accelerator) + train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision) ae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/library/strategy_base.py b/library/strategy_base.py index 358e42f1d..6bcf3dc6a 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -2,9 +2,10 @@ import os import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +from safetensors.torch import safe_open, save_file import torch from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection @@ -12,6 +13,7 @@ # TODO remove circular import by moving ImageInfo to a separate file # from library.train_util import ImageInfo +from library import utils from library.utils import setup_logging setup_logging() @@ -20,6 +22,27 @@ logger = logging.getLogger(__name__) +def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]: + if dtype is None: + # all dtypes are acceptable + return get_available_dtypes() + + dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype + compatible_dtypes = [torch.float32] + if dtype.itemsize == 1: # fp8 + compatible_dtypes.append(torch.bfloat16) + compatible_dtypes.append(torch.float16) + compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8 + return compatible_dtypes + + +def get_available_dtypes() -> List[torch.dtype]: + """ + Returns the list of available dtypes for latents caching. Higher precision is preferred. + """ + return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] + + class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class @@ -382,11 +405,18 @@ class LatentsCachingStrategy: _strategy = None # strategy instance: actual strategy class - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__( + self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: + self._architecture = architecture + self._latents_stride = latents_stride self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self.load_version_warning_printed = False + self.save_version_warning_printed = False + @classmethod def set_strategy(cls, strategy): if cls._strategy is not None: @@ -397,6 +427,14 @@ def set_strategy(cls, strategy): def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: return cls._strategy + @property + def architecture(self): + return self._architecture + + @property + def latents_stride(self): + return self._latents_stride + @property def cache_to_disk(self): return self._cache_to_disk @@ -407,69 +445,143 @@ def batch_size(self): @property def cache_suffix(self): - raise NotImplementedError + return f"_{self.architecture.lower()}.safetensors" - def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: - w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x") return int(w), int(h) - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - raise NotImplementedError + def get_latents_cache_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + self, + bucket_reso: Tuple[int, int], + cache_path: str, + flip_aug: bool, + alpha_mask: bool, + preferred_dtype: Optional[Union[str, torch.dtype]], ) -> bool: raise NotImplementedError def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError + def get_key_suffix( + self, + bucket_reso: Optional[Tuple[int, int]] = None, + latents_size: Optional[Tuple[int, int]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + ) -> str: + """ + if dtype is None, it returns "_32x64" for example. + """ + if latents_size is not None: + expected_latents_size = latents_size # H, W + else: + # bucket_reso is (W, H) + expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W + + if dtype is None: + dtype_suffix = "" + else: + dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype) + + # e.g. "_32x64_float16", HxW, dtype + key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}" + + return key_suffix + + def get_compatible_latents_keys( + self, + keys: set[str], + dtype: Union[str, torch.dtype], + flip_aug: bool, + bucket_reso: Optional[Tuple[int, int]] = None, + latents_size: Optional[Tuple[int, int]] = None, + ) -> Tuple[Optional[str], Optional[str]]: + """ + bucket_reso is (W, H), latents_size is (H, W) + """ + + latents_key = None + flipped_latents_key = None + + compatible_dtypes = get_compatible_dtypes(dtype) + + for compat_dtype in compatible_dtypes: + key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype) + + if latents_key is None: + latents_key = "latents" + key_suffix + if latents_key not in keys: + latents_key = None + if flip_aug and flipped_latents_key is None: + flipped_latents_key = "latents_flipped" + key_suffix + if flipped_latents_key not in keys: + flipped_latents_key = None + + if latents_key is not None and (flipped_latents_key is not None or not flip_aug): + break + + return latents_key, flipped_latents_key + def _default_is_disk_cached_latents_expected( self, - latents_stride: int, bucket_reso: Tuple[int, int], - npz_path: str, + latents_cache_path: str, flip_aug: bool, alpha_mask: bool, - multi_resolution: bool = False, + preferred_dtype: Optional[Union[str, torch.dtype]], ): + # multi_resolution is always enabled for any strategy if not self.cache_to_disk: return False - if not os.path.exists(npz_path): + if not os.path.exists(latents_cache_path): return False if self.skip_disk_cache_validity_check: return True - expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) - - # e.g. "_32x64", HxW - key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None) try: - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - return False - if flip_aug and "latents_flipped" + key_reso_suffix not in npz: - return False - if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + # safe_open locks the file, so we cannot use it for checking keys + # with safe_open(latents_cache_path, framework="pt") as f: + # keys = f.keys() + with utils.MemoryEfficientSafeOpen(latents_cache_path) as f: + keys = f.keys() + + if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys: + # print(f"alpha_mask not found: {latents_cache_path}") return False + + if preferred_dtype is None: + # remove dtype suffix from keys, because any dtype is acceptable + keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)] + keys = set(keys) + if "latents" + key_suffix_without_dtype not in keys: + # print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}") + return False + if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys: + # print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}") + return False + else: + # specific dtype or compatible dtype is required + latents_key, flipped_latents_key = self.get_compatible_latents_keys( + keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso + ) + if latents_key is None or (flip_aug and flipped_latents_key is None): + # print(f"Precise dtype not found: {latents_cache_path}") + return False except Exception as e: - logger.error(f"Error loading file: {npz_path}") + logger.error(f"Error loading file: {latents_cache_path}") raise e return True # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, - encode_by_vae, - vae_device, - vae_dtype, - image_infos: List, - flip_aug: bool, - alpha_mask: bool, - random_crop: bool, - multi_resolution: bool = False, + self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -499,13 +611,8 @@ def _default_cache_batch_latents( original_size = original_sizes[i] crop_ltrb = crop_ltrbs[i] - latents_size = latents.shape[1:3] # H, W - key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW - if self.cache_to_disk: - self.save_latents_to_disk( - info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix - ) + self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) else: info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -515,56 +622,111 @@ def _default_cache_batch_latents( info.alpha_mask = alpha_mask def load_latents_from_disk( - self, npz_path: str, bucket_reso: Tuple[int, int] - ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - """ - for SD/SDXL - """ - return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + self, cache_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: + raise NotImplementedError def _default_load_latents_from_disk( - self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] - ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - if latents_stride is None: - key_reso_suffix = "" - else: - latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) - key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW - - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") - - latents = npz["latents" + key_reso_suffix] - original_size = npz["original_size" + key_reso_suffix].tolist() - crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() - flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None - alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + self, cache_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: + with safe_open(cache_path, framework="pt") as f: + metadata = f.metadata() + version = metadata.get("format_version", "0.0.0") + major, minor, patch = map(int, version.split(".")) + if major > 1: # or (major == 1 and minor > 0): + if not self.load_version_warning_printed: + self.load_version_warning_printed = True + logger.warning( + f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues." + ) + + keys = f.keys() + + latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso) + + key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None) + alpha_mask_key = "alpha_mask" + key_suffix_without_dtype + + latents = f.get_tensor(latents_key) + flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None + alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None + + original_size = [int(metadata["width"]), int(metadata["height"])] + crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype] + crop_ltrb = list(map(int, crop_ltrb.split(","))) + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( self, - npz_path, - latents_tensor, - original_size, - crop_ltrb, - flipped_latents_tensor=None, - alpha_mask=None, - key_reso_suffix="", + cache_path: str, + latents_tensor: torch.Tensor, + original_size: Tuple[int, int], + crop_ltrb: List[int], + flipped_latents_tensor: Optional[torch.Tensor] = None, + alpha_mask: Optional[torch.Tensor] = None, ): - kwargs = {} + dtype = latents_tensor.dtype + latents_size = latents_tensor.shape[1:3] # H, W + tensor_dict = {} + + overwrite = False + if os.path.exists(cache_path): + # load existing safetensors and update it + overwrite = True + + # we cannot use safe_open here because it locks the file + # with safe_open(cache_path, framework="pt") as f: + with utils.MemoryEfficientSafeOpen(cache_path) as f: + metadata = f.metadata() + keys = f.keys() + for key in keys: + tensor_dict[key] = f.get_tensor(key) + assert metadata["architecture"] == self.architecture + + file_version = metadata.get("format_version", "0.0.0") + major, minor, patch = map(int, file_version.split(".")) + if major > 1 or (major == 1 and minor > 0): + self.save_version_warning_printed = True + logger.warning( + f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues." + ) + else: + metadata = {} + metadata["architecture"] = self.architecture + metadata["width"] = f"{original_size[0]}" + metadata["height"] = f"{original_size[1]}" + metadata["format_version"] = "1.0.0" - if os.path.exists(npz_path): - # load existing npz and update it - npz = np.load(npz_path) - for key in npz.files: - kwargs[key] = npz[key] + metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb)) - kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() - kwargs["original_size" + key_reso_suffix] = np.array(original_size) - kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) + key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype) + if latents_tensor is not None: + tensor_dict["latents" + key_suffix] = latents_tensor if flipped_latents_tensor is not None: - kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() + tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor if alpha_mask is not None: - kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() - np.savez(npz_path, **kwargs) + key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None) + tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask + + # remove lower precision latents if higher precision latents are already cached + if overwrite: + available_dtypes = get_available_dtypes() + available_itemsize = None + available_itemsize_flipped = None + for dtype in available_dtypes: + key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype) + if "latents" + key_suffix in tensor_dict: + if available_itemsize is None: + available_itemsize = dtype.itemsize + elif available_itemsize > dtype.itemsize: + # if higher precision latents are already cached, remove lower precision latents + del tensor_dict["latents" + key_suffix] + + if "latents_flipped" + key_suffix in tensor_dict: + if available_itemsize_flipped is None: + available_itemsize_flipped = dtype.itemsize + elif available_itemsize_flipped > dtype.itemsize: + del tensor_dict["latents_flipped" + key_suffix] + + save_file(tensor_dict, cache_path, metadata=metadata) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5e65927f8..2439acc04 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -195,29 +195,25 @@ def cache_batch_outputs( class FluxLatentsCachingStrategy(LatentsCachingStrategy): - FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + ARCHITECTURE = "flux" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check) - @property - def cache_suffix(self) -> str: - return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + cache_path: str, + flip_aug: bool, + alpha_mask: bool, + preferred_dtype: Optional[torch.dtype] = None, + ): + return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype) def load_latents_from_disk( - self, npz_path: str, bucket_reso: Tuple[int, int] - ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + self, cache_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: + return self._default_load_latents_from_disk(cache_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -225,9 +221,7 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents( - encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True - ) + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68bf..65ea294b0 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -134,30 +134,28 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. # and we keep the old npz for the backward compatibility. - SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" - SD_LATENTS_NPZ_SUFFIX = "_sd.npz" - SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + ARCHITECTURE_SD = "sd" + ARCHITECTURE_SDXL = "sdxl" def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL + super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check) self.sd = sd - self.suffix = ( - SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX - ) - - @property - def cache_suffix(self) -> str: - return self.suffix - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - # support old .npz - old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX - if os.path.exists(old_npz_file): - return old_npz_file - return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + cache_path: str, + flip_aug: bool, + alpha_mask: bool, + preferred_dtype: Optional[torch.dtype] = None, + ) -> bool: + return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype) + + def load_latents_from_disk( + self, cache_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: + return self._default_load_latents_from_disk(cache_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 1d55fe21d..2ad5288b3 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -382,29 +382,25 @@ def cache_batch_outputs( class Sd3LatentsCachingStrategy(LatentsCachingStrategy): - SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + ARCHITECTURE_SD3 = "sd3" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check) - @property - def cache_suffix(self) -> str: - return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + cache_path: str, + flip_aug: bool, + alpha_mask: bool, + preferred_dtype: Optional[torch.dtype] = None, + ): + return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype) def load_latents_from_disk( - self, npz_path: str, bucket_reso: Tuple[int, int] - ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + self, cache_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]: + return self._default_load_latents_from_disk(cache_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -412,9 +408,7 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents( - encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True - ) + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index 25cf7640d..02770ca3b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -158,11 +158,10 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.bucket_reso: Tuple[int, int] = None self.latents: Optional[torch.Tensor] = None self.latents_flipped: Optional[torch.Tensor] = None - self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_cache_path: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( - None # crop left top right bottom in original pixel size, not latents size - ) + # crop left top right bottom in original pixel size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = None self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs @@ -323,6 +322,9 @@ def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): else: resized_width = bucket_reso[0] resized_height = bucket_reso[0] / image_ar + resized_width = int(resized_width + 0.5) + resized_height = int(resized_height + 0.5) + crop_left = (bucket_reso[0] - resized_width) // 2 crop_top = (bucket_reso[1] - resized_height) // 2 crop_right = crop_left + resized_width @@ -1040,7 +1042,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, model: Any, accelerator: Accelerator): + def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool = False): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1094,17 +1096,18 @@ def submit_batch(batch, cond): try: # iterate images - logger.info("caching latents...") + logger.info(f"Caching latents for dataset with {len(image_infos)} images.") + preferred_dtype = model.dtype if force_cache_precision else None for i, info in enumerate(tqdm(image_infos)): subset = self.image_to_subset[info.image_key] - if info.latents_npz is not None: # fine tuning dataset + if info.latents_cache_path is not None: # fine tuning dataset continue # check disk cache exists and size of latents if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size) # if the modulo of num_processes is not equal to process_index, skip caching # this makes each process cache different latents @@ -1114,7 +1117,7 @@ def submit_batch(batch, cond): # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + info.bucket_reso, info.latents_cache_path, subset.flip_aug, subset.alpha_mask, preferred_dtype ) if cache_available: # do not add to batch continue @@ -1144,81 +1147,6 @@ def submit_batch(batch, cond): finally: executor.shutdown() - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): - # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - logger.info("caching latents.") - - image_infos = list(self.image_data.values()) - - # sort by resolution - image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - - # split by resolution and some conditions - class Condition: - def __init__(self, reso, flip_aug, alpha_mask, random_crop): - self.reso = reso - self.flip_aug = flip_aug - self.alpha_mask = alpha_mask - self.random_crop = random_crop - - def __eq__(self, other): - return ( - self.reso == other.reso - and self.flip_aug == other.flip_aug - and self.alpha_mask == other.alpha_mask - and self.random_crop == other.random_crop - ) - - batches: List[Tuple[Condition, List[ImageInfo]]] = [] - batch: List[ImageInfo] = [] - current_condition = None - - logger.info("checking cache validity...") - for info in tqdm(image_infos): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: # fine tuning dataset - continue - - # check disk cache exists and size of latents - if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - if not is_main_process: # store to info only - continue - - cache_available = is_disk_cached_latents_is_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - - if cache_available: # do not add to batch - continue - - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] - - batch.append(info) - current_condition = condition - - # if number of data in batch is enough, flush the batch - if len(batch) >= vae_batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None - - if len(batch) > 0: - batches.append((current_condition, batch)) - - if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only - return - - # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. @@ -1275,131 +1203,6 @@ def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Acceler # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) - # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype - # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset - # to support SD1/2, it needs a flag for v2, but it is postponed - def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True - ): - assert len(tokenizers) == 2, "only support SDXL" - return self.cache_text_encoder_outputs_common( - tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process - ) - - # same as above, but for SD3 - def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None - ): - return self.cache_text_encoder_outputs_common( - [tokenizer], - text_encoders, - devices, - output_dtype, - te_dtypes, - cache_to_disk, - is_main_process, - TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, - batch_size, - ) - - def cache_text_encoder_outputs_common( - self, - tokenizers, - text_encoders, - devices, - output_dtype, - te_dtypes, - cache_to_disk=False, - is_main_process=True, - file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, - batch_size=None, - ): - # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する - # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - logger.info("caching text encoder outputs.") - - tokenize_strategy = TokenizeStrategy.get_strategy() - - if batch_size is None: - batch_size = self.batch_size - - image_infos = list(self.image_data.values()) - - logger.info("checking cache existence...") - image_infos_to_cache = [] - for info in tqdm(image_infos): - # subset = self.image_to_subset[info.image_key] - if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.text_encoder_outputs_npz = te_out_npz - - if not is_main_process: # store to info only - continue - - if os.path.exists(te_out_npz): - # TODO check varidity of cache here - continue - - image_infos_to_cache.append(info) - - if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only - return - - # prepare tokenizers and text encoders - for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): - text_encoder.to(device) - if te_dtype is not None: - text_encoder.to(dtype=te_dtype) - - # create batch - is_sd3 = len(tokenizers) == 1 - batch = [] - batches = [] - for info in image_infos_to_cache: - if not is_sd3: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) - batch.append((info, l_tokens, g_tokens, t5_tokens)) - - if len(batch) >= batch_size: - batches.append(batch) - batch = [] - - if len(batch) > 0: - batches.append(batch) - - # iterate batches: call text encoder and cache outputs for memory or disk - logger.info("caching text encoder outputs...") - if not is_sd3: - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype - ) - else: - for batch in tqdm(batches): - infos, l_tokens, g_tokens, t5_tokens = zip(*batch) - - # stack tokens - # l_tokens = [tokens[0] for tokens in l_tokens] - # g_tokens = [tokens[0] for tokens in g_tokens] - # t5_tokens = [tokens[0] for tokens in t5_tokens] - - cache_batch_text_encoder_outputs_sd3( - infos, - tokenizers[0], - text_encoders, - self.max_token_length, - cache_to_disk, - (l_tokens, g_tokens, t5_tokens), - output_dtype, - ) - def get_image_size(self, image_path): # return imagesize.get(image_path) image_size = imagesize.get(image_path) @@ -1522,17 +1325,14 @@ def __getitem__(self, index): alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None - elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + elif image_info.latents_cache_path is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_cache_path, image_info.bucket_reso) ) if flipped: latents = flipped_latents - alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1] del flipped_latents - latents = torch.FloatTensor(latents) - if alpha_mask is not None: - alpha_mask = torch.FloatTensor(alpha_mask) image = None else: @@ -1885,28 +1685,28 @@ def load_dreambooth_dir(subset: DreamBoothSubset): if strategy is not None: logger.info("get image size from name of cache files") - # make image path to npz path mapping - npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort( + # make image path to cache path mapping + cache_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + cache_paths.sort( key=lambda item: item.rsplit("_", maxsplit=2)[0] ) # sort by name excluding resolution and cache_suffix - npz_path_index = 0 + cache_path_index = 0 size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): l = len(os.path.splitext(img_path)[0]) # remove extension found = False - while npz_path_index < len(npz_paths): # until found or end of npz_paths + while cache_path_index < len(cache_paths): # until found or end of npz_paths # npz_paths are sorted, so if npz_path > img_path, img_path is not found - if npz_paths[npz_path_index][:l] > img_path[:l]: + if cache_paths[cache_path_index][:l] > img_path[:l]: break - if npz_paths[npz_path_index][:l] == img_path[:l]: # found + if cache_paths[cache_path_index][:l] == img_path[:l]: # found found = True break - npz_path_index += 1 # next npz_path + cache_path_index += 1 # next npz_path if found: - w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + w, h = strategy.get_image_size_from_disk_cache_path(img_path, cache_paths[cache_path_index]) else: w, h = None, None @@ -2139,8 +1939,8 @@ def __init__( image_info.image_size = img_md.get("train_resolution") if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + # if cache exists, use them + image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) self.register_image(image_info, subset) @@ -2161,7 +1961,7 @@ def __init__( for image_info in self.image_data.values(): subset = self.image_to_subset[image_info.image_key] - has_npz = image_info.latents_npz is not None + has_npz = image_info.latents_cache_path is not None npz_any = npz_any or has_npz if subset.flip_aug: @@ -2233,7 +2033,7 @@ def __init__( # npz情報をきれいにしておく if not use_npz_latents: for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None + image_info.latents_cache_path = image_info.latents_npz_flipped = None def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): base_name = os.path.splitext(image_key)[0] @@ -2382,11 +2182,8 @@ def make_buckets(self): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - - def new_cache_latents(self, model: Any, accelerator: Accelerator): - return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) + def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator, force_cache_precision) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2485,33 +2282,12 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): + def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - - def new_cache_latents(self, model: Any, accelerator: Accelerator): - for i, dataset in enumerate(self.datasets): - logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, accelerator) + dataset.new_cache_latents(model, accelerator, force_cache_precision) accelerator.wait_for_everyone() - def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True - ): - for i, dataset in enumerate(self.datasets): - logger.info(f"[Dataset {i}]") - dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) - - def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None - ): - for i, dataset in enumerate(self.datasets): - logger.info(f"[Dataset {i}]") - dataset.cache_text_encoder_outputs_sd3( - tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size - ) - def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") @@ -2556,72 +2332,6 @@ def disable_token_padding(self): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): - expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 - - if not os.path.exists(npz_path): - return False - - try: - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso - return False - else: - if "alpha_mask" in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - -# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) -# TODO update to use CachingStrategy -# def load_latents_from_disk( -# npz_path, -# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: -# npz = np.load(npz_path) -# if "latents" not in npz: -# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - -# latents = npz["latents"] -# original_size = npz["original_size"].tolist() -# crop_ltrb = npz["crop_ltrb"].tolist() -# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None -# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None -# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask - - -# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): -# kwargs = {} -# if flipped_latents_tensor is not None: -# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() -# if alpha_mask is not None: -# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() -# np.savez( -# npz_path, -# latents=latents_tensor.float().cpu().numpy(), -# original_size=np.array(original_size), -# crop_ltrb=np.array(crop_ltrb), -# **kwargs, -# ) - - def debug_dataset(train_dataset, show_input_ids=False): logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") logger.info( @@ -2865,19 +2575,19 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, list[Optional[torch.Tensor]], list[Tuple[int, int]], list[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] - alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + alpha_masks: List[torch.Tensor] = [torch.Size([H, W]), ...], List of None if not use_alpha_mask original_sizes: List[Tuple[int, int]] = [(W, H), ...] crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] """ images: List[torch.Tensor] = [] - alpha_masks: List[np.ndarray] = [] + alpha_masks: List[torch.Tensor] = [] original_sizes: List[Tuple[int, int]] = [] crop_ltrbs: List[Tuple[int, int, int, int]] = [] for info in image_infos: @@ -2907,158 +2617,6 @@ def load_images_and_masks_for_caching( return img_tensor, alpha_masks, original_sizes, crop_ltrbs -def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool -) -> None: - r""" - requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz - optionally requires image_infos to have: image - if cache_to_disk is True, set info.latents_npz - flipped latents is also saved if flip_aug is True - if cache_to_disk is False, set info.latents - latents_flipped is also set if flip_aug is True - latents_original_size and latents_crop_ltrb are also set - """ - images = [] - alpha_masks: List[np.ndarray] = [] - for info in image_infos: - image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - - info.latents_original_size = original_size - info.latents_crop_ltrb = crop_ltrb - - if use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 - alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] - else: - alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] - else: - alpha_mask = None - alpha_masks.append(alpha_mask) - - image = image[:, :, :3] # remove alpha channel if exists - image = IMAGE_TRANSFORMS(image) - images.append(image) - - img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) - - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - - if flip_aug: - img_tensors = torch.flip(img_tensors, dims=[3]) - with torch.no_grad(): - flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - else: - flipped_latents = [None] * len(latents) - - for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): - # check NaN - if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): - raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") - - if cache_to_disk: - # save_latents_to_disk( - # info.latents_npz, - # latent, - # info.latents_original_size, - # info.latents_crop_ltrb, - # flipped_latent, - # alpha_mask, - # ) - pass - else: - info.latents = latent - if flip_aug: - info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - - if not HIGH_VRAM: - clean_memory_on_device(vae.device) - - -def cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype -): - input_ids1 = input_ids1.to(text_encoders[0].device) - input_ids2 = input_ids2.to(text_encoders[1].device) - - with torch.no_grad(): - b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( - max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - dtype, - ) - - # ここでcpuに移動しておかないと、上書きされてしまう - b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 - b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 - b_pool2 = b_pool2.detach().to("cpu") # b,1280 - - for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): - if cache_to_disk: - save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) - else: - info.text_encoder_outputs1 = hidden_state1 - info.text_encoder_outputs2 = hidden_state2 - info.text_encoder_pool2 = pool2 - - -def cache_batch_text_encoder_outputs_sd3( - image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype -): - # make input_ids for each text encoder - l_tokens, g_tokens, t5_tokens = input_ids - - clip_l, clip_g, t5xxl = text_encoders - with torch.no_grad(): - b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( - l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype - ) - b_lg_out = b_lg_out.detach() - b_t5_out = b_t5_out.detach() - b_pool = b_pool.detach() - - for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): - # debug: NaN check - if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): - raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") - - if cache_to_disk: - save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) - else: - info.text_encoder_outputs1 = lg_out - info.text_encoder_outputs2 = t5_out - info.text_encoder_pool2 = pool - - -def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): - np.savez( - npz_path, - hidden_state1=hidden_state1.cpu().float().numpy(), - hidden_state2=hidden_state2.cpu().float().numpy(), - pool2=pool2.cpu().float().numpy(), - ) - - -def load_text_encoder_outputs_from_disk(npz_path): - with np.load(npz_path) as f: - hidden_state1 = torch.from_numpy(f["hidden_state1"]) - hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None - pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None - return hidden_state1, hidden_state2, pool2 - - # endregion # region モジュール入れ替え部 @@ -4357,6 +3915,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--force_cache_precision", + action="store_true", + help="force cache precision to match the model precision. this option re-caches latents if the precision is lower than the model precision" + " / cacheの精度をモデルの精度に合わせる。このオプションを指定すると、精度がモデルの精度よりも低い場合にlatentを再キャッシュします", + ) parser.add_argument( "--skip_cache_check", action="store_true", @@ -5913,7 +5477,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") - names.append("text_encoder3") # SD3 + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/library/utils.py b/library/utils.py index 07079c6d9..0a0333eb4 100644 --- a/library/utils.py +++ b/library/utils.py @@ -189,6 +189,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) raise ValueError(f"Unsupported dtype: {s}") +def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str: + dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype + + # get name of the dtype + dtype_name = str(dtype).split(".")[-1] + + return dtype_name + + def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file @@ -264,8 +273,8 @@ class MemoryEfficientSafeOpen: # does not support metadata loading def __init__(self, filename): self.filename = filename - self.header, self.header_size = self._read_header() self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() def __enter__(self): return self @@ -276,6 +285,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") @@ -293,10 +305,9 @@ def get_tensor(self, key): return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): - with open(self.filename, "rb") as f: - header_size = struct.unpack(" None: # cache latents with dataset # TODO use DataLoader to speed up - train_dataset_group.new_cache_latents(vae, accelerator) + train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents to disk.") diff --git a/train_db.py b/train_db.py index 51e209f34..4038b94e0 100644 --- a/train_db.py +++ b/train_db.py @@ -156,7 +156,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator) + train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_network.py b/train_network.py index bbf381f99..4714e06a9 100644 --- a/train_network.py +++ b/train_network.py @@ -418,7 +418,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator) + train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5f4657eb9..3cc815cdc 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -378,7 +378,7 @@ def train(self, args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator) + train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone()