From 16b7ad57bf377f408e103f5a65fe557b1d53aede Mon Sep 17 00:00:00 2001 From: lawrence-cj Date: Wed, 27 Nov 2024 18:35:16 +0800 Subject: [PATCH] pre-commit & and still need to be re-format into current code-base Signed-off-by: lawrence-cj --- README.md | 4 +- train_scripts/make_buckets.py | 116 ++++++++++++++++++---------------- train_scripts/train_local.py | 89 ++++++++++++++------------ 3 files changed, 113 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 277ec91..372d384 100644 --- a/README.md +++ b/README.md @@ -222,6 +222,7 @@ bash train_scripts/train.sh \ ``` Local training with bucketing and VAE embedding caching: + ```bash # Prepare buckets and cache VAE embeds python train_scripts/make_buckets.py \ @@ -235,11 +236,10 @@ bash train_scripts/train_local.sh \ --data.buckets_file=buckets.json \ --train.train_batch_size=30 ``` + Using the AdamW optimizer, training with a batch size of 30 on 1024x1024 resolution consumes ~48GB VRAM on an NVIDIA A6000 GPU. Each training iteration takes ~7.5 seconds. - - # 💻 4. Metric toolkit Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md). diff --git a/train_scripts/make_buckets.py b/train_scripts/make_buckets.py index 4a4f64f..36c4ed4 100644 --- a/train_scripts/make_buckets.py +++ b/train_scripts/make_buckets.py @@ -1,17 +1,20 @@ -import torch -from diffusion.model.builder import get_vae, vae_encode -from diffusion.utils.config import SanaConfig -import pyrallis -from PIL import Image -import torchvision.transforms as T +import json +import math import os import os.path as osp -from torchvision.transforms import InterpolationMode -import json +from itertools import chain + +import pyrallis +import torch +import torchvision.transforms as T +from PIL import Image from torch.utils.data import DataLoader +from torchvision.transforms import InterpolationMode from tqdm import tqdm -import math -from itertools import chain + +from diffusion.model.builder import get_vae, vae_encode +from diffusion.utils.config import SanaConfig + @pyrallis.wrap() def main(config: SanaConfig) -> None: @@ -22,16 +25,16 @@ def main(config: SanaConfig) -> None: step = 32 ratios_array = [] - while(min_size != max_size): + while min_size != max_size: width = int(preferred_pixel_count / min_size) - if(width % step != 0): - mod = width % step - if(mod < step//2): + if width % step != 0: + mod = width % step + if mod < step // 2: width -= mod else: width += step - mod - ratio = min_size / width + ratio = min_size / width ratios_array.append((ratio, (int(min_size), width))) min_size += step @@ -43,25 +46,31 @@ def get_closest_ratio(height: float, width: float): def get_preffered_size(height: float, width: float): pixel_count = height * width - + scale = math.sqrt(pixel_count / preferred_pixel_count) return height / scale, width / scale class BucketsDataset(torch.utils.data.Dataset): def __init__(self, data_dir, skip_files): valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"} - self.files = ([ - osp.join(data_dir, f) for f in os.listdir(data_dir) - if osp.isfile(osp.join(data_dir, f)) and osp.splitext(f)[1].lower() in valid_extensions and osp.join(data_dir, f) not in skip_files ]) - - self.transform = T.Compose([ - T.ToTensor(), - T.Normalize([0.5], [0.5]), - ]) - + self.files = [ + osp.join(data_dir, f) + for f in os.listdir(data_dir) + if osp.isfile(osp.join(data_dir, f)) + and osp.splitext(f)[1].lower() in valid_extensions + and osp.join(data_dir, f) not in skip_files + ] + + self.transform = T.Compose( + [ + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ] + ) + def __len__(self): return len(self.files) - + def __getitem__(self, idx): path = self.files[idx] img = Image.open(path).convert("RGB") @@ -70,11 +79,11 @@ def __getitem__(self, idx): crop = T.Resize(ratio[1], interpolation=InterpolationMode.BICUBIC) return { - 'img': self.transform(crop(img)), - 'size': torch.tensor([ratio[1][0], ratio[1][1]]), - 'prefsize': torch.tensor([prefsize[0], prefsize[1]]), - 'ratio': ratio[0], - 'path': path + "img": self.transform(crop(img)), + "size": torch.tensor([ratio[1][0], ratio[1][1]]), + "prefsize": torch.tensor([prefsize[0], prefsize[1]]), + "ratio": ratio[0], + "path": path, } vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, "cuda").to(torch.float16) @@ -82,14 +91,16 @@ def __getitem__(self, idx): def encode_images(batch, vae): with torch.no_grad(): z = vae_encode( - config.vae.vae_type, vae, batch, + config.vae.vae_type, + vae, + batch, sample_posterior=config.vae.sample_posterior, # Adjust as necessary - device="cuda" + device="cuda", ) return z if os.path.exists(config.data.buckets_file): - with open(config.data.buckets_file, 'r') as json_file: + with open(config.data.buckets_file) as json_file: buckets = json.load(json_file) existings_images = set(chain.from_iterable(buckets.values())) else: @@ -101,36 +112,35 @@ def add_to_list(key, item): buckets[key].append(item) else: buckets[key] = [item] - + for path in config.data.data_dir: - print(f'Processing {path}') + print(f"Processing {path}") dataset = BucketsDataset(path, existings_images) dataloader = DataLoader(dataset, batch_size=1) for batch in tqdm(dataloader): - img = batch['img'] - size = batch['size'] - ratio = batch['ratio'] - image_path = batch['path'] - prefsize = batch['prefsize'] + img = batch["img"] + size = batch["size"] + ratio = batch["ratio"] + image_path = batch["path"] + prefsize = batch["prefsize"] encoded = encode_images(img.to(torch.half), vae) - + for i in range(0, len(encoded)): filename_wo_ext = os.path.splitext(os.path.basename(image_path[i]))[0] add_to_list(str(ratio[i].item()), image_path[i]) - - torch.save({ - 'img': encoded[i].detach().clone(), - 'size': size[i], - 'prefsize': prefsize[i], - 'ratio': ratio[i] - }, f"{path}/{filename_wo_ext}_img.npz") - - with open(config.data.buckets_file, 'w') as json_file: + + torch.save( + {"img": encoded[i].detach().clone(), "size": size[i], "prefsize": prefsize[i], "ratio": ratio[i]}, + f"{path}/{filename_wo_ext}_img.npz", + ) + + with open(config.data.buckets_file, "w") as json_file: json.dump(buckets, json_file, indent=4) for ratio in buckets.keys(): - print(f'{float(ratio):.2f}: {len(buckets[ratio])}') + print(f"{float(ratio):.2f}: {len(buckets[ratio])}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index 084cf83..bc20452 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -29,14 +29,20 @@ import numpy as np import pyrallis import torch +import torch.utils +import torch.utils.data from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.utils import DistributedType from PIL import Image from termcolor import colored -import torch.utils -import torch.utils.data + warnings.filterwarnings("ignore") # ignore warning +import gc +import json +import math +import random + from diffusion import DPMS, FlowEuler, Scheduler from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode from diffusion.model.respace import compute_density_for_timestep_sampling @@ -47,10 +53,6 @@ from diffusion.utils.lr_scheduler import build_lr_scheduler from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, set_random_seed from diffusion.utils.optimizer import build_optimizer -import json -import random -import math -import gc os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -61,10 +63,13 @@ def set_fsdp_env(): os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" + image_index = 0 + + @torch.inference_mode() def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): - + torch.cuda.empty_cache() vis_sampler = config.scheduler.vis_sampler model = accelerator.unwrap_model(model).eval() @@ -127,7 +132,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): model_kwargs=model_kwargs, schedule="FLOW", ) - + denoised = dpm_solver.sample( z, steps=24, @@ -141,7 +146,7 @@ def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): latents.append(denoised) torch.cuda.empty_cache() - + del_vae = False if vae is None: vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) @@ -236,12 +241,9 @@ def concatenate_images(image_caption, images_per_row=5, image_format="webp"): return image_logs -class RatioBucketsDataset(): - def __init__( - self, - buckets_file - ): - with open(buckets_file, 'r') as file: +class RatioBucketsDataset: + def __init__(self, buckets_file): + with open(buckets_file) as file: self.buckets = json.load(file) def __getitem__(self, idx): @@ -249,29 +251,29 @@ def __getitem__(self, idx): loader = random.choice(self.loaders) try: - return next(loader) + return next(loader) except StopIteration: self.loaders.remove(loader) print(f"bucket ended, {len(self.loaders)}") def __len__(self): return self.size - + def make_loaders(self, batch_size): self.loaders = [] self.size = 0 for bucket in self.buckets.keys(): dataset = ImageDataset(self.buckets[bucket]) - - loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=False) + + loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=False + ) self.loaders.append(iter(loader)) self.size += math.ceil(len(dataset) / batch_size) + class ImageDataset(torch.utils.data.Dataset): - def __init__( - self, - images - ): + def __init__(self, images): self.images = images def getdata(self, idx): @@ -279,16 +281,16 @@ def getdata(self, idx): filename_wo_ext = os.path.splitext(os.path.basename(path))[0] text_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}.txt") - with open(text_file, 'r') as file: + with open(text_file) as file: prompt = file.read() cache_file = os.path.join(os.path.dirname(path), f"{filename_wo_ext}_img.npz") cached_data = torch.load(cache_file) - size = cached_data['prefsize'] - ratio = cached_data['ratio'] - vae_embed = cached_data['img'] - + size = cached_data["prefsize"] + ratio = cached_data["ratio"] + vae_embed = cached_data["img"] + data_info = { "img_hw": size, "aspect_ratio": torch.tensor(ratio.item()), @@ -313,6 +315,7 @@ def __getitem__(self, idx): def __len__(self): return len(self.images) + def train(config, args, accelerator, model, optimizer, lr_scheduler, dataset, train_diffusion, logger): if getattr(config.train, "debug_nan", False): DebugUnderflowOverflow(model) @@ -358,19 +361,21 @@ def check_nan_inf(model): shuffled_prompts = [] for prompt in prompts: tags = prompt.split(",") # Split the string into a list of tags - random.shuffle(tags) # Shuffle the tags + random.shuffle(tags) # Shuffle the tags shuffled_prompts.append(",".join(tags)) # Join them back into a string if "T5" in config.text_encoder.text_encoder_name: with torch.no_grad(): txt_tokens = tokenizer( - shuffled_prompts, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + shuffled_prompts, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ).to(accelerator.device) y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] y_mask = txt_tokens.attention_mask[:, None, None] - elif ( - "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name - ): + elif "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name: with torch.no_grad(): if not config.text_encoder.chi_prompt: max_length_all = config.text_encoder.model_max_length @@ -430,13 +435,13 @@ def check_nan_inf(model): # Check if the loss is NaN if torch.isnan(loss): loss_nan_timer += 1 - print(f'Skip nan: {loss_nan_timer}') + print(f"Skip nan: {loss_nan_timer}") continue # Skip the rest of the loop iteration if loss is NaN accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) - + optimizer.step() lr_scheduler.step() accelerator.wait_for_everyone() @@ -462,9 +467,7 @@ def check_nan_inf(model): ) log_buffer.average() - current_step = ( - global_step - step // config.train.train_batch_size - ) % len(dataset) + current_step = (global_step - step // config.train.train_batch_size) % len(dataset) current_step = len(dataset) if current_step == 0 else current_step info = ( f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {len(dataset)}, " @@ -639,7 +642,7 @@ def main(cfg: SanaConfig) -> None: if getattr(config.train, "deterministic_validation", False) else None ) - + tokenizer = text_encoder = None if not config.data.load_text_feat: tokenizer, text_encoder = get_tokenizer_and_text_encoder( @@ -732,7 +735,11 @@ def main(cfg: SanaConfig) -> None: torch.save({"caption_embeds": caption_emb, "emb_mask": caption_emb_mask}, prompt_embed_path) null_tokens = tokenizer( - "bad artwork,ugly,sketch,poorly drawn,messy,noisy,score: 0/10,blurry,low quality,old", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + "bad artwork,ugly,sketch,poorly drawn,messy,noisy,score: 0/10,blurry,low quality,old", + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ).to(accelerator.device) if "T5" in config.text_encoder.text_encoder_name: null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] @@ -849,7 +856,7 @@ def main(cfg: SanaConfig) -> None: config.train.lr_schedule_args["num_warmup_steps"] * num_replicas ) lr_scheduler = build_lr_scheduler(config.train, optimizer, dataset, 1) - + logger.warning( f"{colored(f'Basic Setting: ', 'green', attrs=['bold'])}" f"lr: {config.train.optimizer['lr']:.9f}, bs: {config.train.train_batch_size}, gc: {config.train.grad_checkpointing}, "