Skip to content

Commit

Permalink
Update SDXL Turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 21, 2023
1 parent d8d2555 commit 98a802c
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,14 @@

# ## Basic setup

from io import BytesIO
from pathlib import Path

from modal import Image, Stub, gpu, method
from modal import Image, Stub, build, enter, gpu, method

# ## Define a container image


def download_models():
from huggingface_hub import snapshot_download

# Ignore files that we don't need to speed up download time.
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]

snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)


image = (
Image.debian_slim()
.pip_install(
Expand All @@ -46,11 +34,18 @@ def download_models():
"accelerate~=0.25", # Allows `device_map="auto"``, which allows computation of optimized device_map
"safetensors~=0.4", # Enables safetensor format as opposed to using unsafe pickle format
)
.run_function(download_models)
)

stub = Stub("stable-diffusion-xl-turbo", image=image)

with image.imports():
import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
from huggingface_hub import snapshot_download
from PIL import Image


# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
Expand All @@ -62,10 +57,19 @@ def download_models():

@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240)
class Model:
def __enter__(self):
import torch
from diffusers import AutoPipelineForImage2Image

@build()
def download_models(self):
# Ignore files that we don't need to speed up download time.
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]

snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)

@enter()
def enter(self):
self.pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
Expand All @@ -75,11 +79,6 @@ def __enter__(self):

@method()
def inference(self, image_bytes, prompt):
from io import BytesIO

from diffusers.utils import load_image
from PIL import Image

init_image = load_image(Image.open(BytesIO(image_bytes))).resize(
(512, 512)
)
Expand Down

0 comments on commit 98a802c

Please sign in to comment.