Skip to content

Commit

Permalink
Merge branch 'main' into ddpm-scheduler-rescale-zero-snr
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Dec 24, 2023
2 parents 1660394 + 90b9479 commit d260c4a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
6 changes: 5 additions & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
Expand All @@ -835,7 +836,10 @@ def main(args):
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder.add_adapter(text_lora_config)

Expand Down
10 changes: 8 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,15 +978,21 @@ def main(args):

# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
Expand Down
5 changes: 4 additions & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def main():
param.requires_grad_(False)

unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

# Move unet, vae and text_encoder to device and cast to weight_dtype
Expand Down
10 changes: 8 additions & 2 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)
Expand All @@ -618,7 +621,10 @@ def main(args):
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
Expand Down
8 changes: 6 additions & 2 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests:

def get_dummy_components(self, scheduler_cls=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
rank = 4

torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
Expand All @@ -125,11 +126,14 @@ def get_dummy_components(self, scheduler_cls=None):
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")

text_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
r=rank,
lora_alpha=rank,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
)

unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)

unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
Expand Down

0 comments on commit d260c4a

Please sign in to comment.