Skip to content

Commit

Permalink
perf updates for turbo example
Browse files Browse the repository at this point in the history
  • Loading branch information
aksh-at committed Dec 12, 2023
1 parent c6da4fc commit 8e84b3f
Showing 1 changed file with 27 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def download_models():

snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)

snapshot_download("madebyollin/sdxl-vae-fp16-fix", ignore_patterns=ignore)


stub = Stub("stable-diffusion-xl-turbo")

Expand All @@ -41,39 +43,53 @@ def download_models():
from io import BytesIO

import torch
from diffusers import AutoPipelineForImage2Image
from diffusers import AutoencoderKL, AutoPipelineForImage2Image
from diffusers.utils import load_image
from PIL import Image, ImageChops
from PIL import Image


@stub.cls(
gpu=gpu.A100(memory=40), container_idle_timeout=240, image=inference_image, keep_warm=1
gpu=gpu.A100(memory=40),
image=inference_image,
keep_warm=1,
cloud="oci", # remove this later
)
class Model:
def __enter__(self):
self.pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
device_map="auto",
variant="fp16"
variant="fp16",
vae=AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
device_map="auto",
),
)

@web_endpoint(method="POST")
async def inference(self, request: Request):
t0 = time.time()
t00 = time.time()
body = await request.body()
print("loading time:", time.time() - t00)
body_json = json.loads(body)
img_data_in = base64.b64decode(
body_json["image"].split(",")[1]
) # read data-uri
prompt = body_json["prompt"]

init_image = load_image(Image.open(BytesIO(img_data_in))).resize((512,512))
init_image = load_image(Image.open(BytesIO(img_data_in))).resize(
(512, 512)
)
num_inference_steps = 2
# note: anything under 0.5 strength gives blurry results
strength = 0.5
strength = 0.7
assert num_inference_steps * strength >= 1

print("start time:", time.time() - t00)

t0 = time.time()
image = self.pipe(
prompt,
image=init_image,
Expand All @@ -86,12 +102,12 @@ async def inference(self, request: Request):
print("infer time:", time.time() - t0)

byte_stream = BytesIO()
image.save(byte_stream, format="png")
image.save(byte_stream, format="jpeg")
img_data_out = byte_stream.getvalue()

print("total time:", time.time() - t0)

output_data = b"data:image/png;base64," + base64.b64encode(
output_data = b"data:image/jpeg;base64," + base64.b64encode(
img_data_out
)

Expand All @@ -106,6 +122,7 @@ async def inference(self, request: Request):
mounts=[Mount.from_local_dir(static_path, remote_path="/assets")],
image=web_image,
keep_warm=1,
allow_concurrent_inputs=10,
)
@asgi_app()
def fastapi_app():
Expand All @@ -118,9 +135,7 @@ def fastapi_app():
template = Template(template_html)

with open("/assets/index.html", "w") as f:
html = template.render(
inference_url=Model.inference.web_url
)
html = template.render(inference_url=Model.inference.web_url)
f.write(html)

web_app.mount("/", StaticFiles(directory="/assets", html=True))
Expand Down

0 comments on commit 8e84b3f

Please sign in to comment.