Skip to content

Commit

Permalink
rewrite SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 21, 2023
1 parent c3acb72 commit 5e8fdc2
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# ## Basic setup

import io
from pathlib import Path

from modal import Image, Mount, Stub, asgi_app, gpu, method
Expand All @@ -28,18 +29,6 @@
# triggers a rebuild.


def download_models():
from huggingface_hub import snapshot_download

ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"]
snapshot_download(
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore
)
snapshot_download(
"stabilityai/stable-diffusion-xl-refiner-1.0", ignore_patterns=ignore
)


sdxl_image = (
Image.debian_slim()
.apt_install(
Expand All @@ -52,11 +41,18 @@ def download_models():
"accelerate~=0.21",
"safetensors~=0.3",
)
.run_function(download_models)
)

stub = Stub("stable-diffusion-xl")

with sdxl_image.imports():
import fastapi.staticfiles
import torch
from diffusers import DiffusionPipeline
from fastapi import FastAPI
from fastapi.responses import Response
from huggingface_hub import snapshot_download

# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
Expand All @@ -68,10 +64,21 @@ def download_models():

@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240, image=sdxl_image)
class Model:
def __enter__(self):
import torch
from diffusers import DiffusionPipeline
def __build__(self):
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]
snapshot_download(
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore
)
snapshot_download(
"stabilityai/stable-diffusion-xl-refiner-1.0",
ignore_patterns=ignore,
)

def __enter__(self):
load_options = dict(
torch_dtype=torch.float16,
use_safetensors=True,
Expand Down Expand Up @@ -115,8 +122,6 @@ def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
image=image,
).images[0]

import io

byte_stream = io.BytesIO()
image.save(byte_stream, format="PNG")
image_bytes = byte_stream.getvalue()
Expand Down Expand Up @@ -160,15 +165,10 @@ def main(prompt: str):
)
@asgi_app()
def app():
import fastapi.staticfiles
from fastapi import FastAPI

web_app = FastAPI()

@web_app.get("/infer/{prompt}")
async def infer(prompt: str):
from fastapi.responses import Response

image_bytes = Model().inference.remote(prompt)

return Response(image_bytes, media_type="image/png")
Expand Down

0 comments on commit 5e8fdc2

Please sign in to comment.