diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 313e40488..09c680002 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -945,7 +945,7 @@ def __call__( # encode the init image into latents and scale the latents init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[2:] == (height // 8, width // 8): + if init_image.size()[1:] == (height // 8, width // 8): init_latents = init_image else: if vae_batch_size >= batch_size: @@ -1015,7 +1015,7 @@ def __call__( if self.control_nets: if reginonal_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt else: text_emb_last = text_embeddings noise_pred = original_control_net.call_unet_and_control_net( @@ -2318,6 +2318,22 @@ def __getattr__(self, item): else: networks = [] + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + # ControlNetの処理 control_nets: List[ControlNetInfo] = [] if args.control_net_models: @@ -2590,7 +2606,7 @@ def resize_images(imgs, size): np_mask = np_mask[:, :, i] size = np_mask.shape else: - np_mask = np.full(size, 255, dtype=np.uint8) + np_mask = np.full(size, 255, dtype=np.uint8) mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) network.set_region(i, i == len(networks) - 1, mask) mask_images = None @@ -2639,6 +2655,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # highres_fixの処理 if highres_fix and not highres_1st: # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + print("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -2657,12 +2675,32 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ext.network_muls, ext.num_sub_prompts, ) - batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する print("process 2nd stage") - if args.highres_fix_latents_upscaling: + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する org_dtype = images_1st.dtype if images_1st.dtype == torch.bfloat16: images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない @@ -2671,10 +2709,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ) # , antialias=True) images_1st = images_1st.to(org_dtype) + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + batch_2nd = [] for i, (bd, image) in enumerate(zip(batch, images_1st)): - if not args.highres_fix_latents_upscaling: - image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) batch_2nd.append(bd_2nd) batch = batch_2nd @@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) + parser.add_argument( + "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" ) diff --git a/library/train_util.py b/library/train_util.py index cfb5b7ee0..d43e0075d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -845,9 +845,10 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_ # 画像サイズはsizeより大きいのでリサイズする face_size = max(face_w, face_h) + size = min(self.height, self.width) # 短いほう min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ if min_scale >= max_scale: # range指定がmin==max scale = min_scale else: @@ -872,7 +873,7 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_ else: # range指定があるときのみ、すこしだけランダムに(わりと適当) if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: + if face_size > size // 10 and face_size >= 40: p1 = p1 + random.randint(-face_size // 20, +face_size // 20) p1 = max(0, min(p1, length - target_size)) diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py new file mode 100644 index 000000000..c69f983cf --- /dev/null +++ b/tools/latent_upscaler.py @@ -0,0 +1,342 @@ +# 外部から簡単にupscalerを呼ぶためのスクリプト +# 単体で動くようにモデル定義も含めている + +import argparse +import glob +import os +import cv2 +from diffusers import AutoencoderKL + +from typing import Dict, List +import numpy as np + +import torch +from torch import nn +from tqdm import tqdm +from PIL import Image + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): + super(ResidualBlock, self).__init__() + + if out_channels is None: + out_channels = in_channels + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも + + # initialize weights + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + + out = self.relu2(out) + + return out + + +class Upscaler(nn.Module): + def __init__(self): + super(Upscaler, self).__init__() + + # define layers + # latent has 4 channels + + self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn1 = nn.BatchNorm2d(128) + self.relu1 = nn.ReLU(inplace=True) + + # resblocks + # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ + self.resblock1 = ResidualBlock(128) + self.resblock2 = ResidualBlock(128) + self.resblock3 = ResidualBlock(128) + self.resblock4 = ResidualBlock(128) + self.resblock5 = ResidualBlock(128) + self.resblock6 = ResidualBlock(128) + self.resblock7 = ResidualBlock(128) + self.resblock8 = ResidualBlock(128) + self.resblock9 = ResidualBlock(128) + self.resblock10 = ResidualBlock(128) + self.resblock11 = ResidualBlock(128) + self.resblock12 = ResidualBlock(128) + self.resblock13 = ResidualBlock(128) + self.resblock14 = ResidualBlock(128) + self.resblock15 = ResidualBlock(128) + self.resblock16 = ResidualBlock(128) + self.resblock17 = ResidualBlock(128) + self.resblock18 = ResidualBlock(128) + self.resblock19 = ResidualBlock(128) + self.resblock20 = ResidualBlock(128) + + # last convs + self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + + self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + self.bn3 = nn.BatchNorm2d(64) + self.relu3 = nn.ReLU(inplace=True) + + # final conv: output 4 channels + self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) + + # initialize weights + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + # initialize final conv weights to 0: 流行りのzero conv + nn.init.constant_(self.conv_final.weight, 0) + + def forward(self, x): + inp = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず + residual = x + x = self.resblock1(x) + x = self.resblock2(x) + x = self.resblock3(x) + x = self.resblock4(x) + x = x + residual + residual = x + x = self.resblock5(x) + x = self.resblock6(x) + x = self.resblock7(x) + x = self.resblock8(x) + x = x + residual + residual = x + x = self.resblock9(x) + x = self.resblock10(x) + x = self.resblock11(x) + x = self.resblock12(x) + x = x + residual + residual = x + x = self.resblock13(x) + x = self.resblock14(x) + x = self.resblock15(x) + x = self.resblock16(x) + x = x + residual + residual = x + x = self.resblock17(x) + x = self.resblock18(x) + x = self.resblock19(x) + x = self.resblock20(x) + x = x + residual + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + x = self.conv3(x) + x = self.bn3(x) + + # ここにreluを入れないほうがいい気がする + + x = self.conv_final(x) + + # network estimates the difference between the input and the output + x = x + inp + + return x + + def support_latents(self) -> bool: + return False + + def upscale( + self, + vae: AutoencoderKL, + lowreso_images: List[Image.Image], + lowreso_latents: torch.Tensor, + dtype: torch.dtype, + width: int, + height: int, + batch_size: int = 1, + vae_batch_size: int = 1, + ): + # assertion + assert lowreso_images is not None, "Upscaler requires lowreso image" + + # make upsampled image with lanczos4 + upsampled_images = [] + for lowreso_image in lowreso_images: + upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS)) + upsampled_images.append(upsampled_image) + + # convert to tensor: this tensor is too large to be converted to cuda + upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images] + upsampled_images = torch.stack(upsampled_images, dim=0) + upsampled_images = upsampled_images.to(dtype) + + # normalize to [-1, 1] + upsampled_images = upsampled_images / 127.5 - 1.0 + + # convert upsample images to latents with batch size + # print("Encoding upsampled (LANCZOS4) images...") + upsampled_latents = [] + for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): + batch = upsampled_images[i : i + vae_batch_size].to(vae.device) + with torch.no_grad(): + batch = vae.encode(batch).latent_dist.sample() + upsampled_latents.append(batch) + + upsampled_latents = torch.cat(upsampled_latents, dim=0) + + # upscale (refine) latents with this model with batch size + print("Upscaling latents...") + upscaled_latents = [] + for i in range(0, upsampled_latents.shape[0], batch_size): + with torch.no_grad(): + upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size])) + upscaled_latents = torch.cat(upscaled_latents, dim=0) + + return upscaled_latents * 0.18215 + + +# external interface: returns a model +def create_upscaler(**kwargs): + weights = kwargs["weights"] + model = Upscaler() + + print(f"Loading weights from {weights}...") + model.load_state_dict(torch.load(weights, map_location=torch.device("cpu"))) + return model + + +# another interface: upscale images with a model for given images from command line +def upscale_images(args: argparse.Namespace): + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + us_dtype = torch.float16 # TODO: support fp32/bf16 + os.makedirs(args.output_dir, exist_ok=True) + + # load VAE with Diffusers + assert args.vae_path is not None, "VAE path is required" + print(f"Loading VAE from {args.vae_path}...") + vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") + vae.to(DEVICE, dtype=us_dtype) + + # prepare model + print("Preparing model...") + upscaler: Upscaler = create_upscaler(weights=args.weights) + # print("Loading weights from", args.weights) + # upscaler.load_state_dict(torch.load(args.weights)) + upscaler.eval() + upscaler.to(DEVICE, dtype=us_dtype) + + # load images + image_paths = glob.glob(args.image_pattern) + images = [] + for image_path in image_paths: + image = Image.open(image_path) + image = image.convert("RGB") + + # make divisible by 8 + width = image.width + height = image.height + if width % 8 != 0: + width = width - (width % 8) + if height % 8 != 0: + height = height - (height % 8) + if width != image.width or height != image.height: + image = image.crop((0, 0, width, height)) + + images.append(image) + + # debug output + if args.debug: + for image, image_path in zip(images, image_paths): + image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS) + + basename = os.path.basename(image_path) + basename_wo_ext, ext = os.path.splitext(basename) + dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}") + image_debug.save(dest_file_name) + + # upscale + print("Upscaling...") + upscaled_latents = upscaler.upscale( + vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size + ) + upscaled_latents /= 0.18215 + + # decode with batch + print("Decoding...") + upscaled_images = [] + for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): + with torch.no_grad(): + batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample + batch = batch.to("cpu") + upscaled_images.append(batch) + upscaled_images = torch.cat(upscaled_images, dim=0) + + # tensor to numpy + upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy() + upscaled_images = (upscaled_images + 1.0) * 127.5 + upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8) + + upscaled_images = upscaled_images[..., ::-1] + + # save images + for i, image in enumerate(upscaled_images): + basename = os.path.basename(image_paths[i]) + basename_wo_ext, ext = os.path.splitext(basename) + dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}") + cv2.imwrite(dest_file_name, image) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--vae_path", type=str, default=None, help="VAE path") + parser.add_argument("--weights", type=str, default=None, help="Weights path") + parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern") + parser.add_argument("--output_dir", type=str, default=".", help="Output directory") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size") + parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size") + parser.add_argument("--debug", action="store_true", help="Debug mode") + + args = parser.parse_args() + upscale_images(args)