Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DN6 committed Dec 13, 2023
1 parent 1177d37 commit f07d36e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions tests/lora/test_lora_layers_old_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,8 +2285,8 @@ def test_sdxl_1_0_lora_fusion_efficiency(self):
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

start_time = time.time()
Expand All @@ -2299,13 +2299,13 @@ def test_sdxl_1_0_lora_fusion_efficiency(self):

del pipe

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()

start_time = time.time()
generator = torch.Generator().manual_seed(0)
start_time = time.time()
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
Expand Down
4 changes: 2 additions & 2 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,14 +1934,14 @@ def test_sdxl_1_0_lora_fusion_efficiency(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe.fuse_lora()

# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
# silently deleted - otherwise this will CPU OOM
pipe.unload_lora_weights()

pipe.enable_model_cpu_offload()

start_time = time.time()
generator = torch.Generator().manual_seed(0)
start_time = time.time()
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
Expand Down

0 comments on commit f07d36e

Please sign in to comment.