Skip to content

Commit

Permalink
Merge pull request #1660 from kohya-ss/fast_image_sizes
Browse files Browse the repository at this point in the history
Fast image sizes
  • Loading branch information
kohya-ss authored Oct 13, 2024
2 parents e277b57 + c65cf38 commit d02a6ef
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
7 changes: 6 additions & 1 deletion library/strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,14 @@ def cache_to_disk(self):
def batch_size(self):
return self._batch_size

def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
@property
def cache_suffix(self):
raise NotImplementedError

def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)

def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError

Expand Down
9 changes: 3 additions & 6 deletions library/strategy_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)

def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX

def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
Expand Down
12 changes: 4 additions & 8 deletions library/strategy_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,10 @@ def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cac
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)

def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
# does not include old npz
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)

@property
def cache_suffix(self) -> str:
return self.suffix

def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
Expand Down
9 changes: 3 additions & 6 deletions library/strategy_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)

def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX

def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
Expand Down
23 changes: 22 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,9 +1828,30 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None:
logger.info("get image size from name of cache files")

# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort()
npz_path_index = 0

size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
l = len(os.path.splitext(img_path)[0]) # remove extension
found = False
while npz_path_index < len(npz_paths): # until found or end of npz_paths
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
if npz_paths[npz_path_index][:l] > img_path[:l]:
break
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
found = True
break
npz_path_index += 1 # next npz_path

if found:
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
else:
w, h = None, None

if w is not None and h is not None:
sizes[i] = [w, h]
size_set_count += 1
Expand Down

0 comments on commit d02a6ef

Please sign in to comment.