diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 761391189f8f..d521ee8ca824 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1376,21 +1376,40 @@ def download_from_original_stable_diffusion_ckpt( num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 if model_type in ["SDXL", "SDXL-Refiner"]: - scheduler_dict = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": num_train_timesteps, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", - } - scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) - scheduler_type = "euler" + if "turbo" in checkpoint_path_or_dict: + scheduler_dict = { + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "trailing", + "trained_betas": None + } + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler-ancestral" + else: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" else: beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085