From 28e9352cc5a4211344b604a6c4f4dfa7783c42ed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Dec 2024 22:04:37 +0900 Subject: [PATCH] feat: Florence-2 captioninig (WIP) --- finetune/caption_images_by_florence2.py | 232 ++++++++++++++++++++++++ finetune/tag_images_by_wd14_tagger.py | 128 ++----------- finetune/tagger_utils.py | 171 +++++++++++++++++ 3 files changed, 415 insertions(+), 116 deletions(-) create mode 100644 finetune/caption_images_by_florence2.py create mode 100644 finetune/tagger_utils.py diff --git a/finetune/caption_images_by_florence2.py b/finetune/caption_images_by_florence2.py new file mode 100644 index 000000000..adcabc611 --- /dev/null +++ b/finetune/caption_images_by_florence2.py @@ -0,0 +1,232 @@ +# add caption to images by Florence-2 + + +import argparse +import json +import os +import glob +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm +from transformers import AutoProcessor, AutoModelForCausalLM + +from library import device_utils, train_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import tagger_utils + +TASK_PROMPT = "" + + +def main(args): + assert args.load_archive == ( + args.metadata is not None + ), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります" + + device = args.device if args.device is not None else device_utils.get_preferred_device() + if type(device) is str: + device = torch.device(device) + torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 + logger.info(f"device: {device}, dtype: {torch_dtype}") + + logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中") + + support_flash_attn = False + try: + import flash_attn + + support_flash_attn = True + except ImportError: + pass + + if support_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + else: + logger.info( + "flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます" + ) + + # https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330 + # Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends + from transformers.dynamic_module_utils import get_imports + from unittest.mock import patch + + def fixed_get_imports(filename) -> list[str]: + if not str(filename).endswith("modeling_florence2.py"): + return get_imports(filename) + imports = get_imports(filename) + imports.remove("flash_attn") + return imports + + # workaround for unnecessary flash_attn requirement + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + + model.eval() + processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) + + # 画像を読み込む + if not args.load_archive: + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") + else: + archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob( + os.path.join(args.train_data_dir, "*.tar") + ) + image_paths = [Path(archive_file) for archive_file in archive_files] + + # load metadata if needed + if args.metadata is not None: + metadata = tagger_utils.load_metadata(args.metadata) + images_metadata = metadata["images"] + else: + images_metadata = metadata = None + + # define preprocess_image function + def preprocess_image(image: Image.Image): + inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype) + return inputs + + # prepare DataLoader or something similar :) + # Loader returns: list of (image_path, processed_image_or_something, image_size) + if args.load_archive: + loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug) + else: + # we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable + loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug) + + def run_batch( + list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]], + images_metadata: Optional[dict[str, Any]], + caption_index: Optional[int] = None, + ): + input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size]) + pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size]) + + if args.debug: + logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}") + with torch.no_grad(): + generated_ids = model.generate( + input_ids=input_ids, + pixel_values=pixel_values, + max_new_tokens=args.max_new_tokens, + num_beams=args.num_beams, + ) + if args.debug: + logger.info(f"generate done: {generated_ids.shape}") + generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False) + if args.debug: + logger.info(f"decode done: {len(generated_texts)}") + + for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size): + parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size) + caption_text = parsed_answer[""] + + caption_text = caption_text.strip().replace("", "") + original_caption_text = caption_text + + if args.remove_mood: + p = caption_text.find("The overall ") + if p != -1: + caption_text = caption_text[:p].strip() + + caption_file = os.path.splitext(image_path)[0] + args.caption_extension + + if images_metadata is None: + with open(caption_file, "wt", encoding="utf-8") as f: + f.write(caption_text + "\n") + else: + image_md = images_metadata.get(image_path, None) + if image_md is None: + image_md = {"image_size": list(image_size)} + images_metadata[image_path] = image_md + if "caption" not in image_md: + image_md["caption"] = [] + if caption_index is None: + image_md["caption"].append(caption_text) + else: + while len(image_md["caption"]) <= caption_index: + image_md["caption"].append("") + image_md["caption"][caption_index] = caption_text + + if args.debug: + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tCaption: {caption_text}") + if args.remove_mood and original_caption_text != caption_text: + logger.info(f"\tCaption (prior to removing mood): {original_caption_text}") + + for data_entry in tqdm(loader, smoothing=0.0): + b_imgs = data_entry + b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string + run_batch(b_imgs, images_metadata, args.caption_index) + + if args.metadata is not None: + logger.info(f"saving metadata file: {args.metadata}") + with open(args.metadata, "wt", encoding="utf-8") as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + + logger.info("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する") + parser.add_argument( + "--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する" + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=1024, + help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024", + ) + parser.add_argument( + "--num_beams", + type=int, + default=3, + help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する", + ) + parser.add_argument( + "--caption_index", + type=int, + default=None, + help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption" + " / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + tagger_utils.add_archive_arguments(parser) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index e79e22045..049ecd7f4 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,13 +1,10 @@ import argparse -from concurrent.futures import ThreadPoolExecutor import csv import glob import json import os from pathlib import Path -from typing import Any, Optional, Union -import zipfile -import tarfile +from typing import Any, Optional import cv2 import numpy as np @@ -16,14 +13,17 @@ from PIL import Image from tqdm import tqdm -import library.train_util as train_util -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import library.train_util as train_util +from library.utils import pil_resize +import tagger_utils + # from wd14 tagger IMAGE_SIZE = 448 @@ -79,83 +79,6 @@ def __getitem__(self, idx): return (image, img_path, size) -class ArchiveImageLoader: - def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False): - self.archive_paths = archive_paths - self.batch_size = batch_size - self.debug = debug - self.current_archive = None - self.archive_index = 0 - self.image_index = 0 - self.files = None - self.executor = ThreadPoolExecutor() - self.image_exts = set(train_util.IMAGE_EXTENSIONS) - - def __iter__(self): - return self - - def __next__(self): - images = [] - while len(images) < self.batch_size: - if self.current_archive is None: - if self.archive_index >= len(self.archive_paths): - if len(images) == 0: - raise StopIteration - else: - break # return the remaining images - - if self.debug: - logger.info(f"loading archive: {self.archive_paths[self.archive_index]}") - - current_archive_path = self.archive_paths[self.archive_index] - if current_archive_path.endswith(".zip"): - self.current_archive = zipfile.ZipFile(current_archive_path) - self.files = self.current_archive.namelist() - elif current_archive_path.endswith(".tar"): - self.current_archive = tarfile.open(current_archive_path, "r") - self.files = self.current_archive.getnames() - else: - raise ValueError(f"unsupported archive file: {self.current_archive_path}") - - self.image_index = 0 - - # filter by image extensions - self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts] - - if self.debug: - logger.info(f"found {len(self.files)} images in the archive") - - new_images = [] - while len(images) + len(new_images) < self.batch_size: - if self.image_index >= len(self.files): - break - - file = self.files[self.image_index] - archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}" - self.image_index += 1 - - def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]): - with archive.open(file) as f: - image = Image.open(f).convert("RGB") - size = image.size - image = preprocess_image(image) - return image, size - - new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive))) - - # wait for all new_images to load to close the archive - new_images = [(image_path, future.result()) for image_path, future in new_images] - - if self.image_index >= len(self.files): - self.current_archive.close() - self.current_archive = None - self.archive_index += 1 - - images.extend(new_images) - - return [(image_path, image, size) for image_path, (image, size) in images] - - def collate_fn_remove_corrupted(batch): """Collate function that allows to remove corrupted examples in the dataloader. It expects that the dataloader returns 'None' when that occurs. @@ -460,33 +383,16 @@ def run_batch( logger.info(f"\tGeneral tags: {general_tag_text}") # load metadata if needed - metadata = None if args.metadata is not None: - if os.path.exists(args.metadata): - logger.info(f"loading metadata file: {args.metadata}") - with open(args.metadata, "rt", encoding="utf-8") as f: - metadata = json.load(f) - - # version check - major, minor, patch = metadata.get("format_version", "0.0.0").split(".") - major, minor, patch = int(major), int(minor), int(patch) - if major > 1 or (major == 1 and minor > 0): - logger.warning( - f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work." - ) - - if "images" not in metadata: - metadata["images"] = {} - else: - logger.info(f"metadata file not found: {args.metadata}, creating new metadata") - metadata = {"format_version": "1.0.0", "images": {}} - + metadata = tagger_utils.load_metadata(args.metadata) images_metadata = metadata["images"] + else: + images_metadata = metadata = None # prepare DataLoader or something similar :) use_loader = False if args.load_archive: - loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size) + loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug) use_loader = True elif args.max_data_loader_n_workers is not None: # 読み込みの高速化のためにDataLoaderを使うオプション @@ -655,12 +561,6 @@ def setup_parser() -> argparse.ArgumentParser: + " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる", ) - parser.add_argument( - "--metadata", - type=str, - default=None, - help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む", - ) parser.add_argument( "--tags_index", type=int, @@ -668,12 +568,8 @@ def setup_parser() -> argparse.ArgumentParser: help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags" " / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える", ) - parser.add_argument( - "--load_archive", - action="store_true", - help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata" - " / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります", - ) + tagger_utils.add_archive_arguments(parser) + return parser diff --git a/finetune/tagger_utils.py b/finetune/tagger_utils.py new file mode 100644 index 000000000..4ca4bce5c --- /dev/null +++ b/finetune/tagger_utils.py @@ -0,0 +1,171 @@ +import argparse +import json +import math +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Union +import zipfile +import tarfile + +from PIL import Image + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import train_util + + +class ArchiveImageLoader: + def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False): + self.archive_paths = archive_paths + self.batch_size = batch_size + self.preprocess = preprocess + self.debug = debug + self.current_archive = None + self.archive_index = 0 + self.image_index = 0 + self.files = None + self.executor = ThreadPoolExecutor() + self.image_exts = set(train_util.IMAGE_EXTENSIONS) + + def __iter__(self): + return self + + def __next__(self): + images = [] + while len(images) < self.batch_size: + if self.current_archive is None: + if self.archive_index >= len(self.archive_paths): + if len(images) == 0: + raise StopIteration + else: + break # return the remaining images + + if self.debug: + logger.info(f"loading archive: {self.archive_paths[self.archive_index]}") + + current_archive_path = self.archive_paths[self.archive_index] + if current_archive_path.endswith(".zip"): + self.current_archive = zipfile.ZipFile(current_archive_path) + self.files = self.current_archive.namelist() + elif current_archive_path.endswith(".tar"): + self.current_archive = tarfile.open(current_archive_path, "r") + self.files = self.current_archive.getnames() + else: + raise ValueError(f"unsupported archive file: {self.current_archive_path}") + + self.image_index = 0 + + # filter by image extensions + self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts] + + if self.debug: + logger.info(f"found {len(self.files)} images in the archive") + + new_images = [] + while len(images) + len(new_images) < self.batch_size: + if self.image_index >= len(self.files): + break + + file = self.files[self.image_index] + archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}" + self.image_index += 1 + + def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]): + with archive.open(file) as f: + image = Image.open(f).convert("RGB") + size = image.size + image = self.preprocess(image) + return image, size + + new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive))) + + # wait for all new_images to load to close the archive + new_images = [(image_path, future.result()) for image_path, future in new_images] + + if self.image_index >= len(self.files): + self.current_archive.close() + self.current_archive = None + self.archive_index += 1 + + images.extend(new_images) + + return [(image_path, image, size) for image_path, (image, size) in images] + + +class ImageLoader: + def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False): + self.image_paths = image_paths + self.batch_size = batch_size + self.preprocess = preprocess + self.debug = debug + self.image_index = 0 + self.executor = ThreadPoolExecutor() + + def __len__(self): + return math.ceil(len(self.image_paths) / self.batch_size) + + def __iter__(self): + return self + + def __next__(self): + if self.image_index >= len(self.image_paths): + raise StopIteration + + images = [] + while len(images) < self.batch_size and self.image_index < len(self.image_paths): + + def load_image(file): + image = Image.open(file).convert("RGB") + size = image.size + image = self.preprocess(image) + return image, size + + image_path = self.image_paths[self.image_index] + images.append((image_path, self.executor.submit(load_image, image_path))) + self.image_index += 1 + + images = [(image_path, future.result()) for image_path, future in images] + return [(image_path, image, size) for image_path, (image, size) in images] + + +def load_metadata(metadata_file: str): + if os.path.exists(metadata_file): + logger.info(f"loading metadata file: {metadata_file}") + with open(metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + + # version check + major, minor, patch = metadata.get("format_version", "0.0.0").split(".") + major, minor, patch = int(major), int(minor), int(patch) + if major > 1 or (major == 1 and minor > 0): + logger.warning( + f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work." + ) + + if "images" not in metadata: + metadata["images"] = {} + else: + logger.info(f"metadata file not found: {metadata_file}, creating new metadata") + metadata = {"format_version": "1.0.0", "images": {}} + + return metadata + + +def add_archive_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--metadata", + type=str, + default=None, + help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む", + ) + parser.add_argument( + "--load_archive", + action="store_true", + help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata" + " / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります", + )