Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --onnx to wd14 tagger #864

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import csv
import glob
import os
from pathlib import Path

from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm

import library.train_util as train_util

Expand Down Expand Up @@ -81,6 +80,8 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
if args.onnx:
FILES.append("model.onnx")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
Expand All @@ -96,7 +97,35 @@ def main(args):
print("using existing wd14 tagger model")

# 画像を読み込む
model = load_model(args.model_dir)
if args.onnx:
import onnx
import onnxruntime as ort

onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model

model = load_model(f"{args.model_dir}")

# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
Expand Down Expand Up @@ -124,8 +153,11 @@ def main(args):
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])

probs = model(imgs, training=False)
probs = probs.numpy()
if args.onnx:
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
else:
probs = model(imgs, training=False)
probs = probs.numpy()

for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
Expand Down Expand Up @@ -301,6 +333,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")

return parser
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ huggingface-hub==0.15.1
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning
# for WD14 captioning (tensroflow or onnx)
# tensorflow==2.10.1
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library
Expand Down