diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_input.png b/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_input.png new file mode 100644 index 000000000..022c2d13a Binary files /dev/null and b/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_input.png differ diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_output.png b/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_output.png new file mode 100644 index 000000000..6123597cf Binary files /dev/null and b/06_gpu_and_ml/stable_diffusion/stable_diffusion_turbo_output.png differ diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo.py index 0982eb0fb..c6b7fc78a 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo.py @@ -1,12 +1,39 @@ +# --- +# output-directory: "/tmp/stable-diffusion-xl-turbo" +# args: [] +# runtimes: ["runc", "gvisor"] +# --- +# # Stable Diffusion XL Turbo Image-to-image +# +# This example is similar to the [Stable Diffusion XL](/docs/examples/stable_diffusion_xl) +# example, but it's a distilled model trained for real-time synthesis and is image-to-image. Learn more about it [here](https://stability.ai/news/stability-ai-sdxl-turbo). +# +# Input prompt: +# `dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k` +# +# Input | Output +# :-------------------------:|:-------------------------: +# ![](./stable_diffusion_turbo_input.png) | ![](./stable_diffusion_turbo_output.png) + +# ## Basic setup + from pathlib import Path from modal import Image, Stub, gpu, method +# ## Define a container image + def download_models(): from huggingface_hub import snapshot_download - ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"] + # Ignore files that we don't need to speed up download time. + ignore = [ + "*.bin", + "*.onnx_data", + "*/diffusion_pytorch_model.safetensors", + ] + snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore) @@ -15,15 +42,23 @@ def download_models(): .pip_install( "Pillow~=10.1.0", "diffusers~=0.24", - "transformers~=4.35", - "accelerate~=0.25", - "safetensors~=0.4", + "transformers~=4.35", # This is needed for `import torch` + "accelerate~=0.25", # Allows `device_map="auto"``, which allows computation of optimized device_map + "safetensors~=0.4", # Enables safetensor format as opposed to using unsafe pickle format ) .run_function(download_models) ) stub = Stub("stable-diffusion-xl-turbo", image=image) +# ## Load model and run inference +# +# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta) +# loads the model at startup. Then, we evaluate it in the `inference` function. +# +# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay +# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs. + @stub.cls(gpu=gpu.A10G(), container_idle_timeout=240) class Model: @@ -48,11 +83,17 @@ def inference(self, image_bytes, prompt): init_image = load_image(Image.open(BytesIO(image_bytes))).resize( (512, 512) ) + num_inference_steps = 4 + strength = 0.9 + # "When using SDXL-Turbo for image-to-image generation, make sure that num_inference_steps * strength is larger or equal to 1" + # See: https://huggingface.co/stabilityai/sdxl-turbo + assert num_inference_steps * strength >= 1 + image = self.pipe( prompt, image=init_image, - num_inference_steps=4, - strength=0.9, + num_inference_steps=num_inference_steps, + strength=strength, guidance_scale=0.0, ).images[0] @@ -80,3 +121,11 @@ def main( print(f"Saving it to {output_path}") with open(output_path, "wb") as f: f.write(output_image_bytes) + + +# ## Running the model +# +# We can run the model with different parameters using the following command, +# ``` +# modal run stable_diffusion_xl_turbo.py --prompt="harry potter, glasses, wizard" --image-path="dog.png" +# ```