-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add sample blueprint to run stable diffusion model on inferenti…
…a2 using rayserve (#406) Co-authored-by: Vara Bonthu <[email protected]>
- Loading branch information
1 parent
1fd8438
commit 8d3515d
Showing
24 changed files
with
861 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,3 +49,5 @@ site | |
|
||
# Checks | ||
.tfsec | ||
|
||
examples/gradio-ui/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
ai-ml/trainium-inferentia/examples/gradio-ui/README-StableDiffusion.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Steps to Deploy Gradio on Your Mac | ||
|
||
## Pre-requisites | ||
Deploy the `trainium-inferentia` blueprint using this [link](https://awslabs.github.io/data-on-eks/docs/blueprints/ai-ml/trainium) | ||
|
||
## Step 1: Execute Port Forward to the StableDiffusion Ray Service | ||
First, execute a port forward to the StableDiffusion Ray Service using kubectl: | ||
|
||
```bash | ||
kubectl -n stablediffusion port-forward svc/stablediffusion-service 8000:8000 | ||
``` | ||
|
||
## Step 2: Deploy Gradio WebUI Locally | ||
|
||
### 2.1. Create a Virtual Environment | ||
Create a virtual environment for the Gradio application: | ||
|
||
```bash | ||
cd ai-ml/trainium-inferentia/examples/gradio-ui | ||
python3 -m venv .venv | ||
source .venv/bin/activate | ||
``` | ||
### 2.2. Install Gradio WebUI app | ||
|
||
Install all the Gradio WebUI app dependencies with pip | ||
|
||
```bash | ||
pip install gradio requests | ||
``` | ||
|
||
### 2.3. Invoke the WebUI | ||
Run the Gradio WebUI using the following command: | ||
|
||
NOTE: `gradio-app-stablediffusion.py` refers to the port forward url. e.g., `service_name = "http://localhost:8000" ` | ||
|
||
```bash | ||
python gradio-app-stablediffusion.py | ||
``` | ||
|
||
You should see output similar to the following: | ||
```text | ||
Running on local URL: http://127.0.0.1:7860 | ||
To create a public link, set `share=True` in `launch()`. | ||
``` | ||
|
||
### 2.4. Access the WebUI from Your Browser | ||
Open your web browser and access the Gradio WebUI by navigating to the following URL: | ||
|
||
http://127.0.0.1:7860 | ||
|
||
![gradio-sd](gradio-app-stable-diffusion-xl.png) | ||
|
||
You should now be able to interact with the Gradio application from your local machine. |
Binary file added
BIN
+441 KB
ai-ml/trainium-inferentia/examples/gradio-ui/gradio-app-stable-diffusion-xl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 33 additions & 0 deletions
33
ai-ml/trainium-inferentia/examples/gradio-ui/gradio-app-stablediffusion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import gradio as gr | ||
import requests | ||
import json | ||
from PIL import Image | ||
from io import BytesIO | ||
|
||
# Constants for model endpoint and service name | ||
model_endpoint = "/imagine" | ||
# service_name = "http://<REPLACE_ME_WITH_ELB_DNS_NAME>/serve" | ||
service_name = "http://localhost:8000" # Replace with your actual service name | ||
|
||
|
||
# Function to generate image based on prompt | ||
def generate_image(prompt): | ||
|
||
# Create the URL for the inference | ||
url = f"{service_name}{model_endpoint}" | ||
|
||
try: | ||
# Send the request to the model service | ||
response = requests.get(url, params={"prompt": prompt}, timeout=180) | ||
response.raise_for_status() # Raise an exception for HTTP errors | ||
i = Image.open(BytesIO(response.content)) | ||
return i | ||
|
||
except requests.exceptions.RequestException as e: | ||
# Handle any request exceptions (e.g., connection errors) | ||
return f"AI: Error: {str(e)}" | ||
|
||
# Define the Gradio PromptInterface | ||
demo = gr.Interface(fn=generate_image, | ||
inputs = [gr.Textbox(label="Enter the Prompt")], | ||
outputs = gr.Image(type='pil')).launch(debug='True') |
Oops, something went wrong.