From 2244cf5b835cc35179f29b1babb4a2d19f54bfae Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 18:22:19 +0900 Subject: [PATCH] load images in parallel when caching latents --- library/train_util.py | 93 ++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e6b3408d..1db470d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -1058,7 +1059,6 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] batch: List[ImageInfo] = [] current_condition = None @@ -1066,57 +1066,70 @@ def __eq__(self, other): num_processes = accelerator.num_processes process_index = accelerator.process_index - logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) - if info.latents_npz is not None: # fine tuning dataset - continue + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: + if info.latents_npz is not None: # fine tuning dataset continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - batch.append(info) - current_condition = condition + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if len(batch) > 0: - batches.append((current_condition, batch)) + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - if len(batches) == 0: - logger.info("no latents to cache") - return + batch.append(info) + current_condition = condition - # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと