Skip to content

Commit

Permalink
Merge pull request #47 from laksjdjf/tleco
Browse files Browse the repository at this point in the history
support tleco
  • Loading branch information
laksjdjf authored Mar 21, 2024
2 parents 71734b4 + 53f9420 commit 237cb3b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 21 deletions.
18 changes: 1 addition & 17 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,10 @@ class TrainerConfig:
validation_args: Dict[str, Any] = field(default_factory=dict)
additional_conf: Dict[str, Any] = field(default_factory=dict)

@dataclass
class DatasetArgs:
batch_size: int = MISSING
path: str = MISSING
metadata: str = "buckets.json"
original_size: Optional[str] = None
latent: Optional[str] = "latents"
caption: Optional[str] = "captions"
image: Optional[str] = None
text_emb: Optional[str] = None
control: Optional[str] = None
prompt: Optional[str] = None
prefix: str = ""
shuffle: bool = False
ucg: float = 0.0

@dataclass
class DatasetConfig:
module: str = MISSING
args: DatasetArgs = field(default_factory=DatasetArgs)
args: Dict[str, Any] = field(default_factory=dict)

@dataclass
class DataLoaderArgs:
Expand Down
108 changes: 108 additions & 0 deletions modules/tleco/tleco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from torch.utils.data import Dataset
import random

class TLECODataset(Dataset):
def __init__(
self,
text_model,
path, # 置き換え対象のテキストが書かれたファイルのパス
batch_size=1,
pad_tokens=None, # パディングトークンの文字列
selected_tags=None, # 追加するタグのリストが書かれたファイルのパス
drop_rate=0.0, # 置き換え対象をドロップする確率(正則化)
shuffle=True,
prefix="", # 先頭文字列
repeat=1, # データセットの反復回数
max_add_length=16, # 追加するタグの最大数
position_rate=0.5, # タグの位置 (先頭 or 末尾)
):
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_rate = drop_rate
self.prefix = prefix
self.max_add_length = max_add_length
self.position_rate = position_rate

with open(path, 'r') as f:
self.texts = f.read().splitlines()
self.text_pairs = [[tag.strip() for tag in text.split(",")] for text in self.texts]

if pad_tokens is not None:
for pad_token in pad_tokens:
pad_token_ids = text_model.tokenizer([pad_token]).input_ids[0]
assert len(pad_token_ids) == 3, f"pad_token:{pad_token} is not a single token"
self.pad_tokens = pad_tokens

self.text_infos = []
for source, target in self.text_pairs:
src_tokens_ids, tgt_tokens_ids = text_model.tokenizer([source, target], add_special_tokens=False).input_ids
self.text_infos.append({
"source": source,
"target": target,
"source_length": len(src_tokens_ids),
"target_length": len(tgt_tokens_ids),
})

if selected_tags is not None:
with open(selected_tags, 'r') as f:
self.selected_tags = f.read().splitlines()

self.text_infos = self.text_infos * repeat

self.create_batch()

def create_batch(self):
if self.shuffle:
random.shuffle(self.text_infos)
self.batch = []
for i in range(0, len(self.text_infos), self.batch_size):
self.batch.append(self.text_infos[i:i+self.batch_size])
return

def __len__(self):
return len(self.batch)

def __getitem__(self, idx):
if idx == 0:
self.create_batch()
batch = self.batch[idx]
sources = []
targets = []
for dic in batch:
source = dic["source"]
target = dic["target"]
source_length = dic["source_length"]
target_length = dic["target_length"]

add_tags = random.sample(self.selected_tags, random.randint(0, self.max_add_length))
add_tags = ", ".join(add_tags)

padding = random.choices(self.pad_tokens, k=abs(source_length - target_length))

if source_length > target_length:
target += " " + " ".join(padding)
elif source_length < target_length:
source += " " + " ".join(padding)

last_source = self.prefix
last_target = self.prefix
if random.random() > self.drop_rate:
if random.random() > self.position_rate:
last_source += ", " + source + ", " + add_tags
last_target += ", " + target + ", " + add_tags
else:
last_source += ", " + add_tags + ", " + source
last_target += ", " + add_tags + ", " + target
else:
last_source += ", " + add_tags
last_target += ", " + add_tags

sources.append(last_source)
targets.append(last_target)

return {
"source": sources,
"target": targets
}


40 changes: 40 additions & 0 deletions modules/tleco/tleco_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from modules.trainer import BaseTrainer

class TLECOTrainer(BaseTrainer):
def loss(self, batch):
source = batch["source"]
target = batch["target"]
self.batch_size = len(source)

with torch.autocast("cuda", dtype=self.autocast_dtype):
with torch.no_grad(), self.network.set_temporary_multiplier(0.0):
tgt_hidden, tgt_pool = self.text_model(target)

src_hidden, src_pool = self.text_model(source)

if len(self.network.unet_modules) > 0:
with torch.no_grad(), self.network.set_temporary_multiplier(0.0):
target_kvs = self.kv_emb(tgt_hidden)
source_kvs = self.kv_emb(src_hidden)

loss_hidden = torch.nn.functional.mse_loss(src_hidden, tgt_hidden)
loss_pool = torch.nn.functional.mse_loss(src_pool, tgt_pool)

loss = loss_hidden + loss_pool

if len(self.network.unet_modules) > 0:
loss_kvs =[torch.nn.functional.mse_loss(src, tgt) for src, tgt in zip(source_kvs, target_kvs)]
for loss_kv in loss_kvs:
loss += loss_kv / len(loss_kvs)

return loss

def kv_emb(self, text_emb):
outputs = []
for name, module in self.diffusion.unet.named_modules():
if "attn2.to_k" in name:
outputs.append(module(text_emb))
elif "attn2.to_v" in name:
outputs.append(module(text_emb))
return outputs
9 changes: 5 additions & 4 deletions networks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,17 @@ def __init__(

# unetのloraを作る
self.unet_modules = []
self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, state_dict, module_args, unet_key_filters, unet_keys)
if state_dict or module_args is not None:
self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_TRANSFORMER, state_dict, module_args, unet_key_filters, unet_keys)
if state_dict or conv_module_args is not None:
self.unet_modules += self.create_modules(LORA_PREFIX_UNET, unet, UNET_TARGET_REPLACE_MODULE_CONV, state_dict, conv_module_args, unet_key_filters, unet_keys)
if state_dict or text_module_args is not None:
self.text_encoder_modules = []
if text_model.sdxl:
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te1_keys)
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te2_keys)
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_1, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te1_keys)
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER_2, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te2_keys)
else:
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, te_keys)
self.text_encoder_modules += self.create_modules(LORA_PREFIX_TEXT_ENCODER, text_model, TEXT_ENCODER_TARGET_REPLACE_MODULE, state_dict, text_module_args, None, te_keys)
else:
self.text_encoder_modules = []

Expand Down

0 comments on commit 237cb3b

Please sign in to comment.