Skip to content

Commit

Permalink
[LoRA] depcrecate save_attn_procs(). (#10126)
Browse files Browse the repository at this point in the history
depcrecate save_attn_procs().
  • Loading branch information
sayakpaul authored Dec 6, 2024
1 parent 188bca3 commit fa3a910
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ def save_attn_procs(
)
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
else:
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
deprecate("save_attn_procs", "0.40.0", deprecation_message)

if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")

Expand Down
18 changes: 18 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,24 @@ def test_load_attn_procs_raise_warning(self):
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."

@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)

assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."

with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)

warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message


@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit fa3a910

Please sign in to comment.