From e856faea4f354c9e4e7b0bc2dd93d40945505b0c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Dec 2023 17:30:03 +0530 Subject: [PATCH 01/12] feat: make linear_cls a class member when needed. --- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/activations.py | 1 + src/diffusers/models/attention_processor.py | 7 ++----- src/diffusers/models/embeddings.py | 1 + src/diffusers/models/resnet.py | 1 + src/diffusers/models/transformer_2d.py | 1 + .../pipelines/wuerstchen/modeling_wuerstchen_common.py | 1 + .../pipelines/wuerstchen/modeling_wuerstchen_prior.py | 1 + 8 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 49ee3ee6af6b..e3794939e25e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -33,8 +33,8 @@ _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] - _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["embeddings"] = ["ImageProjection"] + _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 47570eca8443..5f23ad154a0d 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -88,6 +88,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + self.linear_cls = linear_cls self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 40a335527ace..f37c11ffb7b4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -175,11 +175,8 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - if USE_PEFT_BACKEND: - linear_cls = nn.Linear - else: - linear_cls = LoRACompatibleLinear - + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bdd2930d20f9..5ee980ccaab1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -200,6 +200,7 @@ def __init__( ): super().__init__() linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls self.linear_1 = linear_cls(in_channels, time_embed_dim) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 970d2be05b7a..9bf7ab875040 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -649,6 +649,7 @@ def __init__( self.skip_time_act = skip_time_act linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv if groups_out is None: diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 3aecc43f0f5b..aebd02294831 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -107,6 +107,7 @@ def __init__( conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index 00d6f01beced..3ae0c6b493bf 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -35,6 +35,7 @@ class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls self.mapper = linear_cls(c_timestep, c * 2) def forward(self, x, t): diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index a7d9e32fb6c9..2f6088213867 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -43,6 +43,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro super().__init__() conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + self.linear_cls = linear_cls self.c_r = c_r self.projection = conv_cls(c_in, c, kernel_size=1) From 5f6164cdc59bf6a3a0092c76f9834263732a99d7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Dec 2023 17:53:49 +0530 Subject: [PATCH 02/12] replace linear_cls in case peft is installed. --- examples/text_to_image/train_text_to_image_lora.py | 7 +++++-- src/diffusers/training_utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c030c59693c3..c77e9b2ecc40 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -42,8 +42,8 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.training_utils import compute_snr, replace_linear_cls +from diffusers.utils import check_min_version, is_peft_available, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -466,6 +466,9 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) + if is_peft_available(): + replace_linear_cls(unet) + # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 992ae7d1b194..922e13799c61 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -53,6 +53,19 @@ def compute_snr(noise_scheduler, timesteps): return snr +def replace_linear_cls(model): + from .models.lora import LoRACompatibleLinear + + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + bias = True if hasattr(module, "bias") else False + new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) + setattr(model, name, new_linear_cls) + elif len(list(module.children())) > 0: + # Recursively apply the same operation to child modules + replace_linear_cls(module) + + def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: From 86b44367e940be95f91dfb212b911306f8b94cba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 4 Dec 2023 18:24:35 +0530 Subject: [PATCH 03/12] setting the params too --- src/diffusers/training_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 922e13799c61..9855ff54f7b5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -53,6 +53,9 @@ def compute_snr(noise_scheduler, timesteps): return snr +torch.no_grad() + + def replace_linear_cls(model): from .models.lora import LoRACompatibleLinear @@ -60,6 +63,9 @@ def replace_linear_cls(model): if isinstance(module, torch.nn.Linear): bias = True if hasattr(module, "bias") else False new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) + new_linear_cls.weight.copy_(module.weight.data) + if bias: + new_linear_cls.bias.copy_(module.weight.bias) setattr(model, name, new_linear_cls) elif len(list(module.children())) > 0: # Recursively apply the same operation to child modules From 1d228b8eb04f1903c43a2677810f69e8c5e4db15 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:11:48 +0530 Subject: [PATCH 04/12] torch.no_grad deco --- src/diffusers/training_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9855ff54f7b5..1c00d618b3dd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -53,9 +53,7 @@ def compute_snr(noise_scheduler, timesteps): return snr -torch.no_grad() - - +@torch.no_grad() def replace_linear_cls(model): from .models.lora import LoRACompatibleLinear From 0bb32cc28519c8f968e3ff9bb8e84421cc1fa9fd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:12:16 +0530 Subject: [PATCH 05/12] import --- src/diffusers/training_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 1c00d618b3dd..c4b5b6a39a83 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -7,6 +7,7 @@ import torch from .models import UNet2DConditionModel +from .models.lora import LoRACompatibleLinear from .utils import deprecate, is_transformers_available @@ -55,8 +56,6 @@ def compute_snr(noise_scheduler, timesteps): @torch.no_grad() def replace_linear_cls(model): - from .models.lora import LoRACompatibleLinear - for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): bias = True if hasattr(module, "bias") else False From 4c38c229e112e58b8a9ef1800cb25cfa114cec84 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:15:49 +0530 Subject: [PATCH 06/12] fix: bias copy --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c4b5b6a39a83..f4fc45ffb698 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -62,7 +62,7 @@ def replace_linear_cls(model): new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) new_linear_cls.weight.copy_(module.weight.data) if bias: - new_linear_cls.bias.copy_(module.weight.bias) + new_linear_cls.bias.copy_(module.bias.data) setattr(model, name, new_linear_cls) elif len(list(module.children())) > 0: # Recursively apply the same operation to child modules From c580ff04d5c936c2b656672837f7b1b38a7bcebd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:17:06 +0530 Subject: [PATCH 07/12] better handle bias --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f4fc45ffb698..ecaa1ee23b99 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -58,7 +58,7 @@ def compute_snr(noise_scheduler, timesteps): def replace_linear_cls(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): - bias = True if hasattr(module, "bias") else False + bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) new_linear_cls.weight.copy_(module.weight.data) if bias: From 3e051abd75a1d97aabbac2ad48936c3897e883f1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:19:05 +0530 Subject: [PATCH 08/12] device and dtype --- src/diffusers/training_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index ecaa1ee23b99..f32469689e31 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -60,10 +60,13 @@ def replace_linear_cls(model): if isinstance(module, torch.nn.Linear): bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) + new_linear_cls.to(device=module.weight.data.device, dtype=module.weight.data.dtype) + new_linear_cls.weight.copy_(module.weight.data) if bias: new_linear_cls.bias.copy_(module.bias.data) setattr(model, name, new_linear_cls) + elif len(list(module.children())) > 0: # Recursively apply the same operation to child modules replace_linear_cls(module) From 3b42e96912036925d5c28683f4047ad27be885fa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:23:43 +0530 Subject: [PATCH 09/12] better/ --- src/diffusers/training_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index f32469689e31..527c91fcaaca 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -59,12 +59,12 @@ def replace_linear_cls(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False - new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) - new_linear_cls.to(device=module.weight.data.device, dtype=module.weight.data.dtype) - + new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias) new_linear_cls.weight.copy_(module.weight.data) + new_linear_cls.weight.data.to(device=module.weight.data.device, dtype=module.weight.data.dtype) if bias: new_linear_cls.bias.copy_(module.bias.data) + new_linear_cls.bias.data.to(device=module.bias.data.device, dtype=module.bias.data.dtype) setattr(model, name, new_linear_cls) elif len(list(module.children())) > 0: From 212afef9bfa9a91546248e48b063273777d63aae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:27:26 +0530 Subject: [PATCH 10/12] debug --- src/diffusers/models/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index daac8f902cd6..0d85de9295c4 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -217,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype + print(f"hidden_states: {hidden_states.device}, Weight: {self.down.weight.data.device}") down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) From 8285e4e72d0d5682d6a2a917d21ceec06ab438c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:30:38 +0530 Subject: [PATCH 11/12] potentially fix device placement --- examples/text_to_image/train_text_to_image_lora.py | 6 ++++-- src/diffusers/models/lora.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c77e9b2ecc40..f2c91505f7bf 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -466,8 +466,6 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - if is_peft_available(): - replace_linear_cls(unet) # freeze parameters of models to save more memory unet.requires_grad_(False) @@ -487,6 +485,10 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + # Replace the `nn.Linear` layers with `LoRACompatibleLinear` layers. + if is_peft_available(): + replace_linear_cls(unet) + # now we will add new LoRA weights to the attention layers # It's important to realize here how many attention weights will be added and of which sizes # The sizes of the attention layers consist only of two different variables: diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 0d85de9295c4..daac8f902cd6 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -217,7 +217,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype - print(f"hidden_states: {hidden_states.device}, Weight: {self.down.weight.data.device}") down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) From fa9df3466005ff29a640000ae0218433e5411c72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Dec 2023 17:41:03 +0530 Subject: [PATCH 12/12] check --- examples/text_to_image/train_text_to_image_lora.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index f2c91505f7bf..d3db8857d349 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -481,7 +481,7 @@ def main(): weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) + # unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -705,10 +705,14 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + # unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + # unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + # ) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: