diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 31ee93bc0..ffe94e7df 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -2,15 +2,14 @@ import csv import glob import os +from pathlib import Path -from PIL import Image import cv2 -from tqdm import tqdm import numpy as np -from tensorflow.keras.models import load_model -from huggingface_hub import hf_hub_download import torch -from pathlib import Path +from huggingface_hub import hf_hub_download +from PIL import Image +from tqdm import tqdm import library.train_util as train_util @@ -81,6 +80,8 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + if args.onnx: + FILES.append("model.onnx") for file in FILES: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: @@ -96,7 +97,35 @@ def main(args): print("using existing wd14 tagger model") # 画像を読み込む - model = load_model(args.model_dir) + if args.onnx: + import onnx + import onnxruntime as ort + + onnx_path = f"{args.model_dir}/model.onnx" + print("Running wd14 tagger with onnx") + print(f"loading onnx model: {onnx_path}") + model = onnx.load(onnx_path) + input_name = model.graph.input[0].name + try: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value + except: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + if args.batch_size != batch_size and type(batch_size) != str: + # some rebatch model may use 'N' as dynamic axes + print( + f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" + ) + args.batch_size = batch_size + ort_sess = ort.InferenceSession( + onnx_path, + providers=["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"], + ) + else: + from tensorflow.keras.models import load_model + + model = load_model(f"{args.model_dir}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ @@ -124,8 +153,11 @@ def main(args): def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) - probs = model(imgs, training=False) - probs = probs.numpy() + if args.onnx: + probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy + else: + probs = model(imgs, training=False) + probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): # 最初の4つはratingなので無視する @@ -301,6 +333,7 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") return parser diff --git a/requirements.txt b/requirements.txt index 4ca393f52..75de48cb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,8 +19,11 @@ huggingface-hub==0.15.1 # requests==2.28.2 # timm==0.6.12 # fairscale==0.4.13 -# for WD14 captioning +# for WD14 captioning (tensroflow or onnx) # tensorflow==2.10.1 +# onnx==1.14.1 +# onnxruntime-gpu==1.16.0 +# onnxruntime==1.16.0 # open clip for SDXL open-clip-torch==2.20.0 # for kohya_ss library