From f8629e3c1a2de36df2c32c993f1d49ca0ace778b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Oct 2023 23:36:05 +0900 Subject: [PATCH] multi embed training --- tools/split_ti_embeddings.py | 84 ++++++++++++++++++++++++++++++++++++ train_textual_inversion.py | 31 +++++++++++-- 2 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 tools/split_ti_embeddings.py diff --git a/tools/split_ti_embeddings.py b/tools/split_ti_embeddings.py new file mode 100644 index 000000000..4a65d35c9 --- /dev/null +++ b/tools/split_ti_embeddings.py @@ -0,0 +1,84 @@ +import argparse +import os + +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def split(args): + # load embedding + if args.embedding.endswith(".safetensors"): + embedding = load_file(args.embedding) + with safe_open(args.embedding, framework="pt") as f: + metadata = f.metadata() + else: + embedding = torch.load(args.embedding) + metadata = None + + # check format + if "emb_params" in embedding: + # SD1/2 + keys = ["emb_params"] + elif "clip_l" in embedding: + # SDXL + keys = ["clip_l", "clip_g"] + else: + print("Unknown embedding format") + exit() + num_vectors = embedding[keys[0]].shape[0] + + # prepare output directory + os.makedirs(args.output_dir, exist_ok=True) + + # prepare splits + if args.vectors_per_split is not None: + num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split + vectors_for_split = [args.vectors_per_split] * num_splits + if sum(vectors_for_split) > num_vectors: + vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors + assert sum(vectors_for_split) == num_vectors + elif args.vectors is not None: + vectors_for_split = args.vectors + num_splits = len(vectors_for_split) + else: + print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります") + exit() + + assert ( + sum(vectors_for_split) == num_vectors + ), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません" + + # split + basename = os.path.splitext(os.path.basename(args.embedding))[0] + done_vectors = 0 + for i, num_vectors in enumerate(vectors_for_split): + print(f"Splitting {num_vectors} vectors...") + + split_embedding = {} + for key in keys: + split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors] + + output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors") + save_file(split_embedding, output_file, metadata) + print(f"Saved to {output_file}") + + done_vectors += num_vectors + + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Merge models") + parser.add_argument("--embedding", type=str, help="Embedding to split") + parser.add_argument("--output_dir", type=str, help="Output directory") + parser.add_argument( + "--vectors_per_split", + type=int, + default=None, + help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split", + ) + parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2") + args = parser.parse_args() + split(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 252add536..704ff44ee 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,10 +7,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -167,6 +170,13 @@ def train(self, args): args.output_name = args.token_string use_template = args.use_object_template or args.use_style_template + assert ( + args.token_string is not None or args.token_strings is not None + ), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください" + assert ( + not use_template or args.token_strings is None + ), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません" + train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -215,9 +225,17 @@ def train(self, args): # add new word to tokenizer, count is num_vectors_per_token # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added - self.assert_token_string(args.token_string, tokenizers) + if args.token_strings is not None: + token_strings = args.token_strings + assert ( + len(token_strings) == args.num_vectors_per_token + ), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}" + for token_string in token_strings: + self.assert_token_string(token_string, tokenizers) + else: + self.assert_token_string(args.token_string, tokenizers) + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] token_ids_list = [] token_embeds_list = [] for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): @@ -332,7 +350,7 @@ def train(self, args): prompt_replacement = None else: # サンプル生成用 - if args.num_vectors_per_token > 1: + if args.num_vectors_per_token > 1 and args.token_strings is None: replace_to = " ".join(token_strings) train_dataset_group.add_replacement(args.token_string, replace_to) prompt_replacement = (args.token_string, replace_to) @@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) + parser.add_argument( + "--token_strings", + type=str, + default=None, + nargs="*", + help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列", + ) parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") parser.add_argument( "--use_object_template",