Skip to content

Commit

Permalink
Rewrite it to use decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 20, 2023
1 parent 26c2c5f commit 48ae570
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
10 changes: 5 additions & 5 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import time
from pathlib import Path

from modal import Image, Stub, method
from modal import Image, Stub, method, build, enter

# All Modal programs need a [`Stub`](/docs/reference/modal.Stub) — an object that acts as a recipe for
# the application. Let's give it a friendly name.
Expand Down Expand Up @@ -72,7 +72,7 @@
.pip_install("xformers", pre=True)
)

with image.run_inside():
with image.imports():
import diffusers
import torch

Expand All @@ -98,7 +98,9 @@

@stub.cls(image=image, gpu="A10G")
class StableDiffusion:
def __enter__(self):
@build()
@enter()
def initialize(self):
scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
Expand All @@ -119,8 +121,6 @@ def __enter__(self):
)
self.pipe.enable_xformers_memory_efficient_attention()

__build__ = __enter__

@method()
def run_inference(
self, prompt: str, steps: int = 20, batch_size: int = 4
Expand Down
10 changes: 6 additions & 4 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import io
from pathlib import Path

from modal import Image, Mount, Stub, asgi_app, gpu, method
from modal import Image, Mount, Stub, asgi_app, build, enter, gpu, method

# ## Define a container image
#
Expand All @@ -45,7 +45,7 @@

stub = Stub("stable-diffusion-xl", image=image)

with image.run_inside():
with image.imports():
import fastapi.staticfiles
import torch
from diffusers import DiffusionPipeline
Expand All @@ -64,7 +64,8 @@

@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240)
class Model:
def __build__(self):
@build()
def build(self):
ignore = [
"*.bin",
"*.onnx_data",
Expand All @@ -78,7 +79,8 @@ def __build__(self):
ignore_patterns=ignore,
)

def __enter__(self):
@enter()
def enter(self):
load_options = dict(
torch_dtype=torch.float16,
use_safetensors=True,
Expand Down

0 comments on commit 48ae570

Please sign in to comment.