Skip to content

Commit

Permalink
multi embed training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 10, 2023
1 parent 33ee0ac commit f8629e3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 3 deletions.
84 changes: 84 additions & 0 deletions tools/split_ti_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm


def split(args):
# load embedding
if args.embedding.endswith(".safetensors"):
embedding = load_file(args.embedding)
with safe_open(args.embedding, framework="pt") as f:
metadata = f.metadata()
else:
embedding = torch.load(args.embedding)
metadata = None

# check format
if "emb_params" in embedding:
# SD1/2
keys = ["emb_params"]
elif "clip_l" in embedding:
# SDXL
keys = ["clip_l", "clip_g"]
else:
print("Unknown embedding format")
exit()
num_vectors = embedding[keys[0]].shape[0]

# prepare output directory
os.makedirs(args.output_dir, exist_ok=True)

# prepare splits
if args.vectors_per_split is not None:
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
vectors_for_split = [args.vectors_per_split] * num_splits
if sum(vectors_for_split) > num_vectors:
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
assert sum(vectors_for_split) == num_vectors
elif args.vectors is not None:
vectors_for_split = args.vectors
num_splits = len(vectors_for_split)
else:
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
exit()

assert (
sum(vectors_for_split) == num_vectors
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"

# split
basename = os.path.splitext(os.path.basename(args.embedding))[0]
done_vectors = 0
for i, num_vectors in enumerate(vectors_for_split):
print(f"Splitting {num_vectors} vectors...")

split_embedding = {}
for key in keys:
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]

output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
save_file(split_embedding, output_file, metadata)
print(f"Saved to {output_file}")

done_vectors += num_vectors

print("Done")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge models")
parser.add_argument("--embedding", type=str, help="Embedding to split")
parser.add_argument("--output_dir", type=str, help="Output directory")
parser.add_argument(
"--vectors_per_split",
type=int,
default=None,
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
)
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
args = parser.parse_args()
split(args)
31 changes: 28 additions & 3 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -167,6 +170,13 @@ def train(self, args):
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template

assert (
args.token_string is not None or args.token_strings is not None
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
assert (
not use_template or args.token_strings is None
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"

train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)

Expand Down Expand Up @@ -215,9 +225,17 @@ def train(self, args):
# add new word to tokenizer, count is num_vectors_per_token
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added

self.assert_token_string(args.token_string, tokenizers)
if args.token_strings is not None:
token_strings = args.token_strings
assert (
len(token_strings) == args.num_vectors_per_token
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
for token_string in token_strings:
self.assert_token_string(token_string, tokenizers)
else:
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]

token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
Expand Down Expand Up @@ -332,7 +350,7 @@ def train(self, args):
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1:
if args.num_vectors_per_token > 1 and args.token_strings is None:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
Expand Down Expand Up @@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument(
"--token_strings",
type=str,
default=None,
nargs="*",
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",
Expand Down

0 comments on commit f8629e3

Please sign in to comment.