Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPM/UniPC schedulers (possibly more) seem to have massive stability issues on model KohakuXL vs ComfyUI #5646

Closed
Xynonners opened this issue Nov 4, 2023 · 14 comments
Assignees
Labels
bug Something isn't working scheduler stale Issues that haven't received updates

Comments

@Xynonners
Copy link

Describe the bug

The model KohakuXL in diffusers, seems to run into a massive noise/stability issue that doesn't seem to be present within ComfyUI/AUTO1111, on UniPC/DPM schedulers (possibly more) that doesn't go away with high step counts.
At 15S:
test
At 30S:
test
Euler A 30S:
test

Reproduction

import torch
from accelerate.utils import ProjectConfiguration
from accelerate import Accelerator
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
accelerator_config = ProjectConfiguration(
    project_dir="test",
    automatic_checkpoint_naming=True,
    total_limit=10,
)
accelerator = Accelerator(
    log_with="aim",
    mixed_precision="bf16",
    project_config=accelerator_config,
    gradient_accumulation_steps=16,
)
pipeline = StableDiffusionXLPipeline.from_single_file("models/kohakuXLBeta_beta71.safetensors").to(accelerator.device)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device=accelerator.device).manual_seed(42)
ims = pipeline(width=1024, height=1024, guidance_scale=14.0, prompt="realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace,", num_inference_steps=15, generator=generator)
for img in ims.images:
    img.save("test.png")

Logs

Not Applicable

System Info

  • diffusers version: 0.22.0.dev0
  • Platform: Linux-6.5.6-273-tkg-tt-x86_64-with-glibc2.38
  • Python version: 3.11.5
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Huggingface_hub version: 0.17.3
  • Transformers version: 4.34.1
  • Accelerate version: 0.24.0
  • xFormers version: not installed
  • Using GPU in script?: RTX6000
  • Using distributed or parallel set-up in script?: yes in accelerate config, no in this script

Who can help?

@yiyixuxu @patrickvonplaten

@Xynonners Xynonners added the bug Something isn't working label Nov 4, 2023
@patrickvonplaten
Copy link
Contributor

Can you try whether the fix to DPM provided here: #5541 could help?

@nhnt11
Copy link

nhnt11 commented Nov 7, 2023

In my experience, #5541 helps but not completely. I just filed #5689 to document another issue we found.

@Xynonners
Copy link
Author

@patrickvonplaten using euler_at_final with either default, use_karras_sigmas, or use_lu_lambdas does help the issue slightly, but it is still very apparent.

use_karras_sigmas tends to be less noisy than default, while use_lu_lambdas tends to have larger but less splotches.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 17, 2023

hi @Xynonners

sorry I can't reproduce this. I tried this script - think euler indeed looks better but I would say dpm output is more or less within the expected range, unlike the output you provided. Did I miss anything here?

import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler

pipeline = StableDiffusionXLPipeline.from_pretrained("KBlueLeaf/kohaku-xl-beta7.1", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
generator = torch.Generator(device="cuda").manual_seed(42)
ims = pipeline(prompt="realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace,", num_inference_steps=30, generator=generator)
ims.images[0].save("euler.png")

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(42)
ims = pipeline(prompt="realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace,", num_inference_steps=30, use_lu_lambdas =True, euler_at_final = True, generator=generator)

ims.images[0].save("dpm.png")

euler
yiyi_test_5_out_euler

dpm
yiyi_test_5_out_dpm

@Xynonners
Copy link
Author

@yiyixuxu I narrowed it down to the guidance_scale parameter, it seems like DPM is unstable at high guidance scale.

import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler

pipeline = StableDiffusionXLPipeline.from_single_file("UpDraft/models/kohakuXLBeta_beta71.safetensors", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
generator = torch.Generator(device="cuda").manual_seed(42)
ims = pipeline(guidance_scale=14.0, prompt="realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace,", num_inference_steps=30, generator=generator)
ims.images[0].save("euler.png")

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(42)
ims = pipeline(guidance_scale=14.0, prompt="realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace,", num_inference_steps=30, use_lu_lambdas =True, euler_at_final = True, generator=generator)
ims.images[0].save("dpm.png")

Also, the introduction of accelerate also, while making the model much slower, also seems to serve to destabilize it a larger amount.
dpm:
dpm
dpm + accelerate:
test

@Xynonners
Copy link
Author

Xynonners commented Nov 17, 2023

lowered to 15, seems like euler even has noise issues at this point (cfg 14)
dpm:
dpm
euler:
euler

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 27, 2023
@patrickvonplaten patrickvonplaten removed the stale Issues that haven't received updates label Jan 2, 2024
@huggingface huggingface deleted a comment from github-actions bot Jan 2, 2024
@patrickvonplaten
Copy link
Contributor

@yiyixuxu can you look into this to see if it has been fixed?

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 3, 2024
@huggingface huggingface deleted a comment from github-actions bot Feb 9, 2024
@patrickvonplaten
Copy link
Contributor

Gentle re-ping

@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Feb 10, 2024
Copy link

github-actions bot commented Mar 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 5, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Mar 9, 2024
Copy link

github-actions bot commented Apr 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 2, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Apr 3, 2024
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 3, 2024

cc @Beinsezii here since you have been testing our schedulers
is this still an issue?

@Beinsezii
Copy link
Contributor

DPMSolverMultistep has been my default ever since I added ZSNR and I've never once seen images like that across a fair variety of XL models. Closest I can think of is running a funky noise schedule directly on bf16 but that should be fixed forever since I added an upcast in step() with #7097

@Xynonners should probably try again on the latest pytorch/diffusers.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 27, 2024
@yiyixuxu
Copy link
Collaborator

closing this for now!
feel free to re-open if you continue to see issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working scheduler stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

5 participants