From 53f94200c8ce23a398ee4e9a5c773fff65364e74 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Thu, 21 Mar 2024 17:35:16 +0900 Subject: [PATCH] support tleco --- modules/config.py | 18 +----- modules/tleco/tleco_dataset.py | 108 +++++++++++++++++++++++++++++++++ modules/tleco/tleco_trainer.py | 40 ++++++++++++ networks/manager.py | 9 +-- 4 files changed, 154 insertions(+), 21 deletions(-) create mode 100644 modules/tleco/tleco_dataset.py create mode 100644 modules/tleco/tleco_trainer.py diff --git a/modules/config.py b/modules/config.py index 17189dd..1c767cb 100644 --- a/modules/config.py +++ b/modules/config.py @@ -45,26 +45,10 @@ class TrainerConfig: validation_args: Dict[str, Any] = field(default_factory=dict) additional_conf: Dict[str, Any] = field(default_factory=dict) -@dataclass -class DatasetArgs: - batch_size: int = MISSING - path: str = MISSING - metadata: str = "buckets.json" - original_size: Optional[str] = None - latent: Optional[str] = "latents" - caption: Optional[str] = "captions" - image: Optional[str] = None - text_emb: Optional[str] = None - control: Optional[str] = None - prompt: Optional[str] = None - prefix: str = "" - shuffle: bool = False - ucg: float = 0.0 - @dataclass class DatasetConfig: module: str = MISSING - args: DatasetArgs = field(default_factory=DatasetArgs) + args: Dict[str, Any] = field(default_factory=dict) @dataclass class DataLoaderArgs: diff --git a/modules/tleco/tleco_dataset.py b/modules/tleco/tleco_dataset.py new file mode 100644 index 0000000..60401ef --- /dev/null +++ b/modules/tleco/tleco_dataset.py @@ -0,0 +1,108 @@ +from torch.utils.data import Dataset +import random + +class TLECODataset(Dataset): + def __init__( + self, + text_model, + path, # 置き換え対象のテキストが書かれたファイルのパス + batch_size=1, + pad_tokens=None, # パディングトークンの文字列 + selected_tags=None, # 追加するタグのリストが書かれたファイルのパス + drop_rate=0.0, # 置き換え対象をドロップする確率(正則化) + shuffle=True, + prefix="", # 先頭文字列 + repeat=1, # データセットの反復回数 + max_add_length=16, # 追加するタグの最大数 + position_rate=0.5, # タグの位置 (先頭 or 末尾) + ): + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_rate = drop_rate + self.prefix = prefix + self.max_add_length = max_add_length + self.position_rate = position_rate + + with open(path, 'r') as f: + self.texts = f.read().splitlines() + self.text_pairs = [[tag.strip() for tag in text.split(",")] for text in self.texts] + + if pad_tokens is not None: + for pad_token in pad_tokens: + pad_token_ids = text_model.tokenizer([pad_token]).input_ids[0] + assert len(pad_token_ids) == 3, f"pad_token:{pad_token} is not a single token" + self.pad_tokens = pad_tokens + + self.text_infos = [] + for source, target in self.text_pairs: + src_tokens_ids, tgt_tokens_ids = text_model.tokenizer([source, target], add_special_tokens=False).input_ids + self.text_infos.append({ + "source": source, + "target": target, + "source_length": len(src_tokens_ids), + "target_length": len(tgt_tokens_ids), + }) + + if selected_tags is not None: + with open(selected_tags, 'r') as f: + self.selected_tags = f.read().splitlines() + + self.text_infos = self.text_infos * repeat + + self.create_batch() + + def create_batch(self): + if self.shuffle: + random.shuffle(self.text_infos) + self.batch = [] + for i in range(0, len(self.text_infos), self.batch_size): + self.batch.append(self.text_infos[i:i+self.batch_size]) + return + + def __len__(self): + return len(self.batch) + + def __getitem__(self, idx): + if idx == 0: + self.create_batch() + batch = self.batch[idx] + sources = [] + targets = [] + for dic in batch: + source = dic["source"] + target = dic["target"] + source_length = dic["source_length"] + target_length = dic["target_length"] + + add_tags = random.sample(self.selected_tags, random.randint(0, self.max_add_length)) + add_tags = ", ".join(add_tags) + + padding = random.choices(self.pad_tokens, k=abs(source_length - target_length)) + + if source_length > target_length: + target += " " + " ".join(padding) + elif source_length < target_length: + source += " " + " ".join(padding) + + last_source = self.prefix + last_target = self.prefix + if random.random() > self.drop_rate: + if random.random() > self.position_rate: + last_source += ", " + source + ", " + add_tags + last_target += ", " + target + ", " + add_tags + else: + last_source += ", " + add_tags + ", " + source + last_target += ", " + add_tags + ", " + target + else: + last_source += ", " + add_tags + last_target += ", " + add_tags + + sources.append(last_source) + targets.append(last_target) + + return { + "source": sources, + "target": targets + } + + \ No newline at end of file diff --git a/modules/tleco/tleco_trainer.py b/modules/tleco/tleco_trainer.py new file mode 100644 index 0000000..e26c44f --- /dev/null +++ b/modules/tleco/tleco_trainer.py @@ -0,0 +1,40 @@ +import torch +from modules.trainer import BaseTrainer + +class TLECOTrainer(BaseTrainer): + def loss(self, batch): + source = batch["source"] + target = batch["target"] + self.batch_size = len(source) + + with torch.autocast("cuda", dtype=self.autocast_dtype): + with torch.no_grad(), self.network.set_temporary_multiplier(0.0): + tgt_hidden, tgt_pool = self.text_model(target) + + src_hidden, src_pool = self.text_model(source) + + if len(self.network.unet_modules) > 0: + with torch.no_grad(), self.network.set_temporary_multiplier(0.0): + target_kvs = self.kv_emb(tgt_hidden) + source_kvs = self.kv_emb(src_hidden) + + loss_hidden = torch.nn.functional.mse_loss(src_hidden, tgt_hidden) + loss_pool = torch.nn.functional.mse_loss(src_pool, tgt_pool) + + loss = loss_hidden + loss_pool + + if len(self.network.unet_modules) > 0: + loss_kvs =[torch.nn.functional.mse_loss(src, tgt) for src, tgt in zip(source_kvs, target_kvs)] + for loss_kv in loss_kvs: + loss += loss_kv / len(loss_kvs) + + return loss + + def kv_emb(self, text_emb): + outputs = [] + for name, module in self.diffusion.unet.named_modules(): + if "attn2.to_k" in name: + outputs.append(module(text_emb)) + elif "attn2.to_v" in name: + outputs.append(module(text_emb)) + return outputs diff --git a/networks/manager.py b/networks/manager.py index 12c3f35..d518bb3 100644 --- a/networks/manager.py +++ b/networks/manager.py @@ -62,16 +62,17 @@ def __init__( # unetのloraを作る self.unet_modules = [] - self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, state_dict, module_args, unet_key_filters, unet_keys) + if state_dict or module_args is not None: + self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, state_dict, module_args, unet_key_filters, unet_keys) if state_dict or conv_module_args is not None: self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_CONV, state_dict, conv_module_args, unet_key_filters, unet_keys) if state_dict or text_module_args is not None: self.text_encoder_modules = [] if text_model.sdxl: - self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te1_keys) - self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te2_keys) + self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te1_keys) + self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te2_keys) else: - self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te_keys) + self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te_keys) else: self.text_encoder_modules = []