diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py index c1ff6b824..19dcd6a93 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py @@ -15,6 +15,7 @@ # ## Basic setup +import io from pathlib import Path from modal import Image, Mount, Stub, asgi_app, gpu, method @@ -28,18 +29,6 @@ # triggers a rebuild. -def download_models(): - from huggingface_hub import snapshot_download - - ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"] - snapshot_download( - "stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore - ) - snapshot_download( - "stabilityai/stable-diffusion-xl-refiner-1.0", ignore_patterns=ignore - ) - - sdxl_image = ( Image.debian_slim() .apt_install( @@ -52,11 +41,18 @@ def download_models(): "accelerate~=0.21", "safetensors~=0.3", ) - .run_function(download_models) ) stub = Stub("stable-diffusion-xl") +with sdxl_image.imports(): + import fastapi.staticfiles + import torch + from diffusers import DiffusionPipeline + from fastapi import FastAPI + from fastapi.responses import Response + from huggingface_hub import snapshot_download + # ## Load model and run inference # # The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta) @@ -68,10 +64,21 @@ def download_models(): @stub.cls(gpu=gpu.A10G(), container_idle_timeout=240, image=sdxl_image) class Model: - def __enter__(self): - import torch - from diffusers import DiffusionPipeline + def __build__(self): + ignore = [ + "*.bin", + "*.onnx_data", + "*/diffusion_pytorch_model.safetensors", + ] + snapshot_download( + "stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore + ) + snapshot_download( + "stabilityai/stable-diffusion-xl-refiner-1.0", + ignore_patterns=ignore, + ) + def __enter__(self): load_options = dict( torch_dtype=torch.float16, use_safetensors=True, @@ -115,8 +122,6 @@ def inference(self, prompt, n_steps=24, high_noise_frac=0.8): image=image, ).images[0] - import io - byte_stream = io.BytesIO() image.save(byte_stream, format="PNG") image_bytes = byte_stream.getvalue() @@ -160,15 +165,10 @@ def main(prompt: str): ) @asgi_app() def app(): - import fastapi.staticfiles - from fastapi import FastAPI - web_app = FastAPI() @web_app.get("/infer/{prompt}") async def infer(prompt: str): - from fastapi.responses import Response - image_bytes = Model().inference.remote(prompt) return Response(image_bytes, media_type="image/png")