From c77c1f2c9b0f9bf54dd672180890d3778216b975 Mon Sep 17 00:00:00 2001 From: Charles Frye Date: Mon, 4 Mar 2024 15:53:37 -0800 Subject: [PATCH] adds local entrypoint for playground, matches infra with deployed version, pins deps (#623) * adds local entrypoint for playground, matches infra with deployed version * pins dependencies * pins python version --- 06_gpu_and_ml/stable_diffusion/playground.py | 65 +++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/06_gpu_and_ml/stable_diffusion/playground.py b/06_gpu_and_ml/stable_diffusion/playground.py index 352650fdf..f2819c190 100644 --- a/06_gpu_and_ml/stable_diffusion/playground.py +++ b/06_gpu_and_ml/stable_diffusion/playground.py @@ -1,17 +1,24 @@ +# --- +# output-directory: "/tmp/playground-2-5" +# args: ["--prompt", "A cinematic shot of a baby raccoon wearing an intricate Italian priest robe."] +# --- + from pathlib import Path import modal stub = modal.Stub("playground-2-5") +DIFFUSERS_GIT_SHA = "2e31a759b5bd8ca2b288b5c61709636a96c4bae9" + image = ( - modal.Image.debian_slim() + modal.Image.debian_slim(python_version="3.10") .apt_install("git") .pip_install( - "git+https://github.com/huggingface/diffusers.git", - "transformers", - "accelerate", - "safetensors", + f"git+https://github.com/huggingface/diffusers.git@{DIFFUSERS_GIT_SHA}", + "transformers~=4.38.1", + "accelerate==0.27.2", + "safetensors==0.4.2", ) ) @@ -24,7 +31,7 @@ from fastapi import Response -@stub.cls(image=image, gpu="a100") +@stub.cls(image=image, gpu="H100") class Model: @modal.build() @modal.enter() @@ -39,19 +46,33 @@ def load_weights(self): # from diffusers import EDMDPMSolverMultistepScheduler # pipe.scheduler = EDMDPMSolverMultistepScheduler() - @modal.web_endpoint() - def inference( - self, - prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.", - ): + def _inference(self, prompt, n_steps=24, high_noise_frac=0.8): image = self.pipe( - prompt, num_inference_steps=50, guidance_scale=3 + prompt, + negative_prompt="disfigured, ugly, deformed", + num_inference_steps=50, + guidance_scale=3, ).images[0] buffer = io.BytesIO() image.save(buffer, format="JPEG") - return Response(buffer.getvalue(), media_type="image/jpeg") + return buffer + + @modal.method() + def inference(self, prompt, n_steps=24, high_noise_frac=0.8): + return self._inference( + prompt, n_steps=n_steps, high_noise_frac=high_noise_frac + ).getvalue() + + @modal.web_endpoint() + def web_inference(self, prompt, n_steps=24, high_noise_frac=0.8): + return Response( + content=self._inference( + prompt, n_steps=n_steps, high_noise_frac=high_noise_frac + ).getvalue(), + media_type="image/jpeg", + ) frontend_path = Path(__file__).parent / "frontend" @@ -79,9 +100,9 @@ def app(): with open("/assets/index.html", "w") as f: html = template.render( - inference_url=Model.inference.web_url, + inference_url=Model.web_inference.web_url, model_name="Playground 2.5", - default_prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + default_prompt="Astronaut in the ocean, cold color palette, muted colors, detailed, 8k", ) f.write(html) @@ -90,3 +111,17 @@ def app(): ) return web_app + + +@stub.local_entrypoint() +def main(prompt: str): + image_bytes = Model().inference.remote(prompt) + + dir = Path("/tmp/playground-2-5") + if not dir.exists(): + dir.mkdir(exist_ok=True, parents=True) + + output_path = dir / "output.png" + print(f"Saving it to {output_path}") + with open(output_path, "wb") as f: + f.write(image_bytes)