diff --git a/library/train_util.py b/library/train_util.py index b2645505c..4e3014d91 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5639,10 +5639,11 @@ def sample_image_inference( prompt: str = prompt_dict.get("prompt", "") sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if len(prompt_replacement) > 0: + for to_replace, replaced_by in prompt_replacement: + prompt = prompt.replace(to_replace, replaced_by) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(to_replace, replaced_by) if seed is not None: torch.manual_seed(seed) diff --git a/train_network.py b/train_network.py index d45bf8605..300d691fe 100644 --- a/train_network.py +++ b/train_network.py @@ -138,8 +138,41 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement) + + def create_embedding_from_data(self, data, name='unknown'): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + + assert ( + len(param_dict) == 1 + ), f"embedding file has multiple terms in it: {name}" + + emb = next(iter(param_dict.items()))[1] + vec = emb.detach() + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach() for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert ( + len(data.keys()) == 1 + ), f"embedding file has multiple terms in it: {name}" + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach() + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {name} as neither textual inversion embedding nor diffuser concept.") + + return vec, shape, vectors def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): @@ -220,6 +253,69 @@ def train(self, args): tokenizer = self.load_tokenizer(args) tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + # acceleratorを準備する + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + + # text_encoder is List[CLIPTextModel] or CLIPTextModel + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # prepare embeddings (for pivotal tuning) + # load from existing files + embeddings_map = {} + embedding_to_token_ids = {} + if len(args.embeddings) > 0: + for embeds_file in args.embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + embeds, _shape, num_vectors_per_token = self.create_embedding_from_data(data, token_string) + embedding_to_token_ids[token_string] = [] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + accelerator.print("Loaded token strings", token_strings) + for i, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): + num_added_tokens = tokenizer.add_tokens(token_strings) + + assert ( + num_added_tokens == num_vectors_per_token + ), f"The tokenizer already contains {token_string}. Please pass a different word that is not already in the tokenizer. / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + accelerator.print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}" + assert ( + len(tokenizer) - 1 == token_ids[-1] + ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}" + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + for token_id, embed in zip(token_ids, embeds): + text_encoder.get_input_embeddings().weight.data[token_id] = embed + embeddings_map[token_string] = embeds + embedding_to_token_ids[token_string].append(token_ids) + # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -272,6 +368,16 @@ def train(self, args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + prompt_replacements = [] + for emb_name in embeddings_map.keys(): + emb_token_ids = embedding_to_token_ids[emb_name] + if len(emb_token_ids[0]) > 1: + token_strings = [emb_name] + [f"{emb_name}{i+1}" for i in range(len(emb_token_ids[0]) - 1)] + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(emb_name, replace_to) + prompt_replacement = (emb_name, replace_to) + prompt_replacements.append(prompt_replacement) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -292,26 +398,6 @@ def train(self, args): self.assert_extra_args(args, train_dataset_group) - # acceleratorを準備する - logger.info("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - - # text_encoder is List[CLIPTextModel] or CLIPTextModel - text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) @@ -413,6 +499,14 @@ def train(self, args): # 後方互換性を確保するよ try: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + # Add embeddings params when continuing the inversion + if args.continue_inversion: + # TODO: might be good to add the embedding to the LoRA module directly to continue training ("bundle_emb.{emb_name}.string_to_param.*") + for text_encoder in text_encoders: + trainable_params.append({ + "params": text_encoder.get_input_embeddings().parameters(), + "lr": args.embedding_lr or args.text_encoder_lr or args.learning_rate + }) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -551,6 +645,29 @@ def train(self, args): training_model = network + if args.continue_inversion: + token_ids_list = [] + for emb_name in embeddings_map.keys(): + token_ids_group = [] + for sublist in embedding_to_token_ids[emb_name]: + token_ids_group.extend(sublist) + token_ids_list.append(token_ids_group) + index_no_updates_list = [] + orig_embeds_params_list = [] + for tokenizer, token_ids, t_enc in zip(tokenizers, token_ids_list, text_encoders): + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + index_no_updates_list.append(index_no_updates) + orig_embeds_params = accelerator.unwrap_model(t_enc).get_input_embeddings().weight.detach().clone() + orig_embeds_params_list.append(orig_embeds_params) + + # Freeze all parameters except for the token embeddings in text encoder + t_enc.requires_grad_(True) + t_enc.text_model.encoder.requires_grad_(False) + t_enc.text_model.final_layer_norm.requires_grad_(False) + t_enc.text_model.embeddings.position_embedding.requires_grad_(False) + # t_enc.text_model.embeddings.requires_grad_(True) + # t_enc.text_model.embeddings.token_embedding.requires_grad_(True) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required if (args.optimizer_type.lower().endswith("schedulefree")): @@ -573,8 +690,6 @@ def train(self, args): del t_enc - accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) - if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -880,14 +995,8 @@ def load_model_hook(models, input_dir): del train_dataset_group - # callback for step start - if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start - else: - on_step_start = lambda *args, **kwargs: None - # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, embeddings_map, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) @@ -900,7 +1009,40 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) metadata_to_save.update(sai_metadata) - unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if len(embeddings_map.keys()) > 0: + # Bundle embeddings in LoRA state dict + state_dict = unwrapped_nw.state_dict() + for emb_name in embeddings_map.keys(): + accelerator.print(f"Bundling embedding: {emb_name}, {embedding_to_token_ids[emb_name]}") + key = f"bundle_emb.{emb_name}.string_to_param.*" + state_dict[key] = embeddings_map[emb_name] + + if metadata_to_save is not None and len(metadata_to_save) == 0: + metadata_to_save = None + + # Save LoRA + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata_to_save is None: + metadata_to_save = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata_to_save) + metadata_to_save["sshs_model_hash"] = model_hash + metadata_to_save["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, ckpt_file, metadata_to_save) + else: + torch.save(state_dict, ckpt_file) + else: + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -911,7 +1053,15 @@ def remove_model(old_ckpt_name): os.remove(old_ckpt_file) # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement=prompt_replacements) + + # callback for step start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start = accelerator.unwrap_model(network).on_step_start + else: + on_step_start = lambda *args, **kwargs: None + + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) # training loop for epoch in range(num_train_epochs): @@ -953,7 +1103,7 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and (train_text_encoder or args.continue_inversion)), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -1046,6 +1196,23 @@ def remove_model(old_ckpt_name): with accelerator.autocast(): # not sure if necessary check_and_update_ema(args, e, i) + # Let's make sure we don't update any embedding weights besides the added pivots + if args.continue_inversion: + with torch.no_grad(): + for text_encoder, orig_embeds_params, index_no_updates in zip( + text_encoders, orig_embeds_params_list, index_no_updates_list + ): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Update embeddings map (for saving) + # TODO: this is not optimal, might need to be refactored + for emb_name in embeddings_map.keys(): + emb_token_ids = embedding_to_token_ids[emb_name] + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[emb_token_ids].detach().clone() + embeddings_map[emb_name] = updated_embs + if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device @@ -1062,14 +1229,14 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement=prompt_replacements) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch, embeddings_map) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -1128,7 +1295,7 @@ def remove_model(old_ckpt_name): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1, embeddings_map) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -1138,7 +1305,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement=prompt_replacements) # end of epoch @@ -1155,7 +1322,7 @@ def remove_model(old_ckpt_name): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + save_model(ckpt_name, network, global_step, num_train_epochs, embeddings_map, force_sync_upload=True) if args.enable_ema and args.ema_type == 'traditional': # save directly @@ -1295,7 +1462,17 @@ def setup_parser() -> argparse.ArgumentParser: type=int, default=None, help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + ) + # Pivotal tuning + parser.add_argument( + "--embeddings", + type=str, + default=[], + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--continue_inversion", action="store_true", help="Continue the textual inversion when training the LoRA") + parser.add_argument("--embedding_lr", type=float, default=None, help="Learning rate used when continuing the textual inversion") return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index fe6f29ce9..e7e7cfd37 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -93,7 +93,6 @@ "a large painting in the style of {}", ] - class TextualInversionTrainer: def __init__(self): self.vae_scale_factor = 0.18215 @@ -324,6 +323,7 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + prompt_replacements = [] if use_template: accelerator.print(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small @@ -336,16 +336,15 @@ def train(self, args): # サンプル生成用 if args.num_vectors_per_token > 1: prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None + prompt_replacements.append(prompt_replacement) + else: # サンプル生成用 if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings) train_dataset_group.add_replacement(args.token_string, replace_to) prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None + prompt_replacements.append(prompt_replacement) if args.debug_dataset: train_util.debug_dataset(train_dataset_group, show_input_ids=True) @@ -558,7 +557,7 @@ def remove_model(old_ckpt_name): tokenizer_or_list, text_encoder_or_list, unet, - prompt_replacement, + prompt_replacement=prompt_replacements, ) # training loop @@ -666,7 +665,7 @@ def remove_model(old_ckpt_name): tokenizer_or_list, text_encoder_or_list, unet, - prompt_replacement, + prompt_replacement=prompt_replacements, ) # 指定ステップごとにモデルを保存 @@ -749,7 +748,7 @@ def remove_model(old_ckpt_name): tokenizer_or_list, text_encoder_or_list, unet, - prompt_replacement, + prompt_replacement=prompt_replacements, ) # end of epoch