Skip to content

Commit

Permalink
Add SDXL-Turbo default scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgacangoz committed Dec 10, 2023
1 parent 08b453e commit 8b1a0dc
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,21 +1376,40 @@ def download_from_original_stable_diffusion_ckpt(
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000

if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_dict = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"interpolation_type": "linear",
"num_train_timesteps": num_train_timesteps,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
}
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler"
if "turbo" in checkpoint_path_or_dict:
scheduler_dict = {
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "trailing",
"trained_betas": None
}
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler-ancestral"
else:
scheduler_dict = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"interpolation_type": "linear",
"num_train_timesteps": num_train_timesteps,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
}
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler"
else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
Expand Down

0 comments on commit 8b1a0dc

Please sign in to comment.