-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPMSolverMultistep add rescale_betas_zero_snr
#7097
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks
left a question about dtype
@@ -880,6 +930,11 @@ def step( | |||
if self.step_index is None: | |||
self._init_step_index(timestep) | |||
|
|||
# store old dtype because model_output isn't always the same it seems | |||
return_dtype = sample.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean model_output.dtype
isn't always the same as sample.dtype
before the upcast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out that was because model_output = self.convert_model_output(model_output, sample=sample)
ends up creating a shadowed tensor cast to the sample's dtype. I moved the sample upcast after this call so return_type
is no longer needed. Outputs are the same.
yes separate issue please:) |
Avoids having to re-use the dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
can you fix the quality test? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Alright I added a single newline. |
thank you! |
Have we created a separate issue for ZeroSNR + Karras? |
Effectively identical to the implementation for EulerDiscrete in #6024
TL;DR:
rescale_zero_terminal_snr
function copied from DDIM2**-24
from0
when using zsnr so it doesn't produce aninf
sigmatorch.float32
for the duration ofstep()
because the zsnr sigmas are substantially less numerically stable and the performance hit is negligible.use_karras_sigmas=True
still produces strange results as it does in EulerDiscrete. Might be worth investigating separately?Demo images
use_karras_sigmas
also seems to just generally produce bad/noisy results even without ZSNR so I might not have the config set up correctly?