Skip to content

Commit

Permalink
fix face_crop_aug not working on finetune method, prepare upscaler
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Apr 22, 2023
1 parent 2204362 commit 884e6bf
Show file tree
Hide file tree
Showing 3 changed files with 403 additions and 10 deletions.
64 changes: 57 additions & 7 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def __call__(

# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[2:] == (height // 8, width // 8):
if init_image.size()[1:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def __call__(
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
Expand Down Expand Up @@ -2318,6 +2318,22 @@ def __getattr__(self, item):
else:
networks = []

# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)

us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value

print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)

# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
Expand Down Expand Up @@ -2590,7 +2606,7 @@ def resize_images(imgs, size):
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
Expand Down Expand Up @@ -2639,6 +2655,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling

print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
Expand All @@ -2657,12 +2675,32 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)

# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
if args.highres_fix_latents_upscaling:
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height

if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st

# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)

elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
Expand All @@ -2671,10 +2709,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
) # , antialias=True)
images_1st = images_1st.to(org_dtype)

else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]

batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
if not args.highres_fix_latents_upscaling:
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
Expand Down Expand Up @@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)

parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)
Expand Down
7 changes: 4 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,10 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_

# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
size = min(self.height, self.width) # 短いほう
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
Expand All @@ -872,7 +873,7 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_
else:
# range指定があるときのみ、すこしだけランダムに(わりと適当)
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)

p1 = max(0, min(p1, length - target_size))
Expand Down
Loading

0 comments on commit 884e6bf

Please sign in to comment.