Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training - don't merge] explicitly set linear_cls to LoRACompatibleLinear when peft is installed #6045

Closed
wants to merge 13 commits into from
19 changes: 14 additions & 5 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.training_utils import compute_snr, replace_linear_cls
from diffusers.utils import check_min_version, is_peft_available, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -466,6 +466,7 @@ def main():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
Expand All @@ -480,10 +481,14 @@ def main():
weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# Replace the `nn.Linear` layers with `LoRACompatibleLinear` layers.
if is_peft_available():
replace_linear_cls(unet)

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
Expand Down Expand Up @@ -700,10 +705,14 @@ def collate_fn(examples):
)

# Prepare everything with our `accelerator`.
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
# )
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)


# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.linear_cls = linear_cls

self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)

Expand Down
7 changes: 2 additions & 5 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,8 @@ def __init__(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear

linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls

self.linear_1 = linear_cls(in_channels, time_embed_dim)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def __init__(
self.skip_time_act = skip_time_act

linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

if groups_out is None:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(

conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls

# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.mapper = linear_cls(c_timestep, c * 2)

def forward(self, x, t):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
super().__init__()
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls

self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from .models import UNet2DConditionModel
from .models.lora import LoRACompatibleLinear
from .utils import deprecate, is_transformers_available


Expand Down Expand Up @@ -53,6 +54,24 @@ def compute_snr(noise_scheduler, timesteps):
return snr


@torch.no_grad()
def replace_linear_cls(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False
new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias)
new_linear_cls.weight.copy_(module.weight.data)
new_linear_cls.weight.data.to(device=module.weight.data.device, dtype=module.weight.data.dtype)
if bias:
new_linear_cls.bias.copy_(module.bias.data)
new_linear_cls.bias.data.to(device=module.bias.data.device, dtype=module.bias.data.dtype)
setattr(model, name, new_linear_cls)

elif len(list(module.children())) > 0:
# Recursively apply the same operation to child modules
replace_linear_cls(module)


def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns:
Expand Down