Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement sketch-to-image pipeline and route #231

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions cmd/examples/sketch-to-image/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Package main provides a small example on how to run the 'sketch-to-image' pipeline using the AI worker package.
package main

import (
"context"
"flag"
"log/slog"
"os"
"path"
"path/filepath"
"strconv"
"time"

"github.com/livepeer/ai-worker/worker"
"github.com/oapi-codegen/runtime/types"
)

func main() {
aiModelsDir := flag.String("aiModelsDir", "runner/models", "path to the models directory")
flag.Parse()

containerName := "sketch-to-image"
baseOutputPath := "output"

containerImageID := "livepeer/ai-runner:latest"
gpus := []string{"0"}

modelsDir, err := filepath.Abs(*aiModelsDir)
if err != nil {
slog.Error("Error getting absolute path for 'aiModelsDir'", slog.String("error", err.Error()))
return
}

modelID := "xinsir/controlnet-scribble-sdxl-1.0"

w, err := worker.NewWorker(containerImageID, gpus, modelsDir)
if err != nil {
slog.Error("Error creating worker", slog.String("error", err.Error()))
return
}

slog.Info("Warming container")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if err := w.Warm(ctx, containerName, modelID, worker.RunnerEndpoint{}, worker.OptimizationFlags{}); err != nil {
slog.Error("Error warming container", slog.String("error", err.Error()))
return
}

slog.Info("Warm container is up")

args := os.Args[1:]
runs, err := strconv.Atoi(args[0])
if err != nil {
slog.Error("Invalid runs arg", slog.String("error", err.Error()))
return
}

prompt := args[1]
imagePath := args[2]

imageBytes, err := os.ReadFile(imagePath)
if err != nil {
slog.Error("Error reading image", slog.String("imagePath", imagePath))
return
}
imageFile := types.File{}
imageFile.InitFromBytes(imageBytes, imagePath)

req := worker.GenSketchToImageMultipartRequestBody{
Image: imageFile,
ModelId: &modelID,
Prompt: prompt,
}

for i := 0; i < runs; i++ {
slog.Info("Running sketch-to-image", slog.Int("num", i))

resp, err := w.SketchToImage(ctx, req)
if err != nil {
slog.Error("Error running sketch-to-image", slog.String("error", err.Error()))
return
}

for j, media := range resp.Images {
outputPath := path.Join(baseOutputPath, strconv.Itoa(i)+"_"+strconv.Itoa(j)+".png")
if err := worker.SaveImageB64DataUrl(media.Url, outputPath); err != nil {
slog.Error("Error saving b64 data url as image", slog.String("error", err.Error()))
return
}

slog.Info("Output written", slog.String("outputPath", outputPath))
}
}

slog.Info("Sleeping 2 seconds and then stopping container")

time.Sleep(2 * time.Second)

w.Stop(ctx)

time.Sleep(1 * time.Second)
}
11 changes: 11 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
return SegmentAnything2Pipeline(model_id)
case "llm":
from app.pipelines.llm import LLMPipeline

return LLMPipeline(model_id)
case "image-to-text":
from app.pipelines.image_to_text import ImageToTextPipeline

return ImageToTextPipeline(model_id)
case "sketch-to-image":
from app.pipelines.sketch_to_image import SketchToImagePipeline

return SketchToImagePipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -97,10 +102,16 @@ def load_route(pipeline: str) -> any:
return segment_anything_2.router
case "llm":
from app.routes import llm

return llm.router
case "image-to-text":
from app.routes import image_to_text

return image_to_text.router
case "sketch-to-image":
from app.routes import sketch_to_image

return sketch_to_image.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
123 changes: 123 additions & 0 deletions runner/app/pipelines/sketch_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
import os
from enum import Enum
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoencoderKL,
ControlNetModel,
EulerAncestralDiscreteScheduler,
StableDiffusionXLControlNetPipeline,
)
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile
from safetensors.torch import load_file

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)


class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

SCRIBBLE_SDXL = "xinsir/controlnet-scribble-sdxl-1.0"

@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))


class SketchToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
has_fp16_variant = (
any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
)

torch_dtype = torch.float
if torch_device.type != "cpu" and has_fp16_variant:
logger.info("SketchToImagePipeline loading fp16 variant for %s", model_id)
torch_dtype = torch.float16
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="scheduler",
cache_dir=get_model_dir(),
)
controlnet = ControlNetModel.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
cache_dir=get_model_dir(),
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch_dtype
)
self.ldm = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
safety_checker=None,
scheduler=eulera_scheduler,
).to(torch_device)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"SketchToImagePipeline model_id={self.model_id}"
Loading