Skip to content

Commit

Permalink
[feat: Benchmarking Workflow] add stuff for a benchmarking workflow (h…
Browse files Browse the repository at this point in the history
…uggingface#5839)

* add poc for benchmarking workflow.

* import

* fix argument

* fix: argument

* fix: path

* fix

* fix

* path

* output csv files.

* workflow cleanup

* append token

* add utility to push to hf dataset

* fix: kw arg

* better reporting

* fix: headers

* better formatting of the numbers.

* better type annotation

* fix: formatting

* moentarily disable check

* push results.

* remove disable check

* introduce base classes.

* img2img class

* add inpainting pipeline

* intoduce base benchmark class.

* add img2img and inpainting

* feat: utility to compare changes

* fix

* fix import

* add args

* basepath

* better exception handling

* better path handling

* fix

* fix

* remove

* ifx

* fix

* add: support for controlnet.

* image_url -> url

* move images to huggingface hub

* correct urls.

* root_ckpt

* flush before benchmarking

* don't install accelerate from source

* add runner

* simplify Diffusers Benchmarking step

* change runner

* fix: subprocess call.

* filter percentage values

* fix controlnet benchmark

* add t2i adapters.

* fix filter columns

* fix t2i adapter benchmark

* fix init.

* fix

* remove safetensors flag

* fix args print

* fix

* feat: run_command

* add adapter resolution mapping

* benchmark t2i adapter fix.

* fix adapter input

* fix

* convert to L.

* add flush() add appropriate places

* better filtering

* okay

* get env for torch

* convert to float

* fix

* filter out nans.

* better coment

* sdxl

* sdxl for other benchmarks.

* fix: condition

* fix: condition for inpainting

* fix: mapping for resolution

* fix

* include kandinsky and wuerstchen

* fix: Wuerstchen

* Empty-Commit

* [Community] AnimateDiff + Controlnet Pipeline (huggingface#5928)

* begin work on animatediff + controlnet pipeline

* complete todos, uncomment multicontrolnet, input checks

Co-Authored-By: EdoardoBotta <[email protected]>

* update

Co-Authored-By: EdoardoBotta <[email protected]>

* add example

* update community README

* Update examples/community/README.md

---------

Co-authored-by: EdoardoBotta <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* EulerDiscreteScheduler add `rescale_betas_zero_snr` (huggingface#6024)

* EulerDiscreteScheduler add `rescale_betas_zero_snr`

* Revert "[Community] AnimateDiff + Controlnet Pipeline (huggingface#5928)"

This reverts commit 821726d.

* Revert "EulerDiscreteScheduler add `rescale_betas_zero_snr` (huggingface#6024)"

This reverts commit 3dc2362.

* add SDXL turbo

* add lcm lora to the mix as well.

* fix

* increase steps to 2 when running turbo i2i

* debug

* debug

* debug

* fix for good

* fix and isolate better

* fuse lora so that torch compile works with peft

* fix: LCMLoRA

* better identification for LCM

* change to cron job

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: EdoardoBotta <[email protected]>
Co-authored-by: Beinsezii <[email protected]>
  • Loading branch information
6 people authored and donhardman committed Dec 18, 2023
1 parent 39269f1 commit d38bcbc
Show file tree
Hide file tree
Showing 12 changed files with 791 additions and 1 deletion.
52 changes: 52 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Benchmarking tests

on:
schedule:
- cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM

env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8

jobs:
torch_pipelines_cuda_benchmark_tests:
name: Torch Core Pipelines CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on: [single-gpu, nvidia-gpu, a10, ci]
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install pandas
- name: Environment
run: |
python utils/print_env.py
- name: Diffusers Benchmarking
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
BASE_PATH: benchmark_outputs
run: |
export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: benchmark_test_reports
path: benchmarks/benchmark_outputs
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := examples scripts src tests utils
check_dirs := examples scripts src tests utils benchmarks

modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
Expand Down
297 changes: 297 additions & 0 deletions benchmarks/base_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
import os
import sys

import torch

from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ControlNetModel,
LCMScheduler,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetPipeline,
T2IAdapter,
WuerstchenCombinedPipeline,
)
from diffusers.utils import load_image


sys.path.append(".")

from utils import ( # noqa: E402
BASE_PATH,
PROMPT,
BenchmarkInfo,
benchmark_fn,
bytes_to_giga_bytes,
flush,
generate_csv_dict,
write_to_csv,
)


RESOLUTION_MAPPING = {
"runwayml/stable-diffusion-v1-5": (512, 512),
"lllyasviel/sd-controlnet-canny": (512, 512),
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
"TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
"stabilityai/stable-diffusion-2-1": (768, 768),
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
"stabilityai/sdxl-turbo": (512, 512),
}


class BaseBenchmak:
pipeline_class = None

def __init__(self, args):
super().__init__()

def run_inference(self, args):
raise NotImplementedError

def benchmark(self, args):
raise NotImplementedError

def get_result_filepath(self, args):
pipeline_class_name = str(self.pipe.__class__.__name__)
name = (
args.ckpt.replace("/", "_")
+ "_"
+ pipeline_class_name
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
)
filepath = os.path.join(BASE_PATH, name)
return filepath


class TextToImageBenchmark(BaseBenchmak):
pipeline_class = AutoPipelineForText2Image

def __init__(self, args):
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

if args.run_compile:
if not isinstance(pipe, WuerstchenCombinedPipeline):
pipe.unet.to(memory_format=torch.channels_last)
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
pipe.movq.to(memory_format=torch.channels_last)
pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
else:
print("Run torch compile")
pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)

pipe.set_progress_bar_config(disable=True)
self.pipe = pipe

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)

def benchmark(self, args):
flush()

print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")

time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
benchmark_info = BenchmarkInfo(time=time, memory=memory)

pipeline_class_name = str(self.pipe.__class__.__name__)
flush()
csv_dict = generate_csv_dict(
pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
)
filepath = self.get_result_filepath(args)
write_to_csv(filepath, csv_dict)
print(f"Logs written to: {filepath}")
flush()


class TurboTextToImageBenchmark(TextToImageBenchmark):
def __init__(self, args):
super().__init__(args)

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
)


class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
lora_id = "latent-consistency/lcm-lora-sdxl"

def __init__(self, args):
super().__init__(args)
self.pipe.load_lora_weights(self.lora_id)
self.pipe.fuse_lora()
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)

def get_result_filepath(self, args):
pipeline_class_name = str(self.pipe.__class__.__name__)
name = (
self.lora_id.replace("/", "_")
+ "_"
+ pipeline_class_name
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
)
filepath = os.path.join(BASE_PATH, name)
return filepath

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=1.0,
)


class ImageToImageBenchmark(TextToImageBenchmark):
pipeline_class = AutoPipelineForImage2Image
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
image = load_image(url).convert("RGB")

def __init__(self, args):
super().__init__(args)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)


class TurboImageToImageBenchmark(ImageToImageBenchmark):
def __init__(self, args):
super().__init__(args)

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
strength=0.5,
)


class InpaintingBenchmark(ImageToImageBenchmark):
pipeline_class = AutoPipelineForInpainting
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
mask = load_image(mask_url).convert("RGB")

def __init__(self, args):
super().__init__(args)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
mask_image=self.mask,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)


class ControlNetBenchmark(TextToImageBenchmark):
pipeline_class = StableDiffusionControlNetPipeline
aux_network_class = ControlNetModel
root_ckpt = "runwayml/stable-diffusion-v1-5"

url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
image = load_image(url).convert("RGB")

def __init__(self, args):
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

pipe.set_progress_bar_config(disable=True)
self.pipe = pipe

if args.run_compile:
pipe.unet.to(memory_format=torch.channels_last)
pipe.controlnet.to(memory_format=torch.channels_last)

print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)

self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])

def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)


class ControlNetSDXLBenchmark(ControlNetBenchmark):
pipeline_class = StableDiffusionXLControlNetPipeline
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"

def __init__(self, args):
super().__init__(args)


class T2IAdapterBenchmark(ControlNetBenchmark):
pipeline_class = StableDiffusionAdapterPipeline
aux_network_class = T2IAdapter
root_ckpt = "CompVis/stable-diffusion-v1-4"

url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
image = load_image(url).convert("L")

def __init__(self, args):
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

pipe.set_progress_bar_config(disable=True)
self.pipe = pipe

if args.run_compile:
pipe.unet.to(memory_format=torch.channels_last)
pipe.adapter.to(memory_format=torch.channels_last)

print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)

self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])


class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
pipeline_class = StableDiffusionXLAdapterPipeline
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"

url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
image = load_image(url)

def __init__(self, args):
super().__init__(args)
26 changes: 26 additions & 0 deletions benchmarks/benchmark_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import argparse
import sys


sys.path.append(".")
from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="lllyasviel/sd-controlnet-canny",
choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()

benchmark_pipe = (
ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
)
benchmark_pipe.benchmark(args)
Loading

0 comments on commit d38bcbc

Please sign in to comment.