From 32f7f490a86942a62ab1c4c5a8de157129e95628 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Wed, 20 Dec 2023 18:22:07 -0500 Subject: [PATCH] Rewrite it to use decorators --- 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py | 10 +++++----- 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py | 8 +++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py index 0cd5fc0ef..217d79493 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py @@ -37,7 +37,7 @@ import time from pathlib import Path -from modal import Image, Stub, method +from modal import Image, Stub, build, enter, method # All Modal programs need a [`Stub`](/docs/reference/modal.Stub) — an object that acts as a recipe for # the application. Let's give it a friendly name. @@ -72,7 +72,7 @@ .pip_install("xformers", pre=True) ) -with image.run_inside(): +with image.imports(): import diffusers import torch @@ -98,7 +98,9 @@ @stub.cls(image=image, gpu="A10G") class StableDiffusion: - def __enter__(self): + @build() + @enter() + def initialize(self): scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", @@ -119,8 +121,6 @@ def __enter__(self): ) self.pipe.enable_xformers_memory_efficient_attention() - __build__ = __enter__ - @method() def run_inference( self, prompt: str, steps: int = 20, batch_size: int = 4 diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py index 19dcd6a93..08ea18c20 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py @@ -18,7 +18,7 @@ import io from pathlib import Path -from modal import Image, Mount, Stub, asgi_app, gpu, method +from modal import Image, Mount, Stub, asgi_app, build, enter, gpu, method # ## Define a container image # @@ -64,7 +64,8 @@ @stub.cls(gpu=gpu.A10G(), container_idle_timeout=240, image=sdxl_image) class Model: - def __build__(self): + @build() + def build(self): ignore = [ "*.bin", "*.onnx_data", @@ -78,7 +79,8 @@ def __build__(self): ignore_patterns=ignore, ) - def __enter__(self): + @enter() + def enter(self): load_options = dict( torch_dtype=torch.float16, use_safetensors=True,