Skip to content

Commit

Permalink
First draft
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 5, 2023
1 parent 8eef68f commit b5ddb28
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
100 changes: 100 additions & 0 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl_turbo_webcam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import base64
from pathlib import Path
import json

from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
from modal import Image, Mount, Stub, asgi_app, gpu, method


def download_models():
from huggingface_hub import snapshot_download

# Ignore files that we don't need to speed up download time.
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]

snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)


image = (
Image.debian_slim()
.pip_install(
"Pillow~=10.1.0",
"diffusers~=0.24",
"transformers~=4.35",
"accelerate~=0.25",
"safetensors~=0.4",
)
.run_function(download_models)
)

stub = Stub("stable-diffusion-xl-turbo", image=image)


@stub.cls(gpu=gpu.A100(memory=40), container_idle_timeout=240)
class Model:
def __enter__(self):
import torch
from diffusers import AutoPipelineForImage2Image

self.pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
variant="fp16",
device_map="auto",
)

@method()
async def inference(self, image_bytes, prompt):
from io import BytesIO

from diffusers.utils import load_image
from PIL import Image

init_image = load_image(Image.open(BytesIO(image_bytes))).resize((512,512))
num_inference_steps = 2
strength = 0.50
assert num_inference_steps * strength >= 1

image = self.pipe(
prompt,
image=init_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=0.0,
use_safetensors=True,
).images[0]

byte_stream = BytesIO()
image.save(byte_stream, format="PNG")
image_bytes = byte_stream.getvalue()

return image_bytes


web_app = FastAPI()
static_path = Path(__file__).with_name("webcam").resolve()


@web_app.post("/generate")
async def generate(request: Request):
body = await request.body()
body_json = json.loads(body)
img_data_in = base64.b64decode(body_json["image"].split(",")[1]) # read data-uri
prompt = body_json["prompt"]
img_data_out = await Model().inference.remote.aio(img_data_in, prompt)
output_data = b"data:image/png;base64," + base64.b64encode(img_data_out)
return Response(content=output_data)


@stub.function(
mounts=[Mount.from_local_dir(static_path, remote_path="/assets")],
)
@asgi_app()
def fastapi_app():
web_app.mount("/", StaticFiles(directory="/assets", html=True))
return web_app
134 changes: 134 additions & 0 deletions 06_gpu_and_ml/stable_diffusion/webcam/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<h1>SDXL Turbo Image-to-image</h1>
<input type="file" id="imageLoader" name="imageLoader" />
<input type="text" id="textInput" name="textInput" value="wizard, gandalf, lord of the rings, disney, cute, 8k" />
<canvas id="canvas" style="display: none"> </canvas>
<div style="position: relative">
<img id="outputImg" style="position: absolute; top: 0; left: 0" />
</div>

<canvas id="drawingCanvas"></canvas>

<script>
var imageLoader = document.getElementById('imageLoader');
var textInput = document.getElementById('textInput');
imageLoader.addEventListener('change', handleImage, false);

var canvas = document.getElementById('canvas');
var ctx = canvas.getContext('2d');

var drawingCanvas = document.getElementById("drawingCanvas");
var drawingCtx = drawingCanvas.getContext("2d");

var isImageUploaded = false;
canvas.height = 512;
canvas.width = 512;
drawingCanvas.height = canvas.height;
drawingCanvas.width = canvas.width;

var painting = false;

function startPosition(e) {
painting = true;
draw(e);
}

function endPosition() {
painting = false;
drawingCtx.beginPath(); // Begin a new path to start a new stroke
}

function draw(e) {
if (!painting) return;
drawingCtx.lineWidth = 5; // Width of the stroke
drawingCtx.lineCap = 'round'; // Rounded end of the stroke

drawingCtx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); // Position of the mouse
drawingCtx.stroke(); // Make the line visible
drawingCtx.beginPath(); // Begin a new path
drawingCtx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop); // Move to a new position
}

// Event Listeners
canvas.addEventListener('mousedown', startPosition);
canvas.addEventListener('mouseup', endPosition);
canvas.addEventListener('mousemove', draw);

function getCombinedImageData() {
var tempCanvas = document.createElement('canvas');
tempCanvas.width = canvas.width;
tempCanvas.height = canvas.height;
var tempCtx = tempCanvas.getContext('2d');

// Draw the image canvas first
tempCtx.drawImage(canvas, 0, 0);

// Then draw the drawing canvas
tempCtx.drawImage(drawingCanvas, 0, 0);

// Get the combined image data
return tempCanvas.toDataURL("image/png");
}

function handleImage(e) {
var reader = new FileReader();
reader.onload = function (event) {
var img = new Image();
img.onload = function () {
ctx.clearRect(0, 0, canvas.width, canvas.height);

var aspectRatio = img.width / img.height;

// Calculate new dimensions
var newWidth, newHeight;

if (img.width > img.height) {
// Image is wider than it is tall
newWidth = 512;
newHeight = 512 / aspectRatio;
} else {
// Image is taller than it is wide
newHeight = 512;
newWidth = 512 * aspectRatio;
}


ctx.drawImage(img, 0, 0, newWidth, newHeight);
canvas.style.display = 'block';
isImageUploaded = true;
getNextFrameLoop();
}
img.src = event.target.result;
}
reader.readAsDataURL(e.target.files[0]);
}

const getNextFrameLoop = () => {
const context = canvas.getContext("2d");
if (!isImageUploaded) {
setTimeout(getNextFrameLoop, 5000);
return;
}

const data = getCombinedImageData();
fetch("/generate", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
image: data,
prompt: textInput.value,
}),
})
.then((res) => res.text())
.then((text) => {
var img = new Image();
img.onload = function() {
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
};
img.src = text;
setTimeout(getNextFrameLoop, 10);
})
.catch((e) => {
setTimeout(getNextFrameLoop, 1000);
});
};
</script>

0 comments on commit b5ddb28

Please sign in to comment.