Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support wuerstchen training #853

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,7 @@ def debug_dataset(train_dataset, show_input_ids=False):

if show_input_ids:
print(f"input ids: {iid}")
if "input_ids2" in example:
if "input_ids2" in example and example["input_ids2"] is not None:
print(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
Expand All @@ -2012,6 +2012,11 @@ def debug_dataset(train_dataset, show_input_ids=False):
cond_img = cond_img[:, :, ::-1]
if os.name == "nt":
cv2.imshow("cond_img", cond_img)

for key in example.keys():
if key in ["images", "conditioning_images", "input_ids", "input_ids2"]:
continue
print(f"{key}: {example[key][j] if example[key] is not None else None}")

if os.name == "nt": # only windows
cv2.imshow("img", im)
Expand Down
196 changes: 196 additions & 0 deletions wuerstchen/wuerstchen_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# use Diffusers' pipeline to generate images

import argparse
import datetime
import math
import os
import random
import re
from einops import repeat
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler

# from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
from wuerstchen_train import EfficientNetEncoder


def generate(args):
dtype = torch.float32
if args.fp16:
dtype = torch.float16
elif args.bf16:
dtype = torch.bfloat16

device = args.device

os.makedirs(args.outdir, exist_ok=True)

# load tokenizer
print("load tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")

# load text encoder
print("load text encoder")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=dtype
)

# load prior model
print("load prior model")
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="prior", torch_dtype=dtype
)

# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)

for child in module.children():
fn_recursive_set_mem_eff(child)

fn_recursive_set_mem_eff(model)

# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(prior, True)

# load pipeline
print("load pipeline")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=prior,
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
)
pipeline = pipeline.to(device, torch_dtype=dtype)

# generate image
while True:
width = args.w
height = args.h
seed = args.seed
negative_prompt = None

if args.interactive:
print("prompt:")
prompt = input()
if prompt == "":
break

# parse prompt
prompt_args = prompt.split(" --")
prompt = prompt_args[0]

for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
continue

m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seed = int(m.group(1))
print(f"seed: {seed}")
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
else:
prompt = args.prompt
negative_prompt = args.negative_prompt

if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)

with torch.autocast(device):
image = pipeline(
prompt,
negative_prompt=negative_prompt,
# prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
width=width,
height=height,
).images[0]

# save image
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
image.save(os.path.join(args.outdir, f"image_{timestamp}.png"))

if not args.interactive:
break

print("Done!")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

# train_util.add_sd_models_arguments(parser)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)

parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--outdir", type=str, default=".")
parser.add_argument("--w", type=int, default=1024)
parser.add_argument("--h", type=int, default=1024)
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
generate(args)
Loading