Skip to content

Commit

Permalink
#0: Added interactive SD demo
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed May 23, 2024
1 parent a0370d0 commit e681231
Showing 1 changed file with 84 additions and 70 deletions.
154 changes: 84 additions & 70 deletions models/experimental/functional_stable_diffusion/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from loguru import logger
from tqdm.auto import tqdm
from datasets import load_dataset
import os

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
Expand Down Expand Up @@ -234,46 +235,93 @@ def run_demo_inference_diffusiondb(
device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size=(256, 256)
):
disable_persistent_kernel_cache()
device.enable_program_cache()

# 0. Load a sample prompt from the dataset
dataset = load_dataset("poloclub/diffusiondb", "2m_random_1k")
data_1k = dataset["train"]

height, width = image_size

for i in range(num_prompts):
experiment_name = f"diffusiondb_{i}__{height}x{width}"
input_prompt = [f"{data_1k['prompt'][i]}"]
logger.info(f"input_prompts: {input_prompt}")
torch_device = "cpu"
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

image = np.array(data_1k["image"][i])
ref_images = Image.fromarray(image)
ref_img_path = f"{experiment_name}_ref.png"
ref_images.save(ref_img_path)
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# 4. load the K-LMS scheduler with some fitting parameters.
ttnn_scheduler = TtPNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, device=device
)

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
text_encoder.to(torch_device)
unet.to(torch_device)

# 4. load the K-LMS scheduler with some fitting parameters.
ttnn_scheduler = TtPNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, device=device
)
config = unet.config
parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
input_height = 64
input_width = 64
reader_patterns_cache = {} if height == 512 and width == 512 else None
model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)

torch_device = "cpu"
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)
guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = 1

guidance_scale = 7.5 # Scale for classifier-free guidance
generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
batch_size = len(input_prompt)
# Initial random noise
latents = torch.randn(
(batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
generator=generator,
)
latents = latents.to(torch_device)

ttnn_scheduler.set_timesteps(num_inference_steps)

latents = latents * ttnn_scheduler.init_noise_sigma
rand_latents = torch.tensor(latents)
rand_latents = ttnn.from_torch(rand_latents, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
ttnn_latent_model_input = ttnn.concat([rand_latents, rand_latents], dim=0)
_tlist = []
for t in ttnn_scheduler.timesteps:
_t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
_t = _t.unsqueeze(0).unsqueeze(0)
_t = _t.permute(2, 0, 1, 3) # pre-permute temb
_t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
_tlist.append(_t)

time_step = ttnn_scheduler.timesteps.tolist()

interactive = os.environ.get("INTERACTIVE_SD_DEMO", "0") == "1"
i = 0
while i < num_prompts:
ttnn_scheduler.set_timesteps(num_inference_steps)
if interactive:
print("Enter the input promt, or q to exit:")
input_prompt = [input()]
if input_prompt[0] == "q":
break
else:
input_prompt = [f"{data_1k['prompt'][i]}"]

image = np.array(data_1k["image"][i])
ref_images = Image.fromarray(image)
ref_img_path = f"{experiment_name}_ref.png"
ref_images.save(ref_img_path)
i = i + 1

experiment_name = f"diffusiondb_{i}__{height}x{width}"
logger.info(f"input_prompts: {input_prompt}")

## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model.
# Tokenizer and Text Encoder
Expand All @@ -297,44 +345,10 @@ def run_demo_inference_diffusiondb(
ttnn_text_embeddings = ttnn.from_torch(
ttnn_text_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
# Initial random noise
latents = torch.randn(
(batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
generator=generator,
)
latents = latents.to(torch_device)

ttnn_scheduler.set_timesteps(num_inference_steps)

latents = latents * ttnn_scheduler.init_noise_sigma
ttnn_latents = torch.tensor(latents)
ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

config = unet.config
parameters = preprocess_model_parameters(
initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)
input_height = 64
input_width = 64
reader_patterns_cache = {} if height == 512 and width == 512 else None
# ttnn_latents = ttnn.from_torch(ttnn_latents, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
ttnn_latent_model_input = ttnn.concat([ttnn_latents, ttnn_latents], dim=0)
_tlist = []
for t in ttnn_scheduler.timesteps:
_t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
_t = _t.unsqueeze(0).unsqueeze(0)
_t = _t.permute(2, 0, 1, 3) # pre-permute temb
_t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
_tlist.append(_t)

time_step = ttnn_scheduler.timesteps.tolist()

model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)
iter = 0
ttnn_latents = rand_latents
# # Denoising loop
for index in range(len(time_step)):
for index in tqdm(range(len(time_step))):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
ttnn_latent_model_input = ttnn.concat([ttnn_latents, ttnn_latents], dim=0)
_t = _tlist[index]
Expand All @@ -351,12 +365,12 @@ def run_demo_inference_diffusiondb(
return_dict=True,
config=config,
)
print(f"Sample: {iter}")

# perform guidance
noise_pred = tt_guide(ttnn_output, guidance_scale)
ttnn_latents = ttnn_scheduler.step(noise_pred, t, ttnn_latents).prev_sample
_save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")
if not interactive:
_save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")

iter += 1
enable_persistent_kernel_cache()
Expand All @@ -375,15 +389,15 @@ def run_demo_inference_diffusiondb(
ttnn_output_path = f"{experiment_name}_ttnn.png"
pil_images.save(ttnn_output_path)

ref_paths = [ref_img_path, ref_img_path]
ttnn_paths = [ttnn_output_path, ttnn_output_path]

ref_images = preprocess_images(ref_paths)
ttnn_images = preprocess_images(ttnn_paths)
if not interactive:
ref_paths = [ref_img_path, ref_img_path]
ref_images = preprocess_images(ref_paths)

# Calculate FID scores
fid_score_ref_ttnn = calculate_fid_score(ref_images, ttnn_images)
logger.info(f"FID Score (Reference vs TTNN): {fid_score_ref_ttnn}")
# Calculate FID scores
fid_score_ref_ttnn = calculate_fid_score(ref_images, ttnn_images)
logger.info(f"FID Score (Reference vs TTNN): {fid_score_ref_ttnn}")

# calculate Clip score
clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
Expand Down

0 comments on commit e681231

Please sign in to comment.