diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index abd169b8bc97..29fe2744ad7a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -225,6 +225,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -1064,6 +1070,7 @@ def main(args): args.pretrained_model_name_or_path, torch_dtype=torch_dtype, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -1102,10 +1109,18 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + variant=args.variant, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + variant=args.variant, + use_fast=False, ) # import correct text encoder classes @@ -1119,10 +1134,10 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -1130,10 +1145,13 @@ def main(args): else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) if args.train_text_encoder_ti: @@ -1843,10 +1861,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # create pipeline if freeze_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1855,6 +1879,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_2=accelerator.unwrap_model(text_encoder_two), unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) @@ -1932,10 +1957,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it