-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lcm flax #6051
base: main
Are you sure you want to change the base?
Lcm flax #6051
Conversation
Hi @entrpn, great to see this! 🙌 I understand this is in progress, isn't it? Let me know when you want a review, or if you need help with testing or anything :) |
@pcuenca this is still in progress. I don't think I have the scheduler implemented correctly and images are distorted. My next steps is to compare scheduler values between pytorch and JAX. I could use your help in verifying if the following makes sense:
This is the full code snippet: import os
from diffusers import FlaxStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
import random
from datasets import load_dataset
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))
from transformers import CLIPTokenizer, FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection
from diffusers import (
FlaxAutoencoderKL,
FlaxUNet2DConditionModel,
FlaxLCMScheduler
)
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
lcm_model = "lcm_sdxl"
pipeline_old = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
pipeline_old.load_lora_weights(lcm_lora_id)
pipeline_old.scheduler = LCMScheduler.from_config(pipeline_old.scheduler.config)
pipeline_old.fuse_lora(lora_scale=1.0)
pipeline_old.save_pretrained(lcm_model, safe_serialization=False)
del pipeline_old
weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'
tokenizer = CLIPTokenizer.from_pretrained(
base_model,
revision=revision,
subfolder="tokenizer"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer_2",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
base_model,
revision=revision,
subfolder="text_encoder",
dtype=weight_dtype
)
text_encoder_2 = FlaxCLIPTextModelWithProjection.from_pretrained(
base_model,
revision=revision,
subfolder="text_encoder_2",
dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
base_model,
revision=revision,
subfolder="vae",
dtype=weight_dtype
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
lcm_model,
from_pt=True,
subfolder="unet",
dtype=weight_dtype,
split_head_dim=True
)
scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
lcm_model,
subfolder="scheduler",
dtype=jnp.float32
)
flax_pipeline = FlaxStableDiffusionXLPipeline(
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler
)
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 33
default_guidance_scale = 1.0
default_num_steps = 4
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = flax_pipeline.prepare_inputs(prompt)
neg_prompt_ids = flax_pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
NUM_DEVICES = jax.device_count()
params = {}
params["unet"] = unet_params
params["text_encoder"] = text_encoder.params
params["text_encoder_2"] = text_encoder_2.params
params["vae"] = vae_params
params["scheduler"] = scheduler_state
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = flax_pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
jit=True,
).images
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return flax_pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts", split="test")
dts = []
i = 0
for x in range(2):
random_index = random.randint(0, len(dataset) - 1)
start = time.time()
prompt = dataset[random_index]["Prompt"]
neg_prompt = ""
print(f"Prompt: {prompt}")
images = generate(prompt, neg_prompt)
t = time.time() - start
print(f"Inference in {t}")
dts.append(t)
for img in images:
img.save(f'{i:06d}.jpg')
i += 1
mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)") |
@pcuenca making some progress with the scheduler, but images are coming out incomplete. I'm not sure if its related to my question above (fusing lora, saving it and loading it in flax using from_pt=True), or if I'm missing something in the scheduler. I checked the scheduler values for both pytorch and JAX and they seem correct. Pytorch JAX Image takes form, but looks incomplete. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@pcuenca if you have time to take a look, I could use your help. I'm close, but I cannot figure out what part I'm missing. Thanks! Code to run the pipeline: import os
from diffusers import FlaxStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
import random
from datasets import load_dataset
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))
from transformers import CLIPTokenizer, FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection
from diffusers import (
FlaxAutoencoderKL,
FlaxUNet2DConditionModel,
FlaxLCMScheduler
)
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
lcm_model = "lcm_sdxl"
# pipeline_old = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
# torch_dtype=torch.float16,
# )
# lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
# pipeline_old.load_lora_weights(lcm_lora_id)
# pipeline_old.scheduler = LCMScheduler.from_config(pipeline_old.scheduler.config)
# pipeline_old.fuse_lora(lora_scale=1.0)
# pipeline_old.save_pretrained(lcm_model, safe_serialization=False)
# del pipeline_old
weight_dtype = jnp.float16
revision= 'refs/pr/95'
tokenizer = CLIPTokenizer.from_pretrained(
base_model,
revision=revision,
subfolder="tokenizer"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer_2",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
base_model,
revision=revision,
subfolder="text_encoder",
dtype=weight_dtype
)
text_encoder_2 = FlaxCLIPTextModelWithProjection.from_pretrained(
base_model,
revision=revision,
subfolder="text_encoder_2",
dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
base_model,
revision=revision,
subfolder="vae",
dtype=weight_dtype
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
lcm_model,
from_pt=True,
subfolder="unet",
dtype=weight_dtype,
split_head_dim=False
)
scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
lcm_model,
subfolder="scheduler",
dtype=jnp.float32
)
flax_pipeline = FlaxStableDiffusionXLPipeline(
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler
)
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = flax_pipeline.prepare_inputs(prompt)
neg_prompt_ids = flax_pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
NUM_DEVICES = jax.device_count()
params = {}
params["unet"] = unet_params
params["text_encoder"] = text_encoder.params
params["text_encoder_2"] = text_encoder_2.params
params["vae"] = vae_params
params["scheduler"] = scheduler_state
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = flax_pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
jit=True,
).images
print("images.shape: ", images.shape)
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return flax_pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
dts = []
i = 0
for x in range(2):
start = time.time()
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
neg_prompt = ""
print(f"Prompt: {prompt}")
images = generate(prompt, neg_prompt)
t = time.time() - start
print(f"Inference in {t}")
dts.append(t)
for img in images:
img.save(f'{i:06d}.jpg')
i += 1
mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)") Output image: |
My tests so far:
Next I'll focus on debugging the UNet using the same input data and trying to identify where the computation starts to diverge. I'm thinking there's a configuration setting that might have not been applied to the Flax UNet, but there's a lot of work that's been done to the PyTorch UNet in the past few months. |
@pcuenca thank you for looking into it! |
@entrpn Found it! We have to split the rng key in the loop, otherwise we are always using the same scheduler noise :) I'll clean up my code (I have several other changes that are most probably unnecessary), retest and submit a PR to your PR tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in pretty good shape! I think it could be merged soon after we decide how to handle the key split and clean it up a bit. A couple of questions:
-
Are you planning to allow LoRA loading, or use merged checkpoints instead? I think the second option would be enough to get started and unblock the fast LCM use-case. In that case, I'd export the merged weights in
float32
to a new repo. I can help with that if you want. -
If we intend to frequently use LCM SDXL with no classifier-free guidance, we could do some simplifications that could further reduce memory, allow larger batches or maybe even increase speed. We could avoid concatenating the unconditional+conditional latents and then splitting them. If you think this is useful, maybe we can do a special code path just for that.
@@ -253,7 +254,7 @@ def loop_body(step, args): | |||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |||
|
|||
# compute the previous noisy sample x_t -> x_t-1 | |||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() | |||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents, prng_seed).to_tuple() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option is to create the key inside the scheduler if it's not passed, like FlaxDDPMScheduler does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See entrpn#1 for a working denoising loop.
jax.debug.print("**state.step_index: {x}",x=state.step_index) | ||
|
||
def get_noise(key, shape, dtype): | ||
jax.debug.print("---------------get_noise") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should split the key here and store it as part of the state
. As commented in entrpn#1, I believe that just splitting, like FlaxDDPMScheduler
does, is not correct.
@pcuenca thank you for finding the issue and apologies for the late reply. I have been out of the office for 2 weeks and just getting back to things today. I think for now let's just upload the merged weights to a new repo with this PR. I was working on LoRA loading with diffusers (https://github.com/entrpn/diffusers/tree/f/flax_lora_loading) a while back but never completed it. Maybe this is something we can work together on afterwards :) For your second point, I think its a good idea to avoid concatenating the unconditional+conditional latents. I'll take a more detailed look this week at the code changes you made and get back to you so we can hopefully merge this week or next. Once again, thank you! |
@entrpn sounds great! Let's get this to the finish line and then we can work on LoRA loading :) |
Co-authored-by: Pedro Cuenca <[email protected]>
…d on loop. Linting
@pcuenca I just pushed some new changes:
However, I am still seeing the images not generated correctly. They look better, but not like the Pytorch code. Can you help me review and see if I'm missing anything. |
@pcuenca forget my last comment about it not working. I think it was because I had merged lora with lora_scale=1.0. I re-did the merging with lora_scale=0.7 and its working great! This is ready from a functional perspective, if you can help review any coding styles you prefer and also help upload the weights to a repo, so that it can be loaded directly. Thank you for all your help! Here's the inference code I used. Using
|
Sure, I can help with that! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Not stale! I'm at an offsite this week, let's aim to have this merged next week! |
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.