From 2238b94e7bb084b70a2b59ccf42452ffcc16ba8a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 28 Nov 2024 21:05:17 +0900 Subject: [PATCH] support new metadata in wd14tagger (WIP), fix typo --- finetune/tag_images_by_wd14_tagger.py | 305 ++++++++++++++++++++------ sdxl_train_network.py | 2 +- 2 files changed, 240 insertions(+), 67 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6b..f81945602 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,7 +1,13 @@ import argparse +from concurrent.futures import ThreadPoolExecutor import csv +import glob +import json import os from pathlib import Path +from typing import Any, Optional, Union +import zipfile +import tarfile import cv2 import numpy as np @@ -63,13 +69,90 @@ def __getitem__(self, idx): try: image = Image.open(img_path).convert("RGB") + size = image.size image = preprocess_image(image) # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) except Exception as e: logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None - return (image, img_path) + return (image, img_path, size) + + +class ArchiveImageLoader: + def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False): + self.archive_paths = archive_paths + self.batch_size = batch_size + self.debug = debug + self.current_archive = None + self.archive_index = 0 + self.image_index = 0 + self.files = None + self.executor = ThreadPoolExecutor() + self.image_exts = set(train_util.IMAGE_EXTENSIONS) + + def __iter__(self): + return self + + def __next__(self): + images = [] + while len(images) < self.batch_size: + if self.current_archive is None: + if self.archive_index >= len(self.archive_paths): + if len(images) == 0: + raise StopIteration + else: + break # return the remaining images + + if self.debug: + logger.info(f"loading archive: {self.archive_paths[self.archive_index]}") + + current_archive_path = self.archive_paths[self.archive_index] + if current_archive_path.endswith(".zip"): + self.current_archive = zipfile.ZipFile(current_archive_path) + self.files = self.current_archive.namelist() + elif current_archive_path.endswith(".tar"): + self.current_archive = tarfile.open(current_archive_path, "r") + self.files = self.current_archive.getnames() + else: + raise ValueError(f"unsupported archive file: {self.current_archive_path}") + + self.image_index = 0 + + # filter by image extensions + self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts] + + if self.debug: + logger.info(f"found {len(self.files)} images in the archive") + + while len(images) + len(new_images) < self.batch_size: + if self.image_index >= len(self.files): + break + + file = self.files[self.image_index] + archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}" + self.image_index += 1 + + def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]): + with archive.open(file) as f: + image = Image.open(f).convert("RGB") + size = image.size + image = preprocess_image(image) + return image, size + + new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive))) + + # wait for all new_images to load to close the archive + new_images = [(image_path, future.result()) for image_path, future in new_images] + + if self.image_index >= len(self.files): + self.current_archive.close() + self.current_archive = None + self.archive_index += 1 + + images.extend(new_images) + + return [(image_path, image, size) for image_path, (image, size) in images] def collate_fn_remove_corrupted(batch): @@ -149,15 +232,19 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, providers=(["OpenVINOExecutionProvider"]), - provider_options=[{'device_type' : "GPU_FP32"}], + provider_options=[{"device_type": "GPU_FP32"}], ) else: ort_sess = ort.InferenceSession( onnx_path, providers=( - ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else - ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else - ["CPUExecutionProvider"] + ["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ( + ["ROCMExecutionProvider"] + if "ROCMExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"] + ) ), ) else: @@ -203,7 +290,9 @@ def main(args): tag_replacements = escaped_tag_replacements.split(";") for tag_replacement in tag_replacements: tags = tag_replacement.split(",") # source, target - assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" + assert ( + len(tags) == 2 + ), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags] logger.info(f"replacing tag: {source} -> {target}") @@ -216,9 +305,15 @@ def main(args): rating_tags[rating_tags.index(source)] = target # 画像を読み込む - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") + if not args.load_archive: + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") + else: + archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob( + os.path.join(args.train_data_dir, "*.tar") + ) + image_paths = [Path(archive_file) for archive_file in archive_files] tag_freq = {} @@ -231,19 +326,23 @@ def main(args): if args.always_first_tags is not None: always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""] - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) + def run_batch( + list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]], + images_metadata: Optional[dict[str, Any]], + tags_index: Optional[int] = None, + ): + imgs = np.array([im for _, im, _ in list_of_path_img_size]) if args.onnx: # if len(imgs) < args.batch_size: # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy - probs = probs[: len(path_imgs)] + probs = probs[: len(list_of_path_img_size)] else: probs = model(imgs, training=False) probs = probs.numpy() - for (image_path, _), prob in zip(path_imgs, probs): + for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs): combined_tags = [] rating_tag_text = "" character_tag_text = "" @@ -265,7 +364,7 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - if args.character_tags_first: # insert to the beginning + if args.character_tags_first: # insert to the beginning combined_tags.insert(0, tag_name) else: combined_tags.append(tag_name) @@ -281,7 +380,7 @@ def run_batch(path_imgs): tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 rating_tag_text = found_rating if args.use_rating_tags: - combined_tags.insert(0, found_rating) # insert to the beginning + combined_tags.insert(0, found_rating) # insert to the beginning else: combined_tags.append(found_rating) @@ -304,12 +403,24 @@ def run_batch(path_imgs): tag_text = caption_separator.join(combined_tags) if args.append_tags: - # Check if file exists - if os.path.exists(caption_file): - with open(caption_file, "rt", encoding="utf-8") as f: - # Read file and remove new lines - existing_content = f.read().strip("\n") # Remove newlines - + existing_content = None + if images_metadata is None: + # Check if file exists + if os.path.exists(caption_file): + with open(caption_file, "rt", encoding="utf-8") as f: + # Read file and remove new lines + existing_content = f.read().strip("\n") # Remove newlines + else: + image_md = images_metadata.get(image_path, None) + if image_md is not None: + tags = image_md.get("tags", None) + if tags is not None: + if tags_index is None and len(tags) > 0: + existing_content = tags[-1] + elif tags_index is not None and tags_index < len(tags): + existing_content = tags[tags_index] + + if existing_content is not None: # Split the content into tags and store them in a list existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] @@ -319,19 +430,62 @@ def run_batch(path_imgs): # Create new tag_text tag_text = caption_separator.join(existing_tags + new_tags) - with open(caption_file, "wt", encoding="utf-8") as f: - f.write(tag_text + "\n") - if args.debug: - logger.info("") - logger.info(f"{image_path}:") - logger.info(f"\tRating tags: {rating_tag_text}") - logger.info(f"\tCharacter tags: {character_tag_text}") - logger.info(f"\tGeneral tags: {general_tag_text}") - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: + if images_metadata is None: + with open(caption_file, "wt", encoding="utf-8") as f: + f.write(tag_text + "\n") + else: + image_md = images_metadata.get(image_path, None) + if image_md is None: + image_md = {"image_size": [image_size.width, image_size.height]} + images_metadata[image_path] = image_md + if "tags" not in image_md: + image_md["tags"] = [] + if tags_index is None: + image_md["tags"].append(tag_text) + else: + while len(image_md["tags"]) <= tags_index: + image_md["tags"].append("") + image_md["tags"][tags_index] = tag_text + + if args.debug: + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") + + # load metadata if needed + metadata = None + if args.metadata is not None: + if os.path.exists(args.metadata): + logger.info(f"loading metadata file: {args.metadata}") + with open(args.metadata, "rt", encoding="utf-8") as f: + metadata = json.load(f) + + # version check + major, minor, patch = metadata.get("format_version", "0.0.0").split(".") + if major > 1 or (major == 1 and minor > 0): + logger.warning( + f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work." + ) + + if "images" not in metadata: + metadata["images"] = {} + else: + logger.info(f"metadata file not found: {args.metadata}, creating new metadata") + metadata = {"format_version": "1.0.0", "images": {}} + + images_metadata = metadata["images"] + + # prepare DataLoader or something similar :) + use_loader = False + if args.load_archive: + loader = ArchiveImageLoader(image_paths, args.batch_size) + use_loader = True + elif args.max_data_loader_n_workers is not None: + # 読み込みの高速化のためにDataLoaderを使うオプション dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader( + loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, @@ -339,35 +493,37 @@ def run_batch(path_imgs): collate_fn=collate_fn_remove_corrupted, drop_last=False, ) + use_loader = True else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is None: + # make batch of image paths + loader = [] + for i in range(0, len(image_paths), args.batch_size): + loader.append(image_paths[i : i + args.batch_size]) + + for data_entry in tqdm(loader, smoothing=0.0): + if use_loader: + b_imgs = data_entry + else: + b_imgs = [] + for image_path in data_entry: try: image = Image.open(image_path) if image.mode != "RGB": image = image.convert("RGB") + size = image.size image = preprocess_image(image) except Exception as e: logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue - b_imgs.append((image_path, image)) + b_imgs.append((image_path, image, size)) - if len(b_imgs) >= args.batch_size: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) - b_imgs.clear() + b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string + run_batch(b_imgs, images_metadata, args.tags_index) - if len(b_imgs) > 0: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) + if args.metadata is not None: + logger.info(f"saving metadata file: {args.metadata}") + with open(args.metadata, "wt", encoding="utf-8") as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) @@ -380,9 +536,7 @@ def run_batch(path_imgs): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument( - "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" - ) + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( "--repo_id", type=str, @@ -400,9 +554,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) - parser.add_argument( - "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" - ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -441,9 +593,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument( - "--debug", action="store_true", help="debug mode" - ) + parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument( "--undesired_tags", type=str, @@ -453,20 +603,24 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する" ) - parser.add_argument( - "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" - ) + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument( "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" ) parser.add_argument( - "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", + "--use_rating_tags", + action="store_true", + help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", ) parser.add_argument( - "--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", + "--use_rating_tags_as_last_tag", + action="store_true", + help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", ) parser.add_argument( - "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", + "--character_tags_first", + action="store_true", + help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", ) parser.add_argument( "--always_first_tags", @@ -495,6 +649,25 @@ def setup_parser() -> argparse.ArgumentParser: + " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる", ) + parser.add_argument( + "--metadata", + type=str, + default=None, + help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む", + ) + parser.add_argument( + "--tags_index", + type=int, + default=None, + help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags" + " / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える", + ) + parser.add_argument( + "--load_archive", + action="store_true", + help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata" + " / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります", + ) return parser diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 3730f1216..0c799f6e3 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -84,7 +84,7 @@ def get_text_encoder_outputs_caching_strategy(self, args): args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, - args.max_tolen_length, + args.max_token_length, is_weighted=args.weighted_captions, ) else: