Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup stable diffusion #981

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed 06_gpu_and_ml/stable_diffusion/demo_images/dog.png
Binary file not shown.
1 change: 0 additions & 1 deletion 06_gpu_and_ml/stable_diffusion/foo.json

This file was deleted.

Binary file not shown.
90 changes: 81 additions & 9 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# output-directory: "/tmp/stable-diffusion"
# args: ["--prompt", "A 1600s oil painting of the New York City skyline"]
# tags: ["use-case-image-video-3d"]
# deploy: true
# ---

# # Run Stable Diffusion 3.5 Large Turbo from the command line
# # Run Stable Diffusion 3.5 Large Turbo as a CLI, API, and web UI

# This example shows how to run [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) on Modal
# and generate images from your local command line.
# to generate images from your local command line, via an API, and as a web UI.

# Inference takes about one minute to cold start,
# at which point images are generated at a rate of one image every 1-2 seconds
Expand Down Expand Up @@ -47,29 +48,35 @@
.pip_install(
"accelerate==0.33.0",
"diffusers==0.31.0",
"fastapi[standard]==0.115.4",
"huggingface-hub[hf_transfer]==0.25.2",
"sentencepiece==0.2.0",
"torch==2.5.1",
"torchvision==0.20.1",
"transformers~=4.44.0",
)
.entrypoint([]) # deactivate default entrypoint to reduce log verbosity
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster downloads
)

with image.imports():
import diffusers
import torch
from fastapi import Response

# ## Implementing SD3.5 Large Turbo inference on Modal

# We wrap inference in a Modal [Cls](https://modal.com/docs/guide/lifecycle-methods)
# that ensures models are downloaded when we `build` our container image (just like our dependencies)
# and that models are loaded and then moved to the GPU when a new container starts.

# The `run_inference` function just wraps a `diffusers` pipeline.
# The `run` function just wraps a `diffusers` pipeline.
# It sends the output image back to the client as bytes.

# We also include a `web` wrapper that makes it possible
# to trigger inference via an API call.
# See the `/docs` route of the URL ending in `inference-web.modal.run`
# that appears when you deploy the app for details.

model_id = "adamo1139/stable-diffusion-3.5-large-turbo-ungated"
model_revision_id = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2"

Expand All @@ -79,7 +86,7 @@
gpu="H100",
timeout=10 * MINUTES,
)
class StableDiffusion:
class Inference:
@modal.build()
@modal.enter()
def initialize(self):
Expand All @@ -94,7 +101,7 @@ def move_to_gpu(self):
self.pipe.to("cuda")

@modal.method()
def run_inference(
def run(
self, prompt: str, batch_size: int = 4, seed: int = None
) -> list[bytes]:
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
Expand All @@ -116,15 +123,29 @@ def run_inference(
torch.cuda.empty_cache() # reduce fragmentation
return image_output

@modal.web_endpoint(docs=True)
def web(self, prompt: str, seed: int = None):
return Response(
content=self.run.local( # run in the same container
prompt, batch_size=1, seed=seed
)[0],
media_type="image/png",
)


# ## Generating images from the command line
# ## Generating Stable Diffusion images from the command line

# This is the command we'll use to generate images. It takes a text `prompt`,
# a `batch_size` that determines the number of images to generate per prompt,
# and the number of times to run image generation (`samples`).

# You can also provide a `seed` to make sampling more deterministic.

# Run it with
# ```bash
# modal run stable_diffusion_cli.py
# ```


@app.local_entrypoint()
def entrypoint(
Expand All @@ -144,11 +165,11 @@ def entrypoint(
output_dir = Path("/tmp/stable-diffusion")
output_dir.mkdir(exist_ok=True, parents=True)

sd = StableDiffusion()
inference_service = Inference()

for sample_idx in range(samples):
start = time.time()
images = sd.run_inference.remote(prompt, batch_size, seed)
images = inference_service.run.remote(prompt, batch_size, seed)
duration = time.time() - start
print(f"Run {sample_idx+1} took {duration:.3f}s")
if sample_idx:
Expand All @@ -169,5 +190,56 @@ def entrypoint(
output_path.write_bytes(image_bytes)


# ## Generating Stable Diffusion images in a web UI

# Lastly, we add a simple web application that exposes a front-end (written in Alpine.js) for
# our image generation backend.

# The `Inference` class will serve multiple users from its own shared pool of warm GPU containers automatically.

# We can deploy this with `modal deploy stable_diffusion_cli.py`.

frontend_path = Path(__file__).parent / "frontend"

web_image = modal.Image.debian_slim(python_version="3.12").pip_install(
"jinja2", "fastapi[standard]==0.115.4"
)


@app.function(
image=web_image,
mounts=[modal.Mount.from_local_dir(frontend_path, remote_path="/assets")],
allow_concurrent_inputs=1000,
)
@modal.asgi_app()
def ui():
import fastapi.staticfiles
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates

web_app = FastAPI()
templates = Jinja2Templates(directory="/assets")

@web_app.get("/")
async def read_root(request: Request):
return templates.TemplateResponse(
"index.html",
{
"request": request,
"inference_url": Inference.web.web_url,
"model_name": "Stable Diffusion 3.5 Large Turbo",
"default_prompt": "A cinematic shot of a baby raccoon wearing an intricate italian priest robe.",
},
)

web_app.mount(
"/static",
fastapi.staticfiles.StaticFiles(directory="/assets"),
name="static",
)

return web_app


def slugify(s: str) -> str:
return "".join(c if c.isalnum() else "-" for c in s).strip("-")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
220 changes: 0 additions & 220 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py

This file was deleted.

Loading
Loading