-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ubuntu
committed
Dec 5, 2023
1 parent
8eef68f
commit b5ddb28
Showing
2 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo_webcam.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import base64 | ||
from pathlib import Path | ||
import json | ||
|
||
from fastapi import FastAPI, Request, Response | ||
from fastapi.staticfiles import StaticFiles | ||
from modal import Image, Mount, Stub, asgi_app, gpu, method | ||
|
||
|
||
def download_models(): | ||
from huggingface_hub import snapshot_download | ||
|
||
# Ignore files that we don't need to speed up download time. | ||
ignore = [ | ||
"*.bin", | ||
"*.onnx_data", | ||
"*/diffusion_pytorch_model.safetensors", | ||
] | ||
|
||
snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore) | ||
|
||
|
||
image = ( | ||
Image.debian_slim() | ||
.pip_install( | ||
"Pillow~=10.1.0", | ||
"diffusers~=0.24", | ||
"transformers~=4.35", | ||
"accelerate~=0.25", | ||
"safetensors~=0.4", | ||
) | ||
.run_function(download_models) | ||
) | ||
|
||
stub = Stub("stable-diffusion-xl-turbo", image=image) | ||
|
||
|
||
@stub.cls(gpu=gpu.A100(memory=40), container_idle_timeout=240) | ||
class Model: | ||
def __enter__(self): | ||
import torch | ||
from diffusers import AutoPipelineForImage2Image | ||
|
||
self.pipe = AutoPipelineForImage2Image.from_pretrained( | ||
"stabilityai/sdxl-turbo", | ||
torch_dtype=torch.float16, | ||
variant="fp16", | ||
device_map="auto", | ||
) | ||
|
||
@method() | ||
async def inference(self, image_bytes, prompt): | ||
from io import BytesIO | ||
|
||
from diffusers.utils import load_image | ||
from PIL import Image | ||
|
||
init_image = load_image(Image.open(BytesIO(image_bytes))).resize((512,512)) | ||
num_inference_steps = 2 | ||
strength = 0.50 | ||
assert num_inference_steps * strength >= 1 | ||
|
||
image = self.pipe( | ||
prompt, | ||
image=init_image, | ||
num_inference_steps=num_inference_steps, | ||
strength=strength, | ||
guidance_scale=0.0, | ||
use_safetensors=True, | ||
).images[0] | ||
|
||
byte_stream = BytesIO() | ||
image.save(byte_stream, format="PNG") | ||
image_bytes = byte_stream.getvalue() | ||
|
||
return image_bytes | ||
|
||
|
||
web_app = FastAPI() | ||
static_path = Path(__file__).with_name("webcam").resolve() | ||
|
||
|
||
@web_app.post("/generate") | ||
async def generate(request: Request): | ||
body = await request.body() | ||
body_json = json.loads(body) | ||
img_data_in = base64.b64decode(body_json["image"].split(",")[1]) # read data-uri | ||
prompt = body_json["prompt"] | ||
img_data_out = await Model().inference.remote.aio(img_data_in, prompt) | ||
output_data = b"data:image/png;base64," + base64.b64encode(img_data_out) | ||
return Response(content=output_data) | ||
|
||
|
||
@stub.function( | ||
mounts=[Mount.from_local_dir(static_path, remote_path="/assets")], | ||
) | ||
@asgi_app() | ||
def fastapi_app(): | ||
web_app.mount("/", StaticFiles(directory="/assets", html=True)) | ||
return web_app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
<h1>SDXL Turbo Image-to-image</h1> | ||
<input type="file" id="imageLoader" name="imageLoader" /> | ||
<input type="text" id="textInput" name="textInput" value="wizard, gandalf, lord of the rings, disney, cute, 8k" /> | ||
<canvas id="canvas" style="display: none"> </canvas> | ||
<div style="position: relative"> | ||
<img id="outputImg" style="position: absolute; top: 0; left: 0" /> | ||
</div> | ||
|
||
<canvas id="drawingCanvas"></canvas> | ||
|
||
<script> | ||
var imageLoader = document.getElementById('imageLoader'); | ||
var textInput = document.getElementById('textInput'); | ||
imageLoader.addEventListener('change', handleImage, false); | ||
|
||
var canvas = document.getElementById('canvas'); | ||
var ctx = canvas.getContext('2d'); | ||
|
||
var drawingCanvas = document.getElementById("drawingCanvas"); | ||
var drawingCtx = drawingCanvas.getContext("2d"); | ||
|
||
var isImageUploaded = false; | ||
canvas.height = 512; | ||
canvas.width = 512; | ||
drawingCanvas.height = canvas.height; | ||
drawingCanvas.width = canvas.width; | ||
|
||
var painting = false; | ||
|
||
function startPosition(e) { | ||
painting = true; | ||
draw(e); | ||
} | ||
|
||
function endPosition() { | ||
painting = false; | ||
drawingCtx.beginPath(); // Begin a new path to start a new stroke | ||
} | ||
|
||
function draw(e) { | ||
if (!painting) return; | ||
drawingCtx.lineWidth = 5; // Width of the stroke | ||
drawingCtx.lineCap = 'round'; // Rounded end of the stroke | ||
|
||
drawingCtx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); // Position of the mouse | ||
drawingCtx.stroke(); // Make the line visible | ||
drawingCtx.beginPath(); // Begin a new path | ||
drawingCtx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); // Move to a new position | ||
} | ||
|
||
// Event Listeners | ||
canvas.addEventListener('mousedown', startPosition); | ||
canvas.addEventListener('mouseup', endPosition); | ||
canvas.addEventListener('mousemove', draw); | ||
|
||
function getCombinedImageData() { | ||
var tempCanvas = document.createElement('canvas'); | ||
tempCanvas.width = canvas.width; | ||
tempCanvas.height = canvas.height; | ||
var tempCtx = tempCanvas.getContext('2d'); | ||
|
||
// Draw the image canvas first | ||
tempCtx.drawImage(canvas, 0, 0); | ||
|
||
// Then draw the drawing canvas | ||
tempCtx.drawImage(drawingCanvas, 0, 0); | ||
|
||
// Get the combined image data | ||
return tempCanvas.toDataURL("image/png"); | ||
} | ||
|
||
function handleImage(e) { | ||
var reader = new FileReader(); | ||
reader.onload = function (event) { | ||
var img = new Image(); | ||
img.onload = function () { | ||
ctx.clearRect(0, 0, canvas.width, canvas.height); | ||
|
||
var aspectRatio = img.width / img.height; | ||
|
||
// Calculate new dimensions | ||
var newWidth, newHeight; | ||
|
||
if (img.width > img.height) { | ||
// Image is wider than it is tall | ||
newWidth = 512; | ||
newHeight = 512 / aspectRatio; | ||
} else { | ||
// Image is taller than it is wide | ||
newHeight = 512; | ||
newWidth = 512 * aspectRatio; | ||
} | ||
|
||
|
||
ctx.drawImage(img, 0, 0, newWidth, newHeight); | ||
canvas.style.display = 'block'; | ||
isImageUploaded = true; | ||
getNextFrameLoop(); | ||
} | ||
img.src = event.target.result; | ||
} | ||
reader.readAsDataURL(e.target.files[0]); | ||
} | ||
|
||
const getNextFrameLoop = () => { | ||
const context = canvas.getContext("2d"); | ||
if (!isImageUploaded) { | ||
setTimeout(getNextFrameLoop, 5000); | ||
return; | ||
} | ||
|
||
const data = getCombinedImageData(); | ||
fetch("/generate", { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ | ||
image: data, | ||
prompt: textInput.value, | ||
}), | ||
}) | ||
.then((res) => res.text()) | ||
.then((text) => { | ||
var img = new Image(); | ||
img.onload = function() { | ||
ctx.drawImage(img, 0, 0, canvas.width, canvas.height); | ||
}; | ||
img.src = text; | ||
setTimeout(getNextFrameLoop, 10); | ||
}) | ||
.catch((e) => { | ||
setTimeout(getNextFrameLoop, 1000); | ||
}); | ||
}; | ||
</script> |