diff --git a/modules/dummy/dummy_dataset.py b/modules/dummy/dummy_dataset.py new file mode 100644 index 0000000..a219ce8 --- /dev/null +++ b/modules/dummy/dummy_dataset.py @@ -0,0 +1,45 @@ +from torch.utils.data import Dataset +import torch + +class DummyDataset(Dataset): + def __init__( + self, + text_model, + batch_size = 1, + size = (512, 512), + num_batch = 100, + cache_latent = False, + cache_text_emb = False, + ): + self.batch_size = batch_size + self.width, self.height = size + self.num_batch = num_batch + self.cache_latent = cache_latent + self.cache_text_emb = cache_text_emb + + self.sdxl = text_model.sdxl + + def __len__(self): + return self.num_batch + + def __getitem__(self, i): + + batch = {} + if self.cache_latent: + batch["latents"] = torch.randn(self.batch_size, 4, self.height//8, self.width//8) + else: + batch["images"] = torch.randn(self.batch_size, 3, self.height, self.width) + + if self.sdxl: + size_list = [self.height, self.width, 0, 0, self.height, self.width] + batch["size_condition"] = torch.tensor(size_list).repeat(self.batch_size, 1) + + if self.cache_text_emb: + dim = 2048 if self.sdxl else 768 # sd2? siranai ko desu ne + batch["encoder_hidden_states"] = torch.randn(self.batch_size, 77, dim) + if self.sdxl: + batch["pooled_outputs"] = torch.randn(self.batch_size, dim) + else: + batch["captions"] = ["" for _ in range(self.batch_size)] + + return batch \ No newline at end of file diff --git a/modules/trainer.py b/modules/trainer.py index 49fe1d8..d6d0713 100644 --- a/modules/trainer.py +++ b/modules/trainer.py @@ -39,6 +39,7 @@ def __init__(self, config, diffusion:DiffusionModel, text_model:TextModel, vae:A self.diffusers_scheduler = scheduler # モデルのセーブ次にのみ利用 self.scheduler = BaseScheduler(scheduler.config.prediction_type == "v_prediction") self.sdxl = text_model.sdxl + self.scaling_factor = 0.13025 if self.sdxl else 0.18215 if config is not None and config.merging_loras: for lora in config.merging_loras: @@ -218,10 +219,10 @@ def prepare_lr_scheduler(self, total_steps): def loss(self, batch): if "latents" in batch: - latents = batch["latents"].to(self.device) * self.vae.scaling_factor + latents = batch["latents"].to(self.device) * self.scaling_factor else: with torch.autocast("cuda", dtype=self.vae_dtype), torch.no_grad(): - latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.vae.scaling_factor + latents = self.vae.encode(batch['images'].to(self.device)).latent_dist.sample() * self.scaling_factor self.batch_size = latents.shape[0] # stepメソッドでも使う @@ -316,7 +317,7 @@ def sample( latents = torch.zeros(batch_size, 4, height // 8, width // 8, device=self.device, dtype=self.autocast_dtype) else: with torch.autocast("cuda", dtype=self.vae_dtype): - latents = self.encode_latents(images) * self.vae.scaling_factor + latents = self.encode_latents(images) * self.scaling_factor latents.to(dtype=self.autocast_dtype) noise = torch.randn_like(latents) diff --git a/networks/lora.py b/networks/lora.py index bc1dd52..9e39dd6 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -36,9 +36,10 @@ def get_weight(self, multiplier=None): class LoRAModule(BaseModule): - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state_dict=None, rank=4, alpha=1): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state_dict=None, rank=4, alpha=1, forward_mode="sequential"): super().__init__() self.lora_name = lora_name + self.forward_mode = forward_mode if state_dict is not None: up_weight = state_dict[f"{lora_name}.lora_up.weight"] @@ -55,6 +56,9 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state self.lora_down = torch.nn.Linear(in_dim, rank, bias=False) self.lora_up = torch.nn.Linear(rank, out_dim, bias=False) + self.functional = torch.nn.functional.linear + self.functional_args = {} + elif 'Conv' in org_module.__class__.__name__: # ["Conv2d", "LoRACompatibleConv"] in_dim = org_module.in_channels out_dim = org_module.out_channels @@ -70,6 +74,12 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, state in_dim, self.rank, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d( self.rank, out_dim, (1, 1), (1, 1), bias=False) + + self.functional = torch.nn.functional.conv2d + self.functional_args = { + "stride": stride, + "padding": padding, + } self.shape = org_module.weight.shape @@ -108,5 +118,9 @@ def lora_forward(self, x): def forward(self, x, scale = None): if self.multiplier == 0.0: return self.org_forward(x) - else: - return self.org_forward(x) + self.lora_forward(x) \ No newline at end of file + if self.forward_mode == "sequential": + return self.org_forward(x) + self.lora_forward(x) + elif self.forward_mode == "merge": + weight = self.org_module[0].state_dict()["weight"] + bias = None if "bias" not in self.org_module[0].state_dict() else self.org_module[0].state_dict()["bias"] + return self.functional(x, weight + self.get_weight(), bias, **self.functional_args) \ No newline at end of file diff --git a/networks/manager.py b/networks/manager.py index 800fd0d..48d768c 100644 --- a/networks/manager.py +++ b/networks/manager.py @@ -59,6 +59,11 @@ def __init__( te2_keys = [key for key in keys if LORA_PREFIX_TEXT_ENCODER_2 in key] self.module = get_attr_from_config(module) + + if hasattr(conv_module_args, "same") and conv_module_args.same: + conv_module_args = module_args + if hasattr(text_module_args, "same") and text_module_args.same: + text_module_args = module_args # unetのloraを作る self.unet_modules = [] diff --git a/speedtest.py b/speedtest.py new file mode 100644 index 0000000..aedf317 --- /dev/null +++ b/speedtest.py @@ -0,0 +1,125 @@ +from omegaconf import OmegaConf +import sys +import math +from accelerate.utils import set_seed +from modules.utils import get_attr_from_config, collate_fn +from modules.config import Config +from tqdm import tqdm +import logging +import subprocess +import time +import json +import pandas as pd +from itertools import product +import torch +import gc + +logger = logging.getLogger("テストちゃん") + +def get_gpu_memory_usage(): + cmd = ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits'] + result = subprocess.run(cmd, stdout=subprocess.PIPE) + return int(result.stdout.decode('utf-8').strip()) + +def setattr_recursive(obj, key, value): + if "." in key: + key, rest = key.split(".", 1) + setattr_recursive(getattr(obj, key), rest, value) + else: + setattr(obj, key, value) + +def main(config): + + set_seed(config.main.seed) + logger.info(f"シードは{config.main.seed}だよ!") + + logger.info(f"モデルを{config.main.model_path}からロードしちゃうよ!") + trainer_cls = get_attr_from_config(config.trainer.module) + trainer = trainer_cls.from_pretrained(config.main.model_path, config.main.sdxl, config.main.clip_skip, config.trainer) + + dataset_cls = get_attr_from_config(config.dataset.module) + dataset = dataset_cls(trainer.text_model, **config.dataset.args) + + dataloder_cls = get_attr_from_config(config.dataloader.module) + dataloader = dataloder_cls(dataset, collate_fn=collate_fn, **config.dataloader.args) + + trainer.prepare_modules_for_training() + trainer.prepare_network(config.network) + trainer.prepare_controlnet(config.controlnet) + trainer.apply_module_settings() + + trainer.prepare_optimizer() + + steps_per_epoch = len(dataloader) + total_steps = config.main.steps or steps_per_epoch * config.main.epochs + total_epochs = config.main.epochs or math.floor(total_steps / steps_per_epoch) + logger.info(f"トータルのステップ数は{total_steps}だよ!") + + trainer.prepare_lr_scheduler(total_steps) + + peek_memory = get_gpu_memory_usage() + current_step = 0 + + progress_bar = None + for epoch in range(total_epochs): + for batch in dataloader: + if progress_bar is None: + start_time = time.time() + progress_bar = tqdm(total=total_steps, desc="Training") + logs = trainer.step(batch) + peek_memory = max(peek_memory, get_gpu_memory_usage()) + logs.update({"peek_memory": peek_memory}) + progress_bar.update(1) + progress_bar.set_postfix(logs) + current_step += 1 + + if current_step == total_steps: + logger.info(f"トレーニングが終わったよ!") + end_time = time.time() + seconds = end_time - start_time + samples_per_second = total_steps*dataset.batch_size / seconds + print(f"トータルの時間は{seconds:02}秒だよ!") + print(f"VRAMのピークは{peek_memory}MBだよ!") + print(f"1秒あたりのサンプル数は{samples_per_second}だよ!") + del trainer.diffusion.unet, trainer.vae, trainer.text_model + del trainer + gc.collect() + torch.cuda.empty_cache() + return seconds, total_steps, samples_per_second, peek_memory + + logger.info(f"エポック{epoch+1}が終わったよ!") + +if __name__ == "__main__": + base_config = OmegaConf.load(sys.argv[1]) + base_config = OmegaConf.merge(OmegaConf.structured(Config), base_config) + + logging.basicConfig(level=logging.WARNING) + print(OmegaConf.to_yaml(base_config)) + + if len(sys.argv) == 3: + with open(sys.argv[2], "r") as f: + valiation = json.load(f) + + keys = list(valiation.keys()) + values = list(valiation.values()) + columns = [key.split(".")[-1] for key in keys]+["time", "steps", "samples/s", "vram", ] + df = pd.DataFrame(columns=columns) + + for settings in product(*values): + print({keys[i]: setting for i, setting in enumerate(settings)}) + for i, setting in enumerate(settings): + setattr_recursive(base_config, keys[i], setting) + + try: + seconds, steps, samples_par_second, memory = main(base_config) + except Exception as e: + print(e) + seconds, steps, samples_par_second, memory = 0, 0, 0, 0 + + data = list(settings) + [seconds, steps, samples_par_second, memory] + df.loc[len(df)] = data + + df.to_csv("speed_test.csv") + + else: + main(base_config) \ No newline at end of file