Skip to content

Commit

Permalink
adds local entrypoint for playground, matches infra with deployed ver…
Browse files Browse the repository at this point in the history
…sion, pins deps (#623)

* adds local entrypoint for playground, matches infra with deployed version

* pins dependencies

* pins python version
  • Loading branch information
charlesfrye authored Mar 4, 2024
1 parent 76a9839 commit c77c1f2
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions 06_gpu_and_ml/stable_diffusion/playground.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# ---
# output-directory: "/tmp/playground-2-5"
# args: ["--prompt", "A cinematic shot of a baby raccoon wearing an intricate Italian priest robe."]
# ---

from pathlib import Path

import modal

stub = modal.Stub("playground-2-5")

DIFFUSERS_GIT_SHA = "2e31a759b5bd8ca2b288b5c61709636a96c4bae9"

image = (
modal.Image.debian_slim()
modal.Image.debian_slim(python_version="3.10")
.apt_install("git")
.pip_install(
"git+https://github.com/huggingface/diffusers.git",
"transformers",
"accelerate",
"safetensors",
f"git+https://github.com/huggingface/diffusers.git@{DIFFUSERS_GIT_SHA}",
"transformers~=4.38.1",
"accelerate==0.27.2",
"safetensors==0.4.2",
)
)

Expand All @@ -24,7 +31,7 @@
from fastapi import Response


@stub.cls(image=image, gpu="a100")
@stub.cls(image=image, gpu="H100")
class Model:
@modal.build()
@modal.enter()
Expand All @@ -39,19 +46,33 @@ def load_weights(self):
# from diffusers import EDMDPMSolverMultistepScheduler
# pipe.scheduler = EDMDPMSolverMultistepScheduler()

@modal.web_endpoint()
def inference(
self,
prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.",
):
def _inference(self, prompt, n_steps=24, high_noise_frac=0.8):
image = self.pipe(
prompt, num_inference_steps=50, guidance_scale=3
prompt,
negative_prompt="disfigured, ugly, deformed",
num_inference_steps=50,
guidance_scale=3,
).images[0]

buffer = io.BytesIO()
image.save(buffer, format="JPEG")

return Response(buffer.getvalue(), media_type="image/jpeg")
return buffer

@modal.method()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
return self._inference(
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac
).getvalue()

@modal.web_endpoint()
def web_inference(self, prompt, n_steps=24, high_noise_frac=0.8):
return Response(
content=self._inference(
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac
).getvalue(),
media_type="image/jpeg",
)


frontend_path = Path(__file__).parent / "frontend"
Expand Down Expand Up @@ -79,9 +100,9 @@ def app():

with open("/assets/index.html", "w") as f:
html = template.render(
inference_url=Model.inference.web_url,
inference_url=Model.web_inference.web_url,
model_name="Playground 2.5",
default_prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
default_prompt="Astronaut in the ocean, cold color palette, muted colors, detailed, 8k",
)
f.write(html)

Expand All @@ -90,3 +111,17 @@ def app():
)

return web_app


@stub.local_entrypoint()
def main(prompt: str):
image_bytes = Model().inference.remote(prompt)

dir = Path("/tmp/playground-2-5")
if not dir.exists():
dir.mkdir(exist_ok=True, parents=True)

output_path = dir / "output.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as f:
f.write(image_bytes)

0 comments on commit c77c1f2

Please sign in to comment.