Skip to content

Commit

Permalink
add inpainting pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Nov 29, 2023
1 parent 9683cd7 commit 274b9e1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
20 changes: 19 additions & 1 deletion benchmarks/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image, AutoPipelineForInpainting
from diffusers.utils import load_image


Expand Down Expand Up @@ -80,3 +80,21 @@ def run_inference(self, pipe, args):
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)


class InpatingPipeline(ImageToImagePipeline):
pipeline_class = AutoPipelineForInpainting
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
mask = load_image(mask_url).convert("RGB")

def run_inference(self, pipe, args):
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])

_ = pipe(
prompt=PROMPT,
image=self.image,
mask_image=self.mask,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
28 changes: 28 additions & 0 deletions benchmarks/benchmark_sd_inpating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import sys


sys.path.append(".")
from benchmarks.base_classes import InpatingPipeline # noqa: E402


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="runwayml/stable-diffusion-v1-5",
choices=[
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()

benchmark_pipe = InpatingPipeline(args)
benchmark_pipe.benchmark()

0 comments on commit 274b9e1

Please sign in to comment.