Skip to content

Commit

Permalink
Allow noise_scheduler to be configured between ddpm and euler.
Browse files Browse the repository at this point in the history
  • Loading branch information
dg845 committed Jan 8, 2024
1 parent 5cfd137 commit acf5175
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 8 deletions.
25 changes: 23 additions & 2 deletions examples/add/train_add_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerDiscreteScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1142,6 +1143,15 @@ def parse_args():
" timestep T (`noise_scheduler.config.num_train_timesteps - 1`)."
),
)
parser.add_argument(
"--noise_scheduler_type",
type=str,
default="ddpm",
help=(
"The scheduler class to use for the noise scheduler during training. This affects how noise is added to"
" the latents (the forward process). Choose between `ddpm` and `euler`."
),
)
# LoRA Arguments
parser.add_argument(
"--lora_rank",
Expand Down Expand Up @@ -1353,11 +1363,22 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDPMScheduler.from_pretrained(
if args.noise_scheduler_type == "ddpm":
noise_scheduler_cls = DDPMScheduler
elif args.noise_scheduler_type == "euler":
noise_scheduler_cls = EulerDiscreteScheduler
else:
raise ValueError(
f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`"
f" and `euler`."
)
teacher_scheduler = noise_scheduler_cls.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True
noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr)
noise_scheduler = noise_scheduler_cls.from_config(
teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr
)

# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps
Expand Down
25 changes: 23 additions & 2 deletions examples/add/train_add_distill_lora_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1167,6 +1168,15 @@ def parse_args():
" timestep T (`noise_scheduler.config.num_train_timesteps - 1`)."
),
)
parser.add_argument(
"--noise_scheduler_type",
type=str,
default="ddpm",
help=(
"The scheduler class to use for the noise scheduler during training. This affects how noise is added to"
" the latents (the forward process). Choose between `ddpm` and `euler`."
),
)
# LoRA Arguments
parser.add_argument(
"--lora_rank",
Expand Down Expand Up @@ -1386,11 +1396,22 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDPMScheduler.from_pretrained(
if args.noise_scheduler_type == "ddpm":
noise_scheduler_cls = DDPMScheduler
elif args.noise_scheduler_type == "euler":
noise_scheduler_cls = EulerDiscreteScheduler
else:
raise ValueError(
f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`"
f" and `euler`."
)
teacher_scheduler = noise_scheduler_cls.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True
noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr)
noise_scheduler = noise_scheduler_cls.from_config(
teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr
)

# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps
Expand Down
25 changes: 23 additions & 2 deletions examples/add/train_add_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerDiscreteScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1120,6 +1121,15 @@ def parse_args():
" timestep T (`noise_scheduler.config.num_train_timesteps - 1`)."
),
)
parser.add_argument(
"--noise_scheduler_type",
type=str,
default="ddpm",
help=(
"The scheduler class to use for the noise scheduler during training. This affects how noise is added to"
" the latents (the forward process). Choose between `ddpm` and `euler`."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights."
Expand Down Expand Up @@ -1300,11 +1310,22 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDPMScheduler.from_pretrained(
if args.noise_scheduler_type == "ddpm":
noise_scheduler_cls = DDPMScheduler
elif args.noise_scheduler_type == "euler":
noise_scheduler_cls = EulerDiscreteScheduler
else:
raise ValueError(
f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`"
f" and `euler`."
)
teacher_scheduler = noise_scheduler_cls.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True
noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr)
noise_scheduler = noise_scheduler_cls.from_config(
teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr
)

# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps
Expand Down
25 changes: 23 additions & 2 deletions examples/add/train_add_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1145,6 +1146,15 @@ def parse_args():
" timestep T (`noise_scheduler.config.num_train_timesteps - 1`)."
),
)
parser.add_argument(
"--noise_scheduler_type",
type=str,
default="ddpm",
help=(
"The scheduler class to use for the noise scheduler during training. This affects how noise is added to"
" the latents (the forward process). Choose between `ddpm` and `euler`."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights."
Expand Down Expand Up @@ -1333,11 +1343,22 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDPMScheduler.from_pretrained(
if args.noise_scheduler_type == "ddpm":
noise_scheduler_cls = DDPMScheduler
elif args.noise_scheduler_type == "euler":
noise_scheduler_cls = EulerDiscreteScheduler
else:
raise ValueError(
f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`"
f" and `euler`."
)
teacher_scheduler = noise_scheduler_cls.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True
noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr)
noise_scheduler = noise_scheduler_cls.from_config(
teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr
)

# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps
Expand Down

0 comments on commit acf5175

Please sign in to comment.