Skip to content

Commit

Permalink
fixes use in local_entrypoint, pins dependencies (#619)
Browse files Browse the repository at this point in the history
* fixes use in local_entrypoint, pins dependencies

* fixes deployment, fixes typo in default prompt

* actually fixes deployed web inference
  • Loading branch information
charlesfrye authored Mar 4, 2024
1 parent 499c385 commit 5e45207
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
import io
from pathlib import Path

from modal import Image, Mount, Stub, asgi_app, build, enter, gpu, web_endpoint
from modal import (
Image,
Mount,
Stub,
asgi_app,
build,
enter,
gpu,
method,
web_endpoint,
)

# ## Define a container image
#
Expand All @@ -35,11 +45,11 @@
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1"
)
.pip_install(
"diffusers~=0.19",
"invisible_watermark~=0.1",
"transformers~=4.31",
"accelerate~=0.21",
"safetensors~=0.3",
"diffusers==0.26.3",
"invisible_watermark==0.2.0",
"transformers~=4.38.2",
"accelerate==0.27.2",
"safetensors==0.4.2",
)
)

Expand Down Expand Up @@ -104,8 +114,7 @@ def enter(self):
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)

@web_endpoint()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
def _inference(self, prompt, n_steps=24, high_noise_frac=0.8):
negative_prompt = "disfigured, ugly, deformed"
image = self.base(
prompt=prompt,
Expand All @@ -125,7 +134,22 @@ def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
byte_stream = io.BytesIO()
image.save(byte_stream, format="JPEG")

return Response(content=byte_stream.getvalue(), media_type="image/jpeg")
return byte_stream

@method()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
return self._inference(
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac
).getvalue()

@web_endpoint()
def web_inference(self, prompt, n_steps=24, high_noise_frac=0.8):
return Response(
content=self._inference(
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac
).getvalue(),
media_type="image/jpeg",
)


# And this is our entrypoint; where the CLI is invoked. Explore CLI options
Expand Down Expand Up @@ -180,9 +204,9 @@ def app():

with open("/assets/index.html", "w") as f:
html = template.render(
inference_url=Model.inference.web_url,
inference_url=Model.web_inference.web_url,
model_name="Stable Diffusion XL",
default_prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe.",
default_prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.",
)
f.write(html)

Expand Down

0 comments on commit 5e45207

Please sign in to comment.