diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000000..c4c3c101dbfd --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,52 @@ +name: Benchmarking tests + +on: + schedule: + - cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM + +env: + DIFFUSERS_IS_CI: yes + HF_HOME: /mnt/cache + OMP_NUM_THREADS: 8 + MKL_NUM_THREADS: 8 + +jobs: + torch_pipelines_cuda_benchmark_tests: + name: Torch Core Pipelines CUDA Benchmarking Tests + strategy: + fail-fast: false + max-parallel: 1 + runs-on: [single-gpu, nvidia-gpu, a10, ci] + container: + image: diffusers/diffusers-pytorch-cuda + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0 + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: NVIDIA-SMI + run: | + nvidia-smi + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e .[quality,test] + python -m pip install pandas + - name: Environment + run: | + python utils/print_env.py + - name: Diffusers Benchmarking + env: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }} + BASE_PATH: benchmark_outputs + run: | + export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))") + cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: benchmark_test_reports + path: benchmarks/benchmark_outputs \ No newline at end of file diff --git a/Makefile b/Makefile index 70bfced8c7b4..c92285b48c71 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := examples scripts src tests utils +check_dirs := examples scripts src tests utils benchmarks modified_only_fixup: $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py new file mode 100644 index 000000000000..5d328f62b904 --- /dev/null +++ b/benchmarks/base_classes.py @@ -0,0 +1,297 @@ +import os +import sys + +import torch + +from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ControlNetModel, + LCMScheduler, + StableDiffusionAdapterPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLAdapterPipeline, + StableDiffusionXLControlNetPipeline, + T2IAdapter, + WuerstchenCombinedPipeline, +) +from diffusers.utils import load_image + + +sys.path.append(".") + +from utils import ( # noqa: E402 + BASE_PATH, + PROMPT, + BenchmarkInfo, + benchmark_fn, + bytes_to_giga_bytes, + flush, + generate_csv_dict, + write_to_csv, +) + + +RESOLUTION_MAPPING = { + "runwayml/stable-diffusion-v1-5": (512, 512), + "lllyasviel/sd-controlnet-canny": (512, 512), + "diffusers/controlnet-canny-sdxl-1.0": (1024, 1024), + "TencentARC/t2iadapter_canny_sd14v1": (512, 512), + "TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024), + "stabilityai/stable-diffusion-2-1": (768, 768), + "stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024), + "stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024), + "stabilityai/sdxl-turbo": (512, 512), +} + + +class BaseBenchmak: + pipeline_class = None + + def __init__(self, args): + super().__init__() + + def run_inference(self, args): + raise NotImplementedError + + def benchmark(self, args): + raise NotImplementedError + + def get_result_filepath(self, args): + pipeline_class_name = str(self.pipe.__class__.__name__) + name = ( + args.ckpt.replace("/", "_") + + "_" + + pipeline_class_name + + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv" + ) + filepath = os.path.join(BASE_PATH, name) + return filepath + + +class TextToImageBenchmark(BaseBenchmak): + pipeline_class = AutoPipelineForText2Image + + def __init__(self, args): + pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) + pipe = pipe.to("cuda") + + if args.run_compile: + if not isinstance(pipe, WuerstchenCombinedPipeline): + pipe.unet.to(memory_format=torch.channels_last) + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None: + pipe.movq.to(memory_format=torch.channels_last) + pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True) + else: + print("Run torch compile") + pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True) + pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True) + + pipe.set_progress_bar_config(disable=True) + self.pipe = pipe + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + ) + + def benchmark(self, args): + flush() + + print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n") + + time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds. + memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. + benchmark_info = BenchmarkInfo(time=time, memory=memory) + + pipeline_class_name = str(self.pipe.__class__.__name__) + flush() + csv_dict = generate_csv_dict( + pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info + ) + filepath = self.get_result_filepath(args) + write_to_csv(filepath, csv_dict) + print(f"Logs written to: {filepath}") + flush() + + +class TurboTextToImageBenchmark(TextToImageBenchmark): + def __init__(self, args): + super().__init__(args) + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + guidance_scale=0.0, + ) + + +class LCMLoRATextToImageBenchmark(TextToImageBenchmark): + lora_id = "latent-consistency/lcm-lora-sdxl" + + def __init__(self, args): + super().__init__(args) + self.pipe.load_lora_weights(self.lora_id) + self.pipe.fuse_lora() + self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + + def get_result_filepath(self, args): + pipeline_class_name = str(self.pipe.__class__.__name__) + name = ( + self.lora_id.replace("/", "_") + + "_" + + pipeline_class_name + + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv" + ) + filepath = os.path.join(BASE_PATH, name) + return filepath + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + guidance_scale=1.0, + ) + + +class ImageToImageBenchmark(TextToImageBenchmark): + pipeline_class = AutoPipelineForImage2Image + url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg" + image = load_image(url).convert("RGB") + + def __init__(self, args): + super().__init__(args) + self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + image=self.image, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + ) + + +class TurboImageToImageBenchmark(ImageToImageBenchmark): + def __init__(self, args): + super().__init__(args) + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + image=self.image, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + guidance_scale=0.0, + strength=0.5, + ) + + +class InpaintingBenchmark(ImageToImageBenchmark): + pipeline_class = AutoPipelineForInpainting + mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png" + mask = load_image(mask_url).convert("RGB") + + def __init__(self, args): + super().__init__(args) + self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) + self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt]) + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + image=self.image, + mask_image=self.mask, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + ) + + +class ControlNetBenchmark(TextToImageBenchmark): + pipeline_class = StableDiffusionControlNetPipeline + aux_network_class = ControlNetModel + root_ckpt = "runwayml/stable-diffusion-v1-5" + + url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png" + image = load_image(url).convert("RGB") + + def __init__(self, args): + aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) + pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16) + pipe = pipe.to("cuda") + + pipe.set_progress_bar_config(disable=True) + self.pipe = pipe + + if args.run_compile: + pipe.unet.to(memory_format=torch.channels_last) + pipe.controlnet.to(memory_format=torch.channels_last) + + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + + self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) + + def run_inference(self, pipe, args): + _ = pipe( + prompt=PROMPT, + image=self.image, + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.batch_size, + ) + + +class ControlNetSDXLBenchmark(ControlNetBenchmark): + pipeline_class = StableDiffusionXLControlNetPipeline + root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" + + def __init__(self, args): + super().__init__(args) + + +class T2IAdapterBenchmark(ControlNetBenchmark): + pipeline_class = StableDiffusionAdapterPipeline + aux_network_class = T2IAdapter + root_ckpt = "CompVis/stable-diffusion-v1-4" + + url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png" + image = load_image(url).convert("L") + + def __init__(self, args): + aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16) + pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16) + pipe = pipe.to("cuda") + + pipe.set_progress_bar_config(disable=True) + self.pipe = pipe + + if args.run_compile: + pipe.unet.to(memory_format=torch.channels_last) + pipe.adapter.to(memory_format=torch.channels_last) + + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True) + + self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt]) + + +class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark): + pipeline_class = StableDiffusionXLAdapterPipeline + root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" + + url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png" + image = load_image(url) + + def __init__(self, args): + super().__init__(args) diff --git a/benchmarks/benchmark_controlnet.py b/benchmarks/benchmark_controlnet.py new file mode 100644 index 000000000000..9217004461dc --- /dev/null +++ b/benchmarks/benchmark_controlnet.py @@ -0,0 +1,26 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="lllyasviel/sd-controlnet-canny", + choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"], + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_pipe = ( + ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args) + ) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_sd_img.py b/benchmarks/benchmark_sd_img.py new file mode 100644 index 000000000000..491e7c9a65a9 --- /dev/null +++ b/benchmarks/benchmark_sd_img.py @@ -0,0 +1,29 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="runwayml/stable-diffusion-v1-5", + choices=[ + "runwayml/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-xl-refiner-1.0", + "stabilityai/sdxl-turbo", + ], + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_sd_inpainting.py b/benchmarks/benchmark_sd_inpainting.py new file mode 100644 index 000000000000..8f36883e16f3 --- /dev/null +++ b/benchmarks/benchmark_sd_inpainting.py @@ -0,0 +1,28 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import InpaintingBenchmark # noqa: E402 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="runwayml/stable-diffusion-v1-5", + choices=[ + "runwayml/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-xl-base-1.0", + ], + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_pipe = InpaintingBenchmark(args) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_t2i_adapter.py b/benchmarks/benchmark_t2i_adapter.py new file mode 100644 index 000000000000..44b04b470ea6 --- /dev/null +++ b/benchmarks/benchmark_t2i_adapter.py @@ -0,0 +1,28 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="TencentARC/t2iadapter_canny_sd14v1", + choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"], + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_pipe = ( + T2IAdapterBenchmark(args) + if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1" + else T2IAdapterSDXLBenchmark(args) + ) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_t2i_lcm_lora.py b/benchmarks/benchmark_t2i_lcm_lora.py new file mode 100644 index 000000000000..957e0a463e28 --- /dev/null +++ b/benchmarks/benchmark_t2i_lcm_lora.py @@ -0,0 +1,23 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import LCMLoRATextToImageBenchmark # noqa: E402 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="stabilityai/stable-diffusion-xl-base-1.0", + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=4) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_pipe = LCMLoRATextToImageBenchmark(args) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/benchmark_text_to_image.py b/benchmarks/benchmark_text_to_image.py new file mode 100644 index 000000000000..caa97b0c5e3b --- /dev/null +++ b/benchmarks/benchmark_text_to_image.py @@ -0,0 +1,40 @@ +import argparse +import sys + + +sys.path.append(".") +from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402 + + +ALL_T2I_CKPTS = [ + "runwayml/stable-diffusion-v1-5", + "segmind/SSD-1B", + "stabilityai/stable-diffusion-xl-base-1.0", + "kandinsky-community/kandinsky-2-2-decoder", + "warp-ai/wuerstchen", + "stabilityai/sdxl-turbo", +] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + type=str, + default="runwayml/stable-diffusion-v1-5", + choices=ALL_T2I_CKPTS, + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_cpu_offload", action="store_true") + parser.add_argument("--run_compile", action="store_true") + args = parser.parse_args() + + benchmark_cls = None + if "turbo" in args.ckpt: + benchmark_cls = TurboTextToImageBenchmark + else: + benchmark_cls = TextToImageBenchmark + + benchmark_pipe = benchmark_cls(args) + benchmark_pipe.benchmark(args) diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py new file mode 100644 index 000000000000..962e07c6d74c --- /dev/null +++ b/benchmarks/push_results.py @@ -0,0 +1,72 @@ +import glob +import sys + +import pandas as pd +from huggingface_hub import hf_hub_download, upload_file +from huggingface_hub.utils._errors import EntryNotFoundError + + +sys.path.append(".") +from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402 + + +def has_previous_benchmark() -> str: + csv_path = None + try: + csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE) + except EntryNotFoundError: + csv_path = None + return csv_path + + +def filter_float(value): + if isinstance(value, str): + return float(value.split()[0]) + return value + + +def push_to_hf_dataset(): + all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv")) + collate_csv(all_csvs, FINAL_CSV_FILE) + + # If there's an existing benchmark file, we should report the changes. + csv_path = has_previous_benchmark() + if csv_path is not None: + current_results = pd.read_csv(FINAL_CSV_FILE) + previous_results = pd.read_csv(csv_path) + + numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns + numeric_columns = [ + c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"] + ] + + for column in numeric_columns: + previous_results[column] = previous_results[column].map(lambda x: filter_float(x)) + + # Calculate the percentage change + current_results[column] = current_results[column].astype(float) + previous_results[column] = previous_results[column].astype(float) + percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100 + + # Format the values with '+' or '-' sign and append to original values + current_results[column] = current_results[column].map(str) + percent_change.map( + lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)" + ) + # There might be newly added rows. So, filter out the NaNs. + current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", "")) + + # Overwrite the current result file. + current_results.to_csv(FINAL_CSV_FILE, index=False) + + commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results" + upload_file( + repo_id=REPO_ID, + path_in_repo=FINAL_CSV_FILE, + path_or_fileobj=FINAL_CSV_FILE, + repo_type="dataset", + commit_message=commit_message, + ) + + +if __name__ == "__main__": + push_to_hf_dataset() diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py new file mode 100644 index 000000000000..c70fb2227383 --- /dev/null +++ b/benchmarks/run_all.py @@ -0,0 +1,97 @@ +import glob +import subprocess +import sys +from typing import List + + +sys.path.append(".") +from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402 + + +PATTERN = "benchmark_*.py" + + +class SubprocessCallException(Exception): + pass + + +# Taken from `test_examples_utils.py` +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +def main(): + python_files = glob.glob(PATTERN) + + for file in python_files: + print(f"****** Running file: {file} ******") + + # Run with canonical settings. + if file != "benchmark_text_to_image.py": + command = f"python {file}" + run_command(command.split()) + + command += " --run_compile" + run_command(command.split()) + + # Run variants. + for file in python_files: + if file == "benchmark_text_to_image.py": + for ckpt in ALL_T2I_CKPTS: + command = f"python {file} --ckpt {ckpt}" + + if "turbo" in ckpt: + command += " --num_inference_steps 1" + + run_command(command.split()) + + command += " --run_compile" + run_command(command.split()) + + elif file == "benchmark_sd_img.py": + for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]: + command = f"python {file} --ckpt {ckpt}" + + if ckpt == "stabilityai/sdxl-turbo": + command += " --num_inference_steps 2" + + run_command(command.split()) + command += " --run_compile" + run_command(command.split()) + + elif file == "benchmark_sd_inpainting.py": + sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" + command = f"python {file} --ckpt {sdxl_ckpt}" + run_command(command.split()) + + command += " --run_compile" + run_command(command.split()) + + elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]: + sdxl_ckpt = ( + "diffusers/controlnet-canny-sdxl-1.0" + if "controlnet" in file + else "TencentARC/t2i-adapter-canny-sdxl-1.0" + ) + command = f"python {file} --ckpt {sdxl_ckpt}" + run_command(command.split()) + + command += " --run_compile" + run_command(command.split()) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 000000000000..5fce920ac6c3 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,98 @@ +import argparse +import csv +import gc +import os +from dataclasses import dataclass +from typing import Dict, List, Union + +import torch +import torch.utils.benchmark as benchmark + + +GITHUB_SHA = os.getenv("GITHUB_SHA", None) +BENCHMARK_FIELDS = [ + "pipeline_cls", + "ckpt_id", + "batch_size", + "num_inference_steps", + "model_cpu_offload", + "run_compile", + "time (secs)", + "memory (gbs)", + "actual_gpu_memory (gbs)", + "github_sha", +] + +PROMPT = "ghibli style, a fantasy landscape with castles" +BASE_PATH = os.getenv("BASE_PATH", ".") +TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3))) + +REPO_ID = "diffusers/benchmarks" +FINAL_CSV_FILE = "collated_results.csv" + + +@dataclass +class BenchmarkInfo: + time: float + memory: float + + +def flush(): + """Wipes off memory.""" + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + +def bytes_to_giga_bytes(bytes): + return f"{(bytes / 1024 / 1024 / 1024):.3f}" + + +def benchmark_fn(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f}, + num_threads=torch.get_num_threads(), + ) + return f"{(t0.blocked_autorange().mean):.3f}" + + +def generate_csv_dict( + pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo +) -> Dict[str, Union[str, bool, float]]: + """Packs benchmarking data into a dictionary for latter serialization.""" + data_dict = { + "pipeline_cls": pipeline_cls, + "ckpt_id": ckpt, + "batch_size": args.batch_size, + "num_inference_steps": args.num_inference_steps, + "model_cpu_offload": args.model_cpu_offload, + "run_compile": args.run_compile, + "time (secs)": benchmark_info.time, + "memory (gbs)": benchmark_info.memory, + "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}", + "github_sha": GITHUB_SHA, + } + return data_dict + + +def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]): + """Serializes a dictionary into a CSV file.""" + with open(file_name, mode="w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS) + writer.writeheader() + writer.writerow(data_dict) + + +def collate_csv(input_files: List[str], output_file: str): + """Collates multiple identically structured CSVs into a single CSV file.""" + with open(output_file, mode="w", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS) + writer.writeheader() + + for file in input_files: + with open(file, mode="r") as infile: + reader = csv.DictReader(infile) + for row in reader: + writer.writerow(row)