Skip to content

Commit

Permalink
Merge pull request kohya-ss#427 from kohya-ss/dev
Browse files Browse the repository at this point in the history
fix lora_interrogator, wd14 tagger for '^_^' etc
  • Loading branch information
kohya-ss authored Apr 19, 2023
2 parents 334589a + 589a90b commit ee5cec7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History

### 17 Apr. 2023, 2023/4/17:

- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。
### 19 Apr. 2023, 2023/4/19:
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
- Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`.
- `lora_interrogator.py`が動作しなくなっていたのを修正しました。詳細は [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) をご参照ください。A2va氏およびheyalexchoi氏に感謝します。
- `tag_images_by_wd14_tagger.py``_`を含むタグの取り扱いを修正しました。

### Naming of LoRA

Expand Down Expand Up @@ -164,6 +165,11 @@ LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additio

LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。

### 17 Apr. 2023, 2023/4/17:

- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。

### 14 Apr. 2023, 2023/4/14:
- Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`.
- `train_network.py`で、DyLoRAを`--network_weight`オプションで読み込むとエラーになる不具合を修正しました。
Expand Down
14 changes: 8 additions & 6 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,19 @@ def run_batch(path_imgs):
character_tag_text = ""
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i].replace("_", " ") if args.remove_underscore else general_tags[i]
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = (
character_tags[i - len(general_tags)].replace("_", " ")
if args.remove_underscore
else character_tags[i - len(general_tags)]
)
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
Expand Down
23 changes: 17 additions & 6 deletions networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from tqdm import tqdm
from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
Expand All @@ -16,16 +17,20 @@


def interrogate(args):
weights_dtype = torch.float16

# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)

print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)

# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
for key in weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
Expand All @@ -40,9 +45,9 @@ def interrogate(args):
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)

text_encoder.to(DEVICE)
text_encoder.to(DEVICE, dtype=weights_dtype)
text_encoder.eval()
unet.to(DEVICE)
unet.to(DEVICE, dtype=weights_dtype)
unet.eval() # U-Netは呼び出さないので不要だけど

# トークンをひとつひとつ当たっていく
Expand Down Expand Up @@ -78,9 +83,14 @@ def get_all_embeddings(text_encoder):
orig_embs = get_all_embeddings(text_encoder)

network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")

network.to(DEVICE, dtype=weights_dtype)
network.eval()

del unet

print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
Expand All @@ -107,6 +117,7 @@ def get_all_embeddings(text_encoder):

def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,
Expand Down

0 comments on commit ee5cec7

Please sign in to comment.