From 93b491b4014fa5d3a97037c3fc011c2bc270a30e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Dec 2023 09:59:04 +0530 Subject: [PATCH] root_ckpt --- .github/workflows/benchmark.yml | 2 +- benchmarks/base_classes.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 08e670c29d7d..1c807c436665 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -41,7 +41,7 @@ jobs: - name: Environment run: | python utils/print_env.py - - name: Stable Diffusion Benchmarking Tests + - name: Diffusers Benchmarking env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} BASE_PATH: benchmark_outputs diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py index 31789e31457b..6d3adb23ed43 100644 --- a/benchmarks/base_classes.py +++ b/benchmarks/base_classes.py @@ -141,21 +141,17 @@ def run_inference(self, pipe, args): class ControlNetBenchmark(TextToImageBenchmark): pipeline_class = StableDiffusionControlNetPipeline aux_network_class = ControlNetModel + root_ckpt = "runwayml/stable-diffusion-v1-5" url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png" image = load_image(url).convert("RGB") def __init__(self, args): - if isinstance(self.pipeline_class, StableDiffusionControlNetPipeline): - root_ckpt = "runwayml/stable-diffusion-v1-5" - elif isinstance(self.pipeline_class, StableDiffusionXLControlNetPipeline): - root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" - aux_network = self.aux_network_class.from_pretrained( args.ckpt, torch_dtype=torch.float16, use_safetensors=True ) pipe = self.pipeline_class.from_pretrained( - root_ckpt, controlnet=aux_network, torch_dtype=torch.float16, use_safetensors=True + self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16, use_safetensors=True ) pipe = pipe.to("cuda") @@ -179,3 +175,4 @@ def run_inference(self, pipe, args): class ControlNetSDXLBenchmark(ControlNetBenchmark): pipeline_class = StableDiffusionXLControlNetPipeline + root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"