Skip to content

Commit

Permalink
load images in parallel when caching latents
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 13, 2024
1 parent 74228c9 commit 2244cf5
Showing 1 changed file with 53 additions and 40 deletions.
93 changes: 53 additions & 40 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
import datetime
import importlib
import json
Expand Down Expand Up @@ -1058,65 +1059,77 @@ def __eq__(self, other):
and self.random_crop == other.random_crop
)

batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None

# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index

logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
# define a function to submit a batch to cache
def submit_batch(batch, cond):
for info in batch:
if info.image is not None and isinstance(info.image, Future):
info.image = info.image.result() # future to image
caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop)

if info.latents_npz is not None: # fine tuning dataset
continue
# define ThreadPoolExecutor to load images in parallel
max_workers = min(os.cpu_count(), len(image_infos))
max_workers = max(1, max_workers // num_processes) # consider multi-gpu
max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size
executor = ThreadPoolExecutor(max_workers)

# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
try:
# iterate images
logger.info("caching latents...")
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]

# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
if info.latents_npz is not None: # fine tuning dataset
continue

# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)

cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
if i % num_processes != process_index:
continue

# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")

batch.append(info)
current_condition = condition
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue

# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append((current_condition, batch))
batch = []
current_condition = None
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
submit_batch(batch, current_condition)
batch = []

if len(batch) > 0:
batches.append((current_condition, batch))
if info.image is None:
# load image in parallel
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)

if len(batches) == 0:
logger.info("no latents to cache")
return
batch.append(info)
current_condition = condition

# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
current_condition = None

if len(batch) > 0:
submit_batch(batch, current_condition)

finally:
executor.shutdown()

def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
Expand Down

0 comments on commit 2244cf5

Please sign in to comment.