Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scheduler] support custom timesteps and sigmas #7817

Merged
merged 27 commits into from
May 9, 2024
Merged

[scheduler] support custom timesteps and sigmas #7817

merged 27 commits into from
May 9, 2024

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Apr 29, 2024

custom timesteps and sigmas for schedulers

The general logic is: when a custom timesteps is passed, all the scheduler configuration that is used to generate timesteps (e.g. timestep_spacing) will be ignored; sigmas will still be calculated based on the timesteps and relevant config attributes, e.g. interpolation_type, final_sigma_type

When sigmas is passed, all the scheduler config used to generate timesteps and sigmas is ignored. timesteps will be calculated based on sigmas

this PR:

  • adds custom timesteps and sigmas support for more schedulers:
    • add timesteps for dpm multi-step, dpm single-step, and heun scheduler
    • add both timesteps or sigmas for euler
  • refactored the pipeline method retrive_timesteps so it accepts both timesteps and sigmas and will pass these values to the scheduler's set_timesteps method
  • add sigmas argument to more pipelines (to all the pipelines that currently accept timesteps argument)
  • add custom timesteps/sigmas tests for euler, heun, dpm multi and dpm single
  • added a slow test for sd1.5 and sdxl using custom timesteps and sigmas with euler scheduler: since I already added scheduler tests to test custom timesteps and sigmas and the retrieve_timesteps methods are all copied from sd1.5, I think these two pipeline tests are sufficient

supporting AYS

with this PR, you can now use AYS 10 steps(see #7760) with this script

# test ays (example)
import torch

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.schedulers import AysSchedules

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16").to("cuda")

seed = 123

sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
print(f" sampling_schedule: {sampling_schedule}")

prompt = "anthropomorphic capybara wearing a suit and working with a computer"


generator = torch.Generator(device='cuda').manual_seed(seed)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
print(f"pipe.scheduler: {pipe.scheduler}")

# AYS 10 steps     
image = pipe(
    prompt=prompt, 
    num_inference_steps=10,
    timesteps=sampling_schedule,
    generator=generator).images[0]
print(f" timesteps: {pipe.scheduler.timesteps}")
print(f" sigmas: {pipe.scheduler.sigmas}")
image.save("yiyi_test_8_out_10steps.png")

yiyi_test_8_out_10steps

This PR uses AYS as an example use case, but the goal is to support custom timestep and sigmas in a very general way. Currently, I only updated Euler and DPM for now, but we should extend this to other schedulers whenever it makes sense

Notes on final_sigmas_type

in the scheduler's set_timesteps method, for N steps, we will come up with N timesteps and N+1 sigmas. This is because, at each step, we need both the current sigma and next sigma to calculate the derivative. We let the user decide how to set the final sigma value using the final_sigma_type config: if it is "sigma_min", the final sigma will be calculated based on the beta training schedule for the last timestep; if it is zero, the final sigma is 0

I used the this logic for custom timesteps and sigmas regarding final_sigmas_type

  • if a custom timesteps value is passed, we check to make sure it have the same length as num_inference_steps, and the scheduler will decide the last sigma value based on final_sigmas_type
  • if a custom sigmas value is passed, it has to have the same length as num_inference_steps + 1 - i.e. user has to decide what the last sigma is and includes in the custom sigmas values, sofinal_sigmas_type won't be relevant in this case

Another note is the official AYS schedule corresponds to final_sigmas_type="sigma_min", but in diffusers, we set the default to be "zero" since we think the generation quality is better with this config across all schedulers, especially when the steps are low. see more details on this PR #6477. I run a brief tests in below section to compare these two configurations

Testing across different schedulers and config

I run tests across different schedulers for 2 different configurations: final_sigmas_type = "sigma_min and final_sigmas_type="zero".

testing script

Testing script
# test ays
import torch
import numpy as np
import os
import inspect

from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, EulerDiscreteScheduler
from diffusers.utils import make_image_grid
from diffusers.schedulers import AysSchedules

config_min = {"final_sigmas_type":"sigma_min"}
config_zero = {"final_sigmas_type":"zero"}

schedulers = {
    "Euler": {
        "min": (EulerDiscreteScheduler, config_min),
        "zero": (EulerDiscreteScheduler, config_zero),
    },
    "DPMPP_2M": {
        "min": (DPMSolverMultistepScheduler, config_min),
        "zero": (DPMSolverMultistepScheduler, config_zero),
     },
     "DPMPP_2M_SDE": {
        "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min}),
        "zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_zero}),
     },
     "DPMSolverSinglestepScheduler": {
        "min": (DPMSolverSinglestepScheduler, config_min),
        "zero": (DPMSolverSinglestepScheduler, config_zero),
     },
     "HeunDiscreteScheduler": {
        "zero": (HeunDiscreteScheduler, {}),
     },
}

#model_id = "runwayml/stable-diffusion-v1-5"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16").to("cuda")

seed = 123
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]

prompt = "anthropomorphic capybara wearing a suit and working with a computer"

save_dir = './test_ays'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

test_call_params = {
    "default10" :{
        "prompt" : [prompt],
        "num_inference_steps" : 10,
        },
    "ays10": {
        "prompt" : [prompt],
        "num_inference_steps" : 10,
        "timesteps" : sampling_schedule,
    },
    "default25": {
        "prompt" : [prompt],
        "num_inference_steps" : 25,
    }
}

for scheduler_name in schedulers.keys():
    print(f" ")
    print(f" scheduler_name: {scheduler_name}")
    out_imgs = []
    test_names_str = ""
    for test_name, params in test_call_params.items():
        test_names_str = test_names_str + "_" + test_name 
        print(" ")
        print(f" test_name: {test_name}")
        scheduler_configs = schedulers[scheduler_name]
        for scheduler_config_name in scheduler_configs.keys():
            generator = torch.Generator(device='cuda').manual_seed(seed)
            scheduler = scheduler_configs[scheduler_config_name][0].from_pretrained(
                    model_id,
                    subfolder="scheduler",
                    **scheduler_configs[scheduler_config_name][1],
                )
            pipe.scheduler = scheduler
            print(f" scheduler_config: {scheduler_config_name}")      
            img = pipe(**params, generator=generator).images[0]
            print(f" timesteps: {pipe.scheduler.timesteps}")
            print(f" sigmas: {pipe.scheduler.sigmas}")
            out_imgs.append(img)
    out_img = make_image_grid(out_imgs, rows=len(test_call_params), cols=len(out_imgs)//len(test_call_params))
    out_img.save(os.path.join(save_dir, f"{scheduler_name}{test_names_str}.png"))    
    print(f"saved image to {save_dir}/{scheduler_name}{test_names_str}.png")

output

You can see the testing output for each scheduler in below sections, each output is an image grid with 3 rows and 2 columns:

  • from row1 to row3: normal 10 timesteps, AYS 10 steps normal 25 steps;
  • left column is final_sigma_type="sigma_min" (this is how it's done in AYS official diffusers example), the column on the right is final_sigma_type="zero" (diffusers default setting)

DPMPP_2M output

DPMPP_2M_default10_ays10_default25

DPMPP_2M_SDE output

DPMPP_2M_SDE_default10_ays10_default25

Euler

Euler_default10_ays10_default25

DPM SingleStep

DPMSolverSinglestepScheduler_default10_ays10_default25

Heun

(heun does not have final_sigmas_type argument so just one column, final_sigmas is always 0 for heun)

HeunDiscreteScheduler_default10_ays10_default25

@yiyixuxu
Copy link
Collaborator Author

cc @asomoza here
do you want to test out the diffusers version? is there any other scheduler you would want to add this change to?
the SDE version does not look bad to me here, no?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza
Copy link
Member

asomoza commented Apr 29, 2024

@yiyixuxu yeah, SDE looks good here, in fact they all look better than the tests I did, going to try with some photo realistic prompts.

@isidentical
Copy link
Contributor

does this work for SVD?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 1, 2024

# test ays (svd - default scheduler + default ays steps)
import torch
import numpy as np
import os

from diffusers import DiffusionPipeline
import torch

from diffusers.utils import load_image, export_to_video


model_id = "stabilityai/stable-video-diffusion-img2vid-xt"
pipe = DiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()

# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
image = image.resize((1024, 576))

seed = 123


sampling_schedule10 = [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]


save_dir = './test_ays_default_svd'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

test_call_params = {
    "default10" :{
        "image" : image,
        "decode_chunk_size": 8,
        "num_inference_steps" : 10,
        },
    "ays10":{
        "image" : image,
        "decode_chunk_size": 8,
        "num_inference_steps" : 10,
        "sigmas" : sampling_schedule10,
    },
}

for test_name, params in test_call_params.items():
    print(" ")
    print(f" test_name: {test_name}")
    generator = torch.Generator(device='cuda').manual_seed(seed)
    print(f" default scheduler_config: {pipe.scheduler}")  
    frames = pipe(**params, generator=generator).frames[0]
    print(f" timesteps: {pipe.scheduler.timesteps}")
    print(f" sigmas: {pipe.scheduler.sigmas}")
    export_to_video(frames, f"{save_dir}/{test_name}.mp4", fps=7)

    print(f"saved image to {save_dir}/{test_name}.mp4")

@asomoza
Copy link
Member

asomoza commented May 1, 2024

I've tested with more prompts and I like the results.

25 10 ays
20240501130625_2487854446 20240501131657_2487854446 20240501131637_2487854446

With SDE I get the same bad results as comfyui:

25 10 ays
20240501133342_2487854446 20240501133358_2487854446 20240501133413_2487854446

Also I know they're not that popular but I really like to use HeunDiscreteScheduler and DPMSolverSinglestepScheduler when I want good quality and these ones really gets the benefit of the speed.

HeunDiscreteScheduler

25 10 ays
comfyui_normal_00008_ comfyui_ays_00008_

DPMSolverSinglestepScheduler

25 10 ays
comfyui_normal_00011_ comfyui_ays_00011_

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 1, 2024

@asomoza
oh thanks! very nice! I will add custom timesteps support to heun and dpm single steps :)

I'm a little bit concerned about SDE - the paper released many good results with the DPM SDE variant, and these steps seem to be optimized for SDE variants. Can you share the diffusers script you for SDE?

Another decision we need to make here is whether to make final_sigma_type="zero" the default for AYS 10, the official AYS schedule corresponding to final_sigma_type="sigma_min" in diffusers.
i.e. for SDXL, the sigma schedule is [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] - you can see that here we have 11 sigmas for 10 steps, the last sigma is 0.029 instead of 0; in diffusers we set the last sigma value to be 0 by default, because it just generates better images, especially when the steps are low, you can see more details here #6477

from my brief experiments here, #7817 (comment) I think it also improves AYS results, so I think we should keep that default, which is different from the official one; let me know what you think on this too

@asomoza
Copy link
Member

asomoza commented May 1, 2024

@yiyixuxu maybe is that have bad seed luck and found about it, also it seems that it happens more with the finetunes than the base model.

This is the code I use now with your latest commit:

import torch

from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.schedulers import AysSchedules


sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]

pipe = StableDiffusionXLPipeline.from_pretrained(
    "SG161222/RealVisXL_V4.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++")

prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"

generator = torch.Generator(device="cpu").manual_seed(2487854446)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=7.5,
    num_inference_steps=10,
    generator=generator,
    timesteps=sampling_schedule,
).images[0]


image.save("ays_test.png")

For me with that seed I get a bad result with that model, in comfyui I tested it with juggernaut, but I don't seem to get a bad result with the base model, the quality is worse so no many people use it though.

base normal base ays
20240501170808_3997399393 20240501170835_3190962743

I did get a couple of good results with that model (RealVis) and SDE after, but I get bad results more often.

As for the final sigma, I agree with you in making final_sigma_type="zero" the default for ays, more often than not I see artifacts in the images generated with sigma_min. We can really see this if we set the background to a simple uniform color.

sigma_min zero
20240501173427_507752256 20240501173558_507752256

Using sigma_min is not something I would recommend in any case.

@okaris
Copy link

okaris commented May 2, 2024

if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()

timesteps are being overwritten here. Is this intended?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 2, 2024

@okaris, thanks for catching this!
no it is not intended, I think we should throw an error if custom timesteps/sigmas are passed and it is use_karras_sigmas or use_lu_lambdas are set to be True - do you think it makes sense?

@okaris
Copy link

okaris commented May 2, 2024

I believe that manually setting timesteps implies manual control over sigmas as well. It would be helpful if the scheduler could ignore the use_karras_sigmas or use_lu_lambdas options without generating an error. A warning would be beneficial to inform users about potential discrepancies in their expected outcomes.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 5, 2024

@okaris
I prefer to throw an error here because it is more explicit, even though what you suggested will be easier for the user :)
It is part of our design philosophy, too, see here https://huggingface.co/docs/diffusers/en/conceptual/philosophy#simple-over-easy

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 7, 2024

@dg845 feel free to give this a review if you have time!

@yiyixuxu yiyixuxu requested review from BenjaminBossan and pcuenca May 7, 2024 08:21
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have enough background knowledge to really judge the implementation of passing sigmas explicitly, so I focused more on the implementation. This generally LGTM, I just had some smaller comments.

Moreover, would it make sense to add an entry to the docs. At least to me, just from reading the docstrings, I wouldn't really know what this new feature does.

@@ -170,6 +173,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be better to add a method to the scheduler, like scheduler._accepts_sigmas() or scheduler._check_sigmas(sigmas) instead of inspecting. Same argument for corresponding methods below.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Outdated Show resolved Hide resolved
@@ -189,6 +192,7 @@ def __init__(
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type could be changed to Literal["zero", "sigma_min"] to be more precise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! I would maybe ask the community to help us do this for the whole code base

timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
if timesteps is not None and sigmas is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the next check not already encompass this one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically, these 3 checks together make sure among these 3 variables, one is not None and two are None
is there a better way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I misread the code, it should be fine as is. Another way would be to add them, so:

if (num_inference_steps is None) + (timesteps is None) + (sigmas is None) != 1

but not sure if that's easier to read ;-)

src/diffusers/schedulers/scheduling_heun_discrete.py Outdated Show resolved Hide resolved
@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 7, 2024

@stevhliu, where should I add a short introduction to this custom timesteps and sigmas feature for schedulers?

Basically, this allows users to use AYS (#7817 (comment)) out-of-box, but not limited to this feature only

if there is a new sigmas or timesteps scheduler that's not in diffusers yet, you can manually create it and pass it to pipelines as custom timesteps; also this already exists for LCM scheduler so not a completely new feature, but now we are extending it to more schedulers and pipelines

timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
if timesteps is not None and sigmas is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I misread the code, it should be fine as is. Another way would be to add them, so:

if (num_inference_steps is None) + (timesteps is None) + (sigmas is None) != 1

but not sure if that's easier to read ;-)

@stevhliu
Copy link
Member

stevhliu commented May 8, 2024

where should I add a short introduction to this custom timesteps and sigmas feature for schedulers?

Maybe we can add a section called "Custom schedulers" to Load schedulers and models?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 8, 2024

@stevhliu We also have features like this that we added to multiple schedulers, it would be nice to have more visibility for them #7097

@@ -165,6 +165,62 @@ image

Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.

### Custom Timestep Schedules
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a section here @stevhliu

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably refactor later and also in the same place, talk about use_karras_sigmas and rescale_zero_terminal_snr
most of the scheduler configs you just should keep the same as the default scheduler, but these are the configurations that the user can play around with

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, feel free to give it a review :)

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll refactor this later in a separate PR and create a doc for scheduler features (custom timesteps and sigmas, rescale_zero_terminal_snr, and use_karras_sigmas) 🙂

docs/source/en/using-diffusers/schedulers.md Outdated Show resolved Hide resolved
@yiyixuxu yiyixuxu merged commit b934215 into main May 9, 2024
17 checks passed
@yiyixuxu yiyixuxu deleted the ays branch May 9, 2024 21:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants