Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lcm flax #6051

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
[
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxLCMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
Expand Down Expand Up @@ -705,6 +706,7 @@
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxLCMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]

if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxLCMScheduler
)
from ..pipeline_flax_utils import FlaxDiffusionPipeline
from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
Expand All @@ -49,7 +50,7 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler, FlaxLCMScheduler
],
dtype: jnp.dtype = jnp.float32,
):
Expand Down Expand Up @@ -253,7 +254,7 @@ def loop_body(step, args):
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents, prng_seed).to_tuple()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to create the key inside the scheduler if it's not passed, like FlaxDDPMScheduler does

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See entrpn#1 for a working denoising loop.

return latents, scheduler_state

if DEBUG:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
else:
_import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
_import_structure["scheduling_lcm_flax"] = ["FlaxLCMScheduler"]
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
Expand Down Expand Up @@ -165,6 +166,7 @@
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_lcm_flax import FlaxLCMScheduler
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
Expand Down
Loading
Loading