Skip to content

Commit

Permalink
root_ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Dec 4, 2023
1 parent b358c87 commit 93b491b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: Stable Diffusion Benchmarking Tests
- name: Diffusers Benchmarking
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
BASE_PATH: benchmark_outputs
Expand Down
9 changes: 3 additions & 6 deletions benchmarks/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,17 @@ def run_inference(self, pipe, args):
class ControlNetBenchmark(TextToImageBenchmark):
pipeline_class = StableDiffusionControlNetPipeline
aux_network_class = ControlNetModel
root_ckpt = "runwayml/stable-diffusion-v1-5"

url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
image = load_image(url).convert("RGB")

def __init__(self, args):
if isinstance(self.pipeline_class, StableDiffusionControlNetPipeline):
root_ckpt = "runwayml/stable-diffusion-v1-5"
elif isinstance(self.pipeline_class, StableDiffusionXLControlNetPipeline):
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"

aux_network = self.aux_network_class.from_pretrained(
args.ckpt, torch_dtype=torch.float16, use_safetensors=True
)
pipe = self.pipeline_class.from_pretrained(
root_ckpt, controlnet=aux_network, torch_dtype=torch.float16, use_safetensors=True
self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16, use_safetensors=True
)
pipe = pipe.to("cuda")

Expand All @@ -179,3 +175,4 @@ def run_inference(self, pipe, args):

class ControlNetSDXLBenchmark(ControlNetBenchmark):
pipeline_class = StableDiffusionXLControlNetPipeline
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"

0 comments on commit 93b491b

Please sign in to comment.