-
Notifications
You must be signed in to change notification settings - Fork 3
/
RealTimeSketch.py
97 lines (86 loc) · 2.97 KB
/
RealTimeSketch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from PIL import Image
import cv2
import numpy as np
import sys
import webbrowser
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, LCMScheduler
from diffusers.utils import load_image
import torch
sd_model_path ="Models/SD"
lcm_lora_path = "Models/lcm_lora"
controlnet_canny_path ="Models/controlnet/canny"
controlnet_scribble_path ="Models/controlnet/scribble"
# パイプラインをグローバル変数として保持します
pipe = None
low_threshold = 100
high_threshold = 200
img_width = 512
img_height = 768
init_img = Image.fromarray(np.ones((img_height, img_width, 3), dtype=np.uint8) * 255)
init_img_path = 'init.png'
init_img.save(init_img_path)
def Illust_generation_scribble(np_img, prompt: str, c_weight_input: float):
controlnet = ControlNetModel.from_pretrained(controlnet_scribble_path, torch_dtype=torch.float16)
image = cv2.Canny(np_img, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
if np_img is None:
return
global pipe
if pipe is None:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None,
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(lcm_lora_path)
generator = torch.Generator("cuda").manual_seed(1)
image = pipe(
prompt,
image=canny_image,
num_inference_steps=4,
guidance_scale=1.5,
controlnet_conditioning_scale=float(c_weight_input),
cross_attention_kwargs={"scale": 1},
generator=generator
).images[0]
return image
with gr.Blocks() as ui:
prompt_input = gr.Textbox(label="prompt", value="1girl")
c_weight_input = gr.Slider(minimum=0, maximum=1.0, value=0.5, label="control weight")
with gr.Row():
with gr.Column():
image_input = gr.Image(
source="upload",
tool="color-sketch",
value=init_img_path,
width=img_width,
height=img_height,
interactive=True,
)
with gr.Column():
image_color_output = gr.Image(width=img_width, height=img_height)
image_input.change(
fn=Illust_generation_scribble,
inputs=[image_input, prompt_input, c_weight_input],
outputs=[image_color_output],
show_progress='hidden'
)
prompt_input.change(
fn=Illust_generation_scribble,
inputs=[image_input, prompt_input, c_weight_input],
outputs=[image_color_output],
show_progress='hidden'
)
c_weight_input.change(
fn=Illust_generation_scribble,
inputs=[image_input, prompt_input, c_weight_input],
outputs=[image_color_output],
show_progress='hidden'
)
ui.queue()
ui.launch()