Skip to content

Commit

Permalink
#5773: Move SD model to demo folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Sudharsan-V committed May 10, 2024
1 parent 54881fd commit a438715
Show file tree
Hide file tree
Showing 63 changed files with 8,068 additions and 56 deletions.
32 changes: 32 additions & 0 deletions models/demos/wormhole/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Stable_diffusion Model

## Introduction
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.

# Details
The entry point to functional_stable_diffusion model is UNet2DConditionModel in `models/demos/wormhole/stable_diffusion/tt2/ttnn_functional_unet_2d_condition_model.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `CompVis/stable-diffusion-v1-4` version from huggingface as our reference.

# Inputs
Inputs by default are provided from `input_data.json`. If you wish to change the inputs, provide a different path to test_demo.We do not recommend modifying `input_data.json` file.

## How to Run

To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables.
For more information, refer [installation and build guide](https://tenstorrent.github.io/tt-metal/latest/get_started/get_started.html#install-and-build).

Use `pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` to run the demo.

If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo`

Our second demo is designed to run poloclub/diffusiondb dataset, run this with `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb`.

If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps, use `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[<num_prompts>-<num_inference_steps>]`

Note: ttnn stable diffusion utilizes `PNDMScheduler` and requires `num_inference_steps to be greater than or equal to 4`. [Reference](https://arxiv.org/pdf/2202.09778)

# Metrics Interpretation
`FID Score (Fréchet Inception Distance)` evaluates the quality of generated images by measuring the similarity between their feature distributions and those of real images. A lower FID score indicates better similarity between generated and real images.
For more information, refer [FID Score](https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html).

`CLIP Score` measures the similarity between the generated images and the input prompts. Higher CLIP scores indicate better alignment between the generated images and the provided text prompts.
For more information, refer [CLIP Score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html).
41 changes: 41 additions & 0 deletions models/demos/wormhole/stable_diffusion/custom_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
from torch import nn
import ttnn


def preprocess_groupnorm_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter


def preprocess_conv_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter


def custom_preprocessor(model, name):
parameters = {}
if isinstance(model, nn.GroupNorm):
parameters["weight"] = preprocess_groupnorm_parameter(model.weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_groupnorm_parameter(model.bias, dtype=ttnn.bfloat16)

if isinstance(model, nn.Conv2d):
weight = torch.permute(model.weight, (2, 3, 0, 1))
parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_conv_parameter(model.bias, dtype=ttnn.bfloat16)

if isinstance(model, (nn.Linear, nn.LayerNorm)):
weight = model.weight.T.contiguous()
while len(weight.shape) < 4:
weight = weight.unsqueeze(0)
parameters["weight"] = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
if model.bias is not None:
bias = model.bias
while len(bias.shape) < 4:
bias = bias.unsqueeze(0)
parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
return parameters
Loading

0 comments on commit a438715

Please sign in to comment.