Skip to content

Commit

Permalink
Templating for SDXL examples (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
aksh-at authored Feb 28, 2024
1 parent 3b088b3 commit 0cd9f2a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 39 deletions.
18 changes: 3 additions & 15 deletions 06_gpu_and_ml/stable_diffusion/frontend/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
<script src="https://cdn.tailwindcss.com"></script>

<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Stable Diffusion XL — Modal</title>
<title>{{ model_name }} — Modal</title>
</head>
<body x-data="state()">
<div class="max-w-3xl mx-auto pt-4 pb-8 px-10 sm:py-12 sm:px-6 lg:px-8">
<h2 class="text-3xl font-medium text-center mb-10">
Stable Diffusion XL on Modal
{{ model_name }} on Modal
</h2>

<form
Expand Down Expand Up @@ -67,33 +67,21 @@ <h2 class="text-3xl font-medium text-center mb-10">
function state() {
return {
prompt: "a beautiful Japanese temple, butterflies flying around",
features: [],
submitted: "",
submittedFeatures: [],
loading: false,
imageURL: "",
async submitPrompt() {
if (!this.prompt) return;
this.submitted = this.prompt;
this.submittedFeatures = [...this.features];
this.loading = true;

const queryString = new URLSearchParams(
this.features.map((f) => ["features", f])
).toString();
const res = await fetch(`/infer/${this.submitted}?${queryString}`);
const res = await fetch(`{{ inference_url }}?prompt=${this.submitted}`);

const blob = await res.blob();
this.imageURL = URL.createObjectURL(blob);
this.loading = false;
console.log(this.imageURL);
},
toggleFeature(featureName) {
let index = this.features.indexOf(featureName);
index == -1
? this.features.push(featureName)
: this.features.splice(index, 1);
},
};
}
</script>
Expand Down
29 changes: 19 additions & 10 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import io
from pathlib import Path

from modal import Image, Mount, Stub, asgi_app, build, enter, gpu, method
from modal import Image, Mount, Stub, asgi_app, build, enter, gpu, web_endpoint

# ## Define a container image
#
Expand Down Expand Up @@ -48,6 +48,7 @@
with sdxl_image.imports():
import torch
from diffusers import DiffusionPipeline
from fastapi import Response
from huggingface_hub import snapshot_download

# ## Load model and run inference
Expand Down Expand Up @@ -103,7 +104,7 @@ def enter(self):
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)

@method()
@web_endpoint()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
negative_prompt = "disfigured, ugly, deformed"
image = self.base(
Expand All @@ -122,10 +123,9 @@ def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
).images[0]

byte_stream = io.BytesIO()
image.save(byte_stream, format="PNG")
image_bytes = byte_stream.getvalue()
image.save(byte_stream, format="JPEG")

return image_bytes
return Response(content=byte_stream.getvalue(), media_type="image/jpeg")


# And this is our entrypoint; where the CLI is invoked. Explore CLI options
Expand Down Expand Up @@ -157,24 +157,33 @@ def main(prompt: str):

frontend_path = Path(__file__).parent / "frontend"

web_image = Image.debian_slim().pip_install("jinja2")


@stub.function(
image=web_image,
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
allow_concurrent_inputs=20,
)
@asgi_app()
def app():
import fastapi.staticfiles
from fastapi import FastAPI
from fastapi.responses import Response
from jinja2 import Template

web_app = FastAPI()

@web_app.get("/infer/{prompt}")
async def infer(prompt: str):
image_bytes = Model().inference.remote(prompt)
with open("/assets/index.html", "r") as f:
template_html = f.read()

template = Template(template_html)

return Response(image_bytes, media_type="image/png")
with open("/assets/index.html", "w") as f:
html = template.render(
inference_url=Model.inference.web_url,
model_name="Stable Diffusion XL",
)
f.write(html)

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
Expand Down
47 changes: 33 additions & 14 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,44 @@

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.pth" # Use the correct ckpt for your step setting!
ckpt = "sdxl_lightning_4step_unet.safetensors"


with image.imports():
import io

import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
from diffusers import (
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from fastapi import Response
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


@stub.cls(image=image, gpu="a100")
class Model:
@modal.build()
@modal.enter()
def load_weights(self):
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(
"cuda", torch.float16
)
unet.load_state_dict(
load_file(hf_hub_download(repo, ckpt), device="cuda")
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
base, torch_dtype=torch.float16, variant="fp16"
base, unet=unet, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
self.pipe.unet.load_state_dict(
torch.load(hf_hub_download(repo, ckpt), map_location="cuda")
)

self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config, timestep_spacing="trailing"
)

@modal.method()
def generate(
@modal.web_endpoint()
def inference(
self,
prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.",
):
Expand All @@ -48,29 +58,38 @@ def generate(
buffer = io.BytesIO()
image.save(buffer, format="JPEG")

return buffer.getvalue()
return Response(content=buffer.getvalue(), media_type="image/jpeg")


frontend_path = Path(__file__).parent / "frontend"

web_image = modal.Image.debian_slim().pip_install("jinja2")


@stub.function(
image=web_image,
mounts=[modal.Mount.from_local_dir(frontend_path, remote_path="/assets")],
allow_concurrent_inputs=20,
)
@modal.asgi_app()
def app():
import fastapi.staticfiles
from fastapi import FastAPI
from fastapi.responses import Response
from jinja2 import Template

web_app = FastAPI()

@web_app.get("/infer/{prompt}")
async def infer(prompt: str):
image_bytes = Model().generate.remote(prompt)
with open("/assets/index.html", "r") as f:
template_html = f.read()

return Response(image_bytes, media_type="image/jpeg")
template = Template(template_html)

with open("/assets/index.html", "w") as f:
html = template.render(
inference_url=Model.inference.web_url,
model_name="SDXL Lightning",
)
f.write(html)

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
Expand Down

0 comments on commit 0cd9f2a

Please sign in to comment.