diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_lightning.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_lightning.py index 910cb61ea..74c62885e 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_lightning.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_lightning.py @@ -4,8 +4,8 @@ stub = modal.Stub("stable-diffusion-xl-lightning") -image = modal.Image.debian_slim().pip_install( - "diffusers", "transformers", "accelerate" +image = modal.Image.debian_slim(python_version="3.11").pip_install( + "diffusers==0.26.3", "transformers~=4.37.2", "accelerate==0.27.2" ) base = "stabilityai/stable-diffusion-xl-base-1.0" @@ -46,20 +46,66 @@ def load_weights(self): self.pipe.scheduler.config, timestep_spacing="trailing" ) - @modal.web_endpoint() - def inference( - self, - prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.", - ): + def _inference(self, prompt, n_steps=4): + negative_prompt = "disfigured, ugly, deformed" image = self.pipe( - prompt, num_inference_steps=4, guidance_scale=0 + prompt=prompt, + guidance_scale=0, + negative_prompt=negative_prompt, + num_inference_steps=n_steps, ).images[0] - buffer = io.BytesIO() - image.save(buffer, format="JPEG") + byte_stream = io.BytesIO() + image.save(byte_stream, format="JPEG") + + return byte_stream + + @modal.method() + def inference(self, prompt, n_steps=4): + return self._inference( + prompt, + n_steps=n_steps, + ).getvalue() + + @modal.web_endpoint() + def web_inference(self, prompt, n_steps=4): + return Response( + content=self._inference( + prompt, + n_steps=n_steps, + ).getvalue(), + media_type="image/jpeg", + ) + + +# And this is our entrypoint; where the CLI is invoked. Run this example +# with: `modal run stable_diffusion_xl_lightning.py --prompt 'An astronaut riding a green horse'` + + +@stub.local_entrypoint() +def main( + prompt: str = "in the style of Dali, a surrealist painting of a weasel in a tuxedo riding a bicycle in the rain", +): + image_bytes = Model().inference.remote(prompt) + + dir = Path("/tmp/stable-diffusion-xl-lightning") + if not dir.exists(): + dir.mkdir(exist_ok=True, parents=True) + + output_path = dir / "output.png" + print(f"Saving it to {output_path}") + with open(output_path, "wb") as f: + f.write(image_bytes) - return Response(content=buffer.getvalue(), media_type="image/jpeg") +# ## A user interface +# +# Here we ship a simple web application that exposes a front-end (written in Alpine.js) for +# our backend deployment. +# +# The Model class will serve multiple users from a its own shared pool of warm GPU containers automatically. +# +# We can deploy this with `modal deploy stable_diffusion_xl_lightning.py`. frontend_path = Path(__file__).parent / "frontend" @@ -86,9 +132,9 @@ def app(): with open("/assets/index.html", "w") as f: html = template.render( - inference_url=Model.inference.web_url, - model_name="SDXL Lightning", - default_prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.", + inference_url=Model.web_inference.web_url, + model_name="Stable Diffusion XL Lightning", + default_prompt="A cinematic shot of a baby raccoon wearing an intricate Italian priest robe.", ) f.write(html)