From b934215d4c376ea2e08e28103443686b95ea772c Mon Sep 17 00:00:00 2001 From: YiYi Xu <yixu310@gmail.com> Date: Thu, 9 May 2024 11:07:43 -1000 Subject: [PATCH] [scheduler] support custom `timesteps` and `sigmas` (#7817) * support custom sigmas and timesteps, dpm euler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/schedulers.md | 56 +++++++ .../animatediff/pipeline_animatediff_sdxl.py | 30 +++- .../pipeline_animatediff_video2video.py | 34 ++++- .../controlnet/pipeline_controlnet.py | 30 +++- .../alt_diffusion/pipeline_alt_diffusion.py | 26 +++- .../pipeline_alt_diffusion_img2img.py | 26 +++- .../pipeline_latent_consistency_img2img.py | 21 ++- .../pipeline_latent_consistency_text2img.py | 21 ++- .../pixart_alpha/pipeline_pixart_alpha.py | 35 ++++- .../pixart_alpha/pipeline_pixart_sigma.py | 35 ++++- .../pipeline_stable_diffusion.py | 30 +++- .../pipeline_stable_diffusion_img2img.py | 30 +++- .../pipeline_stable_diffusion_inpaint.py | 30 +++- .../pipeline_stable_diffusion_ldm3d.py | 34 ++++- .../pipeline_stable_diffusion_panorama.py | 21 ++- .../pipeline_stable_diffusion_xl.py | 30 +++- .../pipeline_stable_diffusion_xl_img2img.py | 30 +++- .../pipeline_stable_diffusion_xl_inpaint.py | 30 +++- .../pipeline_stable_video_diffusion.py | 68 ++++++++- .../pipeline_stable_diffusion_adapter.py | 30 +++- .../pipeline_stable_diffusion_xl_adapter.py | 30 +++- src/diffusers/schedulers/__init__.py | 4 +- .../scheduling_dpmsolver_multistep.py | 78 ++++++---- .../scheduling_dpmsolver_singlestep.py | 41 ++++-- .../schedulers/scheduling_euler_discrete.py | 137 +++++++++++++----- .../schedulers/scheduling_heun_discrete.py | 58 +++++--- src/diffusers/schedulers/scheduling_utils.py | 9 ++ .../stable_diffusion/test_stable_diffusion.py | 38 +++++ .../test_stable_diffusion_xl.py | 38 +++++ tests/schedulers/test_scheduler_dpm_multi.py | 48 +++++- tests/schedulers/test_scheduler_dpm_single.py | 37 ++++- tests/schedulers/test_scheduler_euler.py | 94 ++++++++++-- tests/schedulers/test_scheduler_heun.py | 62 ++++++-- 33 files changed, 1116 insertions(+), 205 deletions(-) diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md index 01dab2bed7fe..bfc8aa1a2108 100644 --- a/docs/source/en/using-diffusers/schedulers.md +++ b/docs/source/en/using-diffusers/schedulers.md @@ -212,6 +212,62 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True). images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` +## Custom Timestep Schedules + +With all our schedulers, you can choose one of the popular timestep schedules using configurations such as `timestep_spacing`, `interpolation_type`, and `use_karras_sigmas`. Some schedulers also provide the flexibility to use a custom timestep schedule. You can use any list of arbitrary timesteps, we will use the AYS timestep schedule here as example. It is a set of 10-step optimized timestep schedules released by researchers from Nvidia that can achieve significantly better quality compared to the preset timestep schedules. You can read more about their research [here](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/). + +```python +from diffusers.schedulers import AysSchedules +sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"] +print(sampling_schedule) +``` +``` +[999, 845, 730, 587, 443, 310, 193, 116, 53, 13] +``` + +You can then create a pipeline and pass this custom timestep schedule to it as `timesteps`. + +```python +pipe = StableDiffusionXLPipeline.from_pretrained( + "SG161222/RealVisXL_V4.0", + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++") + +prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up" + +generator = torch.Generator(device="cpu").manual_seed(2487854446) + +image = pipe( + prompt=prompt, + negative_prompt="", + generator=generator, + timesteps=sampling_schedule, +).images[0] +``` +The generated image has better quality than the default linear timestep schedule for the same number of steps, and it is similar to the default timestep scheduler when running for 25 steps. + +<div class="flex gap-4"> + <div> + <img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/> + <figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption> + </div> + <div> + <img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/> + <figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption> + </div> + <div> + <img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/> + <figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption> + </div> +</div> + +> [!TIP] +> 🤗 Diffusers currently only supports `timesteps` and `sigmas` for a selected list of schedulers and pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend feature to a scheduler and pipeline that does not currently support it! + + ## Models Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index f15cd5dbebf7..56146b6d05ca 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -156,6 +156,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -171,14 +172,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -189,6 +194,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -865,6 +880,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -923,6 +939,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -1104,7 +1124,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 106fabba721b..67fa8b12a34b 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -137,6 +137,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -152,14 +153,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -170,6 +175,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -750,6 +765,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, strength: float = 0.8, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -783,6 +799,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. strength (`float`, *optional*, defaults to 0.8): Higher strength leads to more differences between original video and generated video. guidance_scale (`float`, *optional*, defaults to 7.5): @@ -912,7 +936,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index f099d54e57cd..0bf5f1fc8e21 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -97,6 +97,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -112,14 +113,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -130,6 +135,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -892,6 +907,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -941,6 +957,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1162,7 +1182,9 @@ def __call__( assert False # 5. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py index 92de940332c4..d1be4bbe6379 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -79,6 +79,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -94,14 +95,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -112,6 +117,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -673,6 +688,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -848,7 +864,9 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py index 48b3b96483d5..53d67a1b2d33 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -119,6 +119,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -134,14 +135,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -152,6 +157,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -753,6 +768,7 @@ def __call__( strength: float = 0.8, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -919,7 +935,9 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index fce694d1d0bd..a84263eaeeb6 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -63,6 +63,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -78,14 +79,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -96,6 +101,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index a69d49f1ffbb..1bbb1bd91bb8 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -67,6 +67,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -82,14 +83,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -100,6 +105,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 22f0c507e5ee..c1bb5b4584ea 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -175,6 +175,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -190,14 +191,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -208,6 +213,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -672,6 +687,7 @@ def __call__( negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, @@ -707,8 +723,13 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -837,7 +858,9 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index e3dd3ddc0404..0389ac06f363 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -119,6 +119,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -134,14 +135,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -152,6 +157,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -599,6 +614,7 @@ def __call__( negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, @@ -634,8 +650,13 @@ def __call__( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -763,7 +784,9 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ff933237a6ab..3612039f09cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -75,6 +75,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -90,14 +91,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -108,6 +113,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -744,6 +759,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -780,6 +796,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -929,7 +949,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 1b31c099b177..daded18ebd85 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -115,6 +115,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -130,14 +131,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -148,6 +153,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -833,6 +848,7 @@ def __call__( strength: float = 0.8, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -875,6 +891,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1009,7 +1029,9 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1d410b36a600..ab2a0eae7313 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -179,6 +179,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -194,14 +195,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -212,6 +217,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -984,6 +999,7 @@ def __call__( strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -1046,6 +1062,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1220,7 +1240,9 @@ def __call__( ) # 4. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 880d6dcf401a..fcb6db4574f9 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -80,6 +80,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -95,14 +96,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -113,6 +118,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -719,6 +734,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 49, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -751,6 +767,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -888,7 +912,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 514c8643ba36..31af62b609a1 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -80,6 +80,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -95,14 +96,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -113,6 +118,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 5af29b5719ce..b7bc832ec3d3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -107,6 +107,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -122,14 +123,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -140,6 +145,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -820,6 +835,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -876,6 +892,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -1075,7 +1095,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b98ea279c1a2..ada981d4dec9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -124,6 +124,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -139,14 +140,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -157,6 +162,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -964,6 +979,7 @@ def __call__( strength: float = 0.3, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, @@ -1022,6 +1038,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1237,7 +1257,9 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 26c1484e9ae2..9361cc6a13af 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -269,6 +269,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -284,14 +285,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -302,6 +307,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -1199,6 +1214,7 @@ def __call__( strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, @@ -1281,6 +1297,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1498,7 +1518,9 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index da6832cebd4d..b5b048d5cbe0 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -83,6 +83,66 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s return outputs +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + @dataclass class StableVideoDiffusionPipelineOutput(BaseOutput): r""" @@ -343,6 +403,7 @@ def __call__( width: int = 1024, num_frames: Optional[int] = None, num_inference_steps: int = 25, + sigmas: Optional[List[float]] = None, min_guidance_scale: float = 1.0, max_guidance_scale: float = 3.0, fps: int = 7, @@ -374,6 +435,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. This parameter is modulated by `strength`. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. min_guidance_scale (`float`, *optional*, defaults to 1.0): The minimum guidance scale. Used for the classifier free guidance with first frame. max_guidance_scale (`float`, *optional*, defaults to 3.0): @@ -492,8 +557,7 @@ def __call__( added_time_ids = added_time_ids.to(device) # 6. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas) # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index fd5b1bb454b7..5b8dbafdec06 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -124,6 +124,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -139,14 +140,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -157,6 +162,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -669,6 +684,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -707,6 +723,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -816,7 +836,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index fdf938b248e6..6d742b4b7c98 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -140,6 +140,7 @@ def retrieve_timesteps( num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -155,14 +156,18 @@ def retrieve_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -173,6 +178,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps @@ -841,6 +856,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, + sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -900,6 +916,10 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -1101,7 +1121,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 720d8ea25e29..da162ed14419 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -68,7 +68,7 @@ _import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] - _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] + _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] try: @@ -163,7 +163,7 @@ from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler - from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 7e0939e0d927..cd54148b4370 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -303,7 +303,12 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -312,33 +317,54 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. """ - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) - timesteps -= 1 + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index d7a073c2383e..25118f432925 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -274,7 +274,12 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -283,17 +288,33 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is + passed, `num_inference_steps` must be `None`. """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") + + num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps - # Clipping the minimum of all lambda(t) for numerical stability. - # This is critical for cosine (squaredcos_cap_v2) noise schedule. - clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index be5bbc235878..476940d0e4f1 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -189,6 +192,7 @@ def __init__( timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -296,7 +300,13 @@ def scale_model_input( self.is_scale_input_called = True return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -305,60 +315,111 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + sigmas (`List[float]`, *optional*): + Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas + will be generated based on the relevant scheduler attributes. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the + custom sigmas schedule. """ - self.num_inference_steps = num_inference_steps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ - ::-1 - ].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 - else: + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + if ( + timesteps is not None + and self.config.timestep_type == "continuous" + and self.config.prediction_type == "v_prediction" + ): raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 + self.num_inference_steps = num_inference_steps - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) + if sigmas is not None: + log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) + sigmas = np.array(sigmas).astype(np.float32) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + else: + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.float32) + else: + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 + )[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # TODO: Support the full EDM scalings for all prediction types and timestep types if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": - self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) + self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) else: self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication def _sigma_to_t(self, sigma, log_sigmas): # get log sigma diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index fb8f28f73d58..34b2bd11606b 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -224,9 +224,10 @@ def scale_model_input( def set_timesteps( self, - num_inference_steps: int, + num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, + timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -236,30 +237,47 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + num_train_timesteps (`int`, *optional*): + The number of diffusion steps used when training the model. If `None`, the default + `num_train_timesteps` attribute is used. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be + generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` + must be `None`, and `timestep_spacing` attribute will be ignored. """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + + num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 + if timesteps is not None: + timesteps = np.array(timesteps, dtype=np.float32) else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index dcdce6c51f05..6faf9ee38cd5 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -48,6 +48,15 @@ class KarrasDiffusionSchedulers(Enum): EDMEulerScheduler = 15 +AysSchedules = { + "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], + "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], + "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], + "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], + "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], +} + + @dataclass class SchedulerOutput(BaseOutput): """ diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index a4fe12fa6d17..a0fc8c5ef0c0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -259,6 +259,44 @@ def test_stable_diffusion_lcm_custom_timesteps(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_ays(self): + from diffusers.schedulers import AysSchedules + + timestep_schedule = AysSchedules["StableDiffusionTimesteps"] + sigma_schedule = AysSchedules["StableDiffusionSigmas"] + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 10 + output = sd_pipe(**inputs).images + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = None + inputs["timesteps"] = timestep_schedule + output_ts = sd_pipe(**inputs).images + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = None + inputs["sigmas"] = sigma_schedule + output_sigmas = sd_pipe(**inputs).images + + assert ( + np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3 + ), "ays timesteps and ays sigmas should have the same outputs" + assert ( + np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3 + ), "use ays timesteps should have different outputs" + assert ( + np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3 + ), "use ays sigmas should have different outputs" + def test_stable_diffusion_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 324ed52d9901..50f3eb99f016 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -214,6 +214,44 @@ def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_ays(self): + from diffusers.schedulers import AysSchedules + + timestep_schedule = AysSchedules["StableDiffusionXLTimesteps"] + sigma_schedule = AysSchedules["StableDiffusionXLSigmas"] + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 10 + output = sd_pipe(**inputs).images + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = None + inputs["timesteps"] = timestep_schedule + output_ts = sd_pipe(**inputs).images + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = None + inputs["sigmas"] = sigma_schedule + output_sigmas = sd_pipe(**inputs).images + + assert ( + np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3 + ), "ays timesteps and ays sigmas should have the same outputs" + assert ( + np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3 + ), "use ays timesteps should have different outputs" + assert ( + np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3 + ), "use ays sigmas should have different outputs" + def test_stable_diffusion_xl_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**components) diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index fcf8881bc820..ef407eaa3dc9 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -111,9 +111,33 @@ def full_loop(self, scheduler=None, **config): sample = self.dummy_sample_deter scheduler.set_timesteps(num_inference_steps) + generator = torch.manual_seed(0) + for i, t in enumerate(scheduler.timesteps): residual = model(sample, t) - sample = scheduler.step(residual, t, sample).prev_sample + sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + return sample + + def full_loop_custom_timesteps(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + # reset the timesteps using `timesteps` + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps) + + generator = torch.manual_seed(0) + model = self.dummy_model() + sample = self.dummy_sample_deter + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample, generator=generator).prev_sample return sample @@ -309,10 +333,28 @@ def test_fp16_support(self): assert sample.dtype == torch.float16 - def test_duplicated_timesteps(self, **config): + def test_duplicated_timesteps(self): for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) + scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(scheduler.config.num_train_timesteps) assert len(scheduler.timesteps) == scheduler.num_inference_steps + + def test_custom_timesteps(self): + for algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + for final_sigmas_type in ["sigma_min", "zero"]: + sample = self.full_loop( + algorithm_type=algorithm_type, + prediction_type=prediction_type, + final_sigmas_type=final_sigmas_type, + ) + sample_custom_timesteps = self.full_loop_custom_timesteps( + algorithm_type=algorithm_type, + prediction_type=prediction_type, + final_sigmas_type=final_sigmas_type, + ) + assert ( + torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 + ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 251a1506f39b..ea43c210d650 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -103,14 +103,31 @@ def full_loop(self, scheduler=None, **config): scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def full_loop_custom_timesteps(self, **config): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) num_inference_steps = 10 + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + # reset the timesteps using`timesteps` + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps) + model = self.dummy_model() sample = self.dummy_sample_deter - scheduler.set_timesteps(num_inference_steps) for i, t in enumerate(scheduler.timesteps): residual = model(sample, t) @@ -307,3 +324,21 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}" assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}" + + def test_custom_timesteps(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + for lower_order_final in [True, False]: + for final_sigmas_type in ["sigma_min", "zero"]: + sample = self.full_loop( + prediction_type=prediction_type, + lower_order_final=lower_order_final, + final_sigmas_type=final_sigmas_type, + ) + sample_custom_timesteps = self.full_loop_custom_timesteps( + prediction_type=prediction_type, + lower_order_final=lower_order_final, + final_sigmas_type=final_sigmas_type, + ) + assert ( + torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 + ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}" diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index 41c418c5064c..fbb49b164165 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -49,12 +49,13 @@ def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) - def test_full_loop_no_noise(self): + def full_loop(self, **config): scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() + scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(self.num_inference_steps) + num_inference_steps = self.num_inference_steps + scheduler.set_timesteps(num_inference_steps) generator = torch.manual_seed(0) @@ -69,19 +70,46 @@ def test_full_loop_no_noise(self): output = scheduler.step(model_output, t, sample, generator=generator) sample = output.prev_sample + return sample - result_sum = torch.sum(torch.abs(sample)) - result_mean = torch.mean(torch.abs(sample)) + def full_loop_custom_timesteps(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) - assert abs(result_sum.item() - 10.0807) < 1e-2 - assert abs(result_mean.item() - 0.0131) < 1e-3 + num_inference_steps = self.num_inference_steps + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + # reset the timesteps using `timesteps` + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps) - def test_full_loop_with_v_prediction(self): + generator = torch.manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + return sample + + def full_loop_custom_sigmas(self, **config): scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(self.num_inference_steps) + num_inference_steps = self.num_inference_steps + scheduler.set_timesteps(num_inference_steps) + sigmas = scheduler.sigmas + # reset the timesteps using `sigmas` + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps=None, sigmas=sigmas) generator = torch.manual_seed(0) @@ -96,6 +124,19 @@ def test_full_loop_with_v_prediction(self): output = scheduler.step(model_output, t, sample, generator=generator) sample = output.prev_sample + return sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 10.0807) < 1e-2 + assert abs(result_mean.item() - 0.0131) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -189,3 +230,36 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}" assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}" + + def test_custom_timesteps(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + for interpolation_type in ["linear", "log_linear"]: + for final_sigmas_type in ["sigma_min", "zero"]: + sample = self.full_loop( + prediction_type=prediction_type, + interpolation_type=interpolation_type, + final_sigmas_type=final_sigmas_type, + ) + sample_custom_timesteps = self.full_loop_custom_timesteps( + prediction_type=prediction_type, + interpolation_type=interpolation_type, + final_sigmas_type=final_sigmas_type, + ) + assert ( + torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 + ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}" + + def test_custom_sigmas(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + for final_sigmas_type in ["sigma_min", "zero"]: + sample = self.full_loop( + prediction_type=prediction_type, + final_sigmas_type=final_sigmas_type, + ) + sample_custom_timesteps = self.full_loop_custom_sigmas( + prediction_type=prediction_type, + final_sigmas_type=final_sigmas_type, + ) + assert ( + torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 + ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}" diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py index df2b62d97191..a3689ef2ea63 100644 --- a/tests/schedulers/test_scheduler_heun.py +++ b/tests/schedulers/test_scheduler_heun.py @@ -41,12 +41,13 @@ def test_prediction_type(self): for prediction_type in ["epsilon", "v_prediction", "sample"]: self.check_over_configs(prediction_type=prediction_type) - def test_full_loop_no_noise(self): + def full_loop(self, **config): scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() + scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(self.num_inference_steps) + num_inference_steps = self.num_inference_steps + scheduler.set_timesteps(num_inference_steps) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -60,23 +61,20 @@ def test_full_loop_no_noise(self): output = scheduler.step(model_output, t, sample) sample = output.prev_sample - result_sum = torch.sum(torch.abs(sample)) - result_mean = torch.mean(torch.abs(sample)) - - if torch_device in ["cpu", "mps"]: - assert abs(result_sum.item() - 0.1233) < 1e-2 - assert abs(result_mean.item() - 0.0002) < 1e-3 - else: - # CUDA - assert abs(result_sum.item() - 0.1233) < 1e-2 - assert abs(result_mean.item() - 0.0002) < 1e-3 + return sample - def test_full_loop_with_v_prediction(self): + def full_loop_custom_timesteps(self, **config): scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(self.num_inference_steps) + num_inference_steps = self.num_inference_steps + scheduler.set_timesteps(num_inference_steps) + timesteps = scheduler.timesteps + timesteps = torch.cat([timesteps[:1], timesteps[1::2]]) + # reset the timesteps using `timesteps` + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -90,6 +88,23 @@ def test_full_loop_with_v_prediction(self): output = scheduler.step(model_output, t, sample) sample = output.prev_sample + return sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -189,3 +204,18 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}" assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}" + + def test_custom_timesteps(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + for timestep_spacing in ["linspace", "leading"]: + sample = self.full_loop( + prediction_type=prediction_type, + timestep_spacing=timestep_spacing, + ) + sample_custom_timesteps = self.full_loop_custom_timesteps( + prediction_type=prediction_type, + timestep_spacing=timestep_spacing, + ) + assert ( + torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5 + ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"