Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lcm flax #6051

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
"FlaxLCMScheduler",
"FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler",
"FlaxSchedulerMixin",
Expand Down Expand Up @@ -708,6 +709,7 @@
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
FlaxLCMScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxSchedulerMixin,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]
_import_structure["pipeline_stable_diffusion_xl"] = ["StableDiffusionXLPipeline"]
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLCMScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
Expand All @@ -49,7 +50,11 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
FlaxDDIMScheduler,
FlaxPNDMScheduler,
FlaxLMSDiscreteScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLCMScheduler,
],
dtype: jnp.dtype = jnp.float32,
):
Expand Down Expand Up @@ -98,6 +103,7 @@ def __call__(
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
output_type: str = None,
do_classifier_free_guidance: bool = True,
jit: bool = False,
):
# 0. Default height and width to unet
Expand All @@ -124,6 +130,7 @@ def __call__(
latents,
neg_prompt_ids,
return_latents,
do_classifier_free_guidance,
)
else:
images = self._generate(
Expand All @@ -137,6 +144,7 @@ def __call__(
latents,
neg_prompt_ids,
return_latents,
do_classifier_free_guidance,
)

if not return_dict:
Expand Down Expand Up @@ -178,29 +186,34 @@ def _generate(
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
return_latents=False,
do_classifier_free_guidance=True,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

# Encode input prompt
prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params)

# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]
if neg_prompt_ids is None:
neg_prompt_embeds = jnp.zeros_like(prompt_embeds)
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
else:
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)

add_time_ids = self._get_add_time_ids(
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
)

prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]

if do_classifier_free_guidance:
if neg_prompt_ids is None:
neg_prompt_embeds = jnp.zeros_like(prompt_embeds)
negative_pooled_embeds = jnp.zeros_like(pooled_embeds)
else:
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)

prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)

else:
add_text_embeds = pooled_embeds
# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)

Expand Down Expand Up @@ -229,11 +242,14 @@ def _generate(

# Denoising loop
def loop_body(step, args):
latents, scheduler_state = args
latents, scheduler_state, prng_seed = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
if do_classifier_free_guidance:
latents_input = jnp.concatenate([latents] * 2)
else:
latents_input = latents

t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
Expand All @@ -248,20 +264,23 @@ def loop_body(step, args):
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

if do_classifier_free_guidance:
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
latents, scheduler_state = self.scheduler.step(
scheduler_state, noise_pred, t, latents, prng_seed
).to_tuple()
prng_seed = jax.random.split(prng_seed)[0]
return latents, scheduler_state, prng_seed

if DEBUG:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
latents, scheduler_state, prng_seed = loop_body(i, (latents, scheduler_state, prng_seed))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state, prng_seed))

if return_latents:
return latents
Expand All @@ -278,8 +297,8 @@ def loop_body(step, args):
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None),
static_broadcasted_argnums=(0, 4, 5, 6, 10),
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None, None),
static_broadcasted_argnums=(0, 4, 5, 6, 10, 11),
)
def _p_generate(
pipe,
Expand All @@ -293,6 +312,7 @@ def _p_generate(
latents,
neg_prompt_ids,
return_latents,
do_classifier_free_guidance,
):
return pipe._generate(
prompt_ids,
Expand All @@ -305,4 +325,5 @@ def _p_generate(
latents,
neg_prompt_ids,
return_latents,
do_classifier_free_guidance,
)
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
_import_structure["scheduling_lcm_flax"] = ["FlaxLCMScheduler"]
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
_import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"]
Expand Down Expand Up @@ -167,6 +168,7 @@
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lcm_flax import FlaxLCMScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -202,6 +203,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: Optional[jax.Array] = None,
eta: float = 0.0,
return_dict: bool = True,
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: Optional[jax.Array] = None,
return_dict: bool = True,
) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -192,6 +193,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: Optional[jax.Array] = None,
return_dict: bool = True,
) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
"""
Expand Down
Loading