From acf5175c6f745910166b265976540dc3ec3542af Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 7 Jan 2024 16:33:40 -0800 Subject: [PATCH] Allow noise_scheduler to be configured between ddpm and euler. --- examples/add/train_add_distill_lora_sd_wds.py | 25 +++++++++++++++++-- .../add/train_add_distill_lora_sdxl_wds.py | 25 +++++++++++++++++-- examples/add/train_add_distill_sd_wds.py | 25 +++++++++++++++++-- examples/add/train_add_distill_sdxl_wds.py | 25 +++++++++++++++++-- 4 files changed, 92 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 1ff40ac595a6..188fb42ba192 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -60,6 +60,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -1142,6 +1143,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1353,11 +1363,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index d8d528badfab..17310c31c95e 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -60,6 +60,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -1167,6 +1168,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1386,11 +1396,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 32bb432f43f5..2bfb833e9ddf 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -59,6 +59,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -1120,6 +1121,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1300,11 +1310,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 97805eb3534d..95f7adf403fd 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -59,6 +59,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -1145,6 +1146,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1333,11 +1343,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps