Skip to content

Commit

Permalink
fix: missing AutoencoderKL lora adapter (#9807)
Browse files Browse the repository at this point in the history
* fix: missing AutoencoderKL lora adapter

* fix

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
beniz and sayakpaul authored Dec 3, 2024
1 parent 30f2e9b commit 963ffca
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
Expand All @@ -34,7 +35,7 @@
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
Expand Down
38 changes: 38 additions & 0 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_peft_available,
load_hf_numpy,
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
Expand All @@ -50,6 +52,10 @@
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin


if is_peft_available():
from peft import LoraConfig


enable_full_determinism()


Expand Down Expand Up @@ -263,6 +269,38 @@ def test_output_pretrained(self):

self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))

@require_peft_backend
def test_lora_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
vae = self.model_class(**init_dict)

target_modules_vae = [
"conv1",
"conv2",
"conv_in",
"conv_shortcut",
"conv",
"conv_out",
"skip_conv_1",
"skip_conv_2",
"skip_conv_3",
"skip_conv_4",
"to_k",
"to_q",
"to_v",
"to_out.0",
]
vae_lora_config = LoraConfig(
r=16,
init_lora_weights="gaussian",
target_modules=target_modules_vae,
)

vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
active_lora = vae.active_adapters()
self.assertTrue(len(active_lora) == 1)
self.assertTrue(active_lora[0] == "vae_lora")


class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
Expand Down

0 comments on commit 963ffca

Please sign in to comment.