Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lcm flax #6051

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Lcm flax #6051

wants to merge 8 commits into from

Conversation

entrpn
Copy link
Contributor

@entrpn entrpn commented Dec 5, 2023

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from pcuenca December 5, 2023 12:36
@pcuenca
Copy link
Member

pcuenca commented Dec 5, 2023

Hi @entrpn, great to see this! 🙌 I understand this is in progress, isn't it? Let me know when you want a review, or if you need help with testing or anything :)

@entrpn
Copy link
Contributor Author

entrpn commented Dec 5, 2023

@pcuenca this is still in progress. I don't think I have the scheduler implemented correctly and images are distorted. My next steps is to compare scheduler values between pytorch and JAX.

I could use your help in verifying if the following makes sense:

  • Take the sdxl pytorch model and fuse it with LCM LoRA.
  • Save the model with fused LoRA and load the UNET into JAX using from_pt=True.

This is the full code snippet:

import os
from diffusers import FlaxStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
import random
from datasets import load_dataset
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

from transformers import CLIPTokenizer, FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection

from diffusers import (
    FlaxAutoencoderKL,
    FlaxUNet2DConditionModel,
    FlaxLCMScheduler
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
lcm_model = "lcm_sdxl"

pipeline_old = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
)
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
pipeline_old.load_lora_weights(lcm_lora_id)
pipeline_old.scheduler = LCMScheduler.from_config(pipeline_old.scheduler.config)

pipeline_old.fuse_lora(lora_scale=1.0)
pipeline_old.save_pretrained(lcm_model, safe_serialization=False)
del pipeline_old

weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'
tokenizer = CLIPTokenizer.from_pretrained(
    base_model,
    revision=revision,
    subfolder="tokenizer"
)

tokenizer_2 = CLIPTokenizer.from_pretrained(
    base_model,
    subfolder="tokenizer_2",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder",
    dtype=weight_dtype
)
text_encoder_2 = FlaxCLIPTextModelWithProjection.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder_2",
    dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
    base_model,
    revision=revision,
    subfolder="vae",
    dtype=weight_dtype
)

unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    lcm_model,
    from_pt=True,
    subfolder="unet",
    dtype=weight_dtype,
    split_head_dim=True
)

scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
    lcm_model,
    subfolder="scheduler",
    dtype=jnp.float32
)

flax_pipeline = FlaxStableDiffusionXLPipeline(
  text_encoder=text_encoder,
  text_encoder_2=text_encoder_2,
  vae=vae,
  tokenizer=tokenizer,
  tokenizer_2=tokenizer_2,
  unet=unet,
  scheduler=scheduler
)

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 33
default_guidance_scale = 1.0
default_num_steps = 4

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = flax_pipeline.prepare_inputs(prompt)
    neg_prompt_ids = flax_pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

params = {}
params["unet"] = unet_params
params["text_encoder"] = text_encoder.params
params["text_encoder_2"] = text_encoder_2.params
params["vae"] = vae_params
params["scheduler"] = scheduler_state

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = flax_pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return flax_pipeline.numpy_to_pil(np.array(images))

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts", split="test")

dts = []
i = 0
for x in range(2):
    random_index = random.randint(0, len(dataset) - 1)
    
    start = time.time()
    prompt = dataset[random_index]["Prompt"]
    neg_prompt = ""

    print(f"Prompt: {prompt}")
    images = generate(prompt, neg_prompt)
    t = time.time() - start
    print(f"Inference in {t}")

    dts.append(t)
    for img in images:
        img.save(f'{i:06d}.jpg')
        i += 1    

mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i},  Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")

Images right now look as follows:
000007

@entrpn
Copy link
Contributor Author

entrpn commented Dec 5, 2023

@pcuenca making some progress with the scheduler, but images are coming out incomplete. I'm not sure if its related to my question above (fusing lora, saving it and loading it in flax using from_pt=True), or if I'm missing something in the scheduler.

I checked the scheduler values for both pytorch and JAX and they seem correct.

Pytorch
timesteps: tensor([999, 759, 499, 259], device='cuda:0')
timestep: tensor(999, device='cuda:0')
alpha_prod_t: tensor(0.0047)
timestep: tensor(759, device='cuda:0')
alpha_prod_t: tensor(0.0522)
timestep: tensor(499, device='cuda:0')
alpha_prod_t: tensor(0.2777)
timestep: tensor(259, device='cuda:0')
alpha_prod_t: tensor(0.6590)

JAX
timesteps: [999 759 499 259]
timestep: 999
alpha_prod_t: 0.004660099744796753
timestep: 759
alpha_prod_t: 0.052212905138731
timestep: 499
alpha_prod_t: 0.2776695489883423
timestep: 259
alpha_prod_t: 0.6589754223823547

Image takes form, but looks incomplete.

000007

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@entrpn
Copy link
Contributor Author

entrpn commented Dec 14, 2023

@pcuenca if you have time to take a look, I could use your help. I'm close, but I cannot figure out what part I'm missing. Thanks!

Code to run the pipeline:

import os
from diffusers import FlaxStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
import random
from datasets import load_dataset
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

from transformers import CLIPTokenizer, FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection

from diffusers import (
    FlaxAutoencoderKL,
    FlaxUNet2DConditionModel,
    FlaxLCMScheduler
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
lcm_model = "lcm_sdxl"

# pipeline_old = DiffusionPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-base-1.0",
#     torch_dtype=torch.float16,
# )
# lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
# pipeline_old.load_lora_weights(lcm_lora_id)
# pipeline_old.scheduler = LCMScheduler.from_config(pipeline_old.scheduler.config)

# pipeline_old.fuse_lora(lora_scale=1.0)
# pipeline_old.save_pretrained(lcm_model, safe_serialization=False)
# del pipeline_old

weight_dtype = jnp.float16
revision= 'refs/pr/95'
tokenizer = CLIPTokenizer.from_pretrained(
    base_model,
    revision=revision,
    subfolder="tokenizer"
)

tokenizer_2 = CLIPTokenizer.from_pretrained(
    base_model,
    subfolder="tokenizer_2",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder",
    dtype=weight_dtype
)
text_encoder_2 = FlaxCLIPTextModelWithProjection.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder_2",
    dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
    base_model,
    revision=revision,
    subfolder="vae",
    dtype=weight_dtype
)

unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    lcm_model,
    from_pt=True,
    subfolder="unet",
    dtype=weight_dtype,
    split_head_dim=False
)

scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
    lcm_model,
    subfolder="scheduler",
    dtype=jnp.float32
)

flax_pipeline = FlaxStableDiffusionXLPipeline(
  text_encoder=text_encoder,
  text_encoder_2=text_encoder_2,
  vae=vae,
  tokenizer=tokenizer,
  tokenizer_2=tokenizer_2,
  unet=unet,
  scheduler=scheduler
)

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = flax_pipeline.prepare_inputs(prompt)
    neg_prompt_ids = flax_pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

params = {}
params["unet"] = unet_params
params["text_encoder"] = text_encoder.params
params["text_encoder_2"] = text_encoder_2.params
params["vae"] = vae_params
params["scheduler"] = scheduler_state

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = flax_pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        jit=True,
    ).images
    print("images.shape: ", images.shape)
    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return flax_pipeline.numpy_to_pil(np.array(images))

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

dts = []
i = 0
for x in range(2):
    
    start = time.time()
    prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
    neg_prompt = ""

    print(f"Prompt: {prompt}")
    images = generate(prompt, neg_prompt)
    t = time.time() - start
    print(f"Inference in {t}")

    dts.append(t)
    for img in images:
        img.save(f'{i:06d}.jpg')
        i += 1    

mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i},  Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")

Output image:

000005

@pcuenca
Copy link
Member

pcuenca commented Jan 5, 2024

My tests so far:

  • I can reproduce your results with float32 and bfloat16. float16 produces black images.
  • The scheduler implementation looks correct, but it's hard to verify due to its stochastic nature. I'm thinking about pre-creating the noise and using the same in both implementations.
  • Running the pipeline with 1 inference step from the same random latents produces different results but similar textures.
  • I could reproduce this type of result with the PyTorch pipeline using a high guidance scale. I noticed that guidance scale is always assumed in the Flax pipeline and this line always runs, but it's not an issue when guidance scale is 1.
  • One pass through the UNet produces slightly different results; I didn't measure the difference but will do and report it here.

Next I'll focus on debugging the UNet using the same input data and trying to identify where the computation starts to diverge. I'm thinking there's a configuration setting that might have not been applied to the Flax UNet, but there's a lot of work that's been done to the PyTorch UNet in the past few months.

@entrpn
Copy link
Contributor Author

entrpn commented Jan 5, 2024

@pcuenca thank you for looking into it!

@pcuenca
Copy link
Member

pcuenca commented Jan 5, 2024

Update

I was misled down the wrong path because of a minor difference between the prompt I was using for JAX and the one for PyTorch 🤦

  • Actually, when using 1 inference step the results are identical between the JAX and the PyTorch versions (starting from the same latents). This is to be expected as the scheduler does not create any new noise when using a single inference step:
Screenshot 2024-01-05 at 20 45 55
  • When using 2 steps, the results differ because the scheduler now generates noise, but the output looks fine:
Screenshot 2024-01-05 at 20 47 56
  • At 3 steps, however, we start seeing saturation in the JAX version:
Screenshot 2024-01-05 at 20 48 51
  • And it becomes much worse when using 4 steps:
Screenshot 2024-01-05 at 20 49 28

This is a very good hint, something must be going out of range for some reason. The good news is that the fused LoRA export looks correct! (I double-checked with a new export too).

@pcuenca
Copy link
Member

pcuenca commented Jan 5, 2024

@entrpn Found it! We have to split the rng key in the loop, otherwise we are always using the same scheduler noise :)

I'll clean up my code (I have several other changes that are most probably unnecessary), retest and submit a PR to your PR tomorrow.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in pretty good shape! I think it could be merged soon after we decide how to handle the key split and clean it up a bit. A couple of questions:

  • Are you planning to allow LoRA loading, or use merged checkpoints instead? I think the second option would be enough to get started and unblock the fast LCM use-case. In that case, I'd export the merged weights in float32 to a new repo. I can help with that if you want.

  • If we intend to frequently use LCM SDXL with no classifier-free guidance, we could do some simplifications that could further reduce memory, allow larger batches or maybe even increase speed. We could avoid concatenating the unconditional+conditional latents and then splitting them. If you think this is useful, maybe we can do a special code path just for that.

@@ -253,7 +254,7 @@ def loop_body(step, args):
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents, prng_seed).to_tuple()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to create the key inside the scheduler if it's not passed, like FlaxDDPMScheduler does

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See entrpn#1 for a working denoising loop.

jax.debug.print("**state.step_index: {x}",x=state.step_index)

def get_noise(key, shape, dtype):
jax.debug.print("---------------get_noise")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should split the key here and store it as part of the state. As commented in entrpn#1, I believe that just splitting, like FlaxDDPMScheduler does, is not correct.

src/diffusers/schedulers/scheduling_lcm_flax.py Outdated Show resolved Hide resolved
@entrpn
Copy link
Contributor Author

entrpn commented Jan 22, 2024

@pcuenca thank you for finding the issue and apologies for the late reply. I have been out of the office for 2 weeks and just getting back to things today.

I think for now let's just upload the merged weights to a new repo with this PR. I was working on LoRA loading with diffusers (https://github.com/entrpn/diffusers/tree/f/flax_lora_loading) a while back but never completed it. Maybe this is something we can work together on afterwards :)

For your second point, I think its a good idea to avoid concatenating the unconditional+conditional latents.

I'll take a more detailed look this week at the code changes you made and get back to you so we can hopefully merge this week or next.

Once again, thank you!

@pcuenca
Copy link
Member

pcuenca commented Jan 22, 2024

@entrpn sounds great! Let's get this to the finish line and then we can work on LoRA loading :)

@entrpn
Copy link
Contributor Author

entrpn commented Jan 24, 2024

@pcuenca I just pushed some new changes:

  • creates a new key in loop as you did in entrpn#1
  • Added do_classifier_free_guidance flag to generate inside sdxl flax pipeline. It speeds up the generation roughly 35%.
  • Updated schedulers that are used in sdxl flax to include key inside step() as optional. This is required for the non LCM model to work.
  • Cleaned up the code and ran linting.

However, I am still seeing the images not generated correctly. They look better, but not like the Pytorch code. Can you help me review and see if I'm missing anything.

@entrpn
Copy link
Contributor Author

entrpn commented Jan 25, 2024

@pcuenca forget my last comment about it not working. I think it was because I had merged lora with lora_scale=1.0. I re-did the merging with lora_scale=0.7 and its working great!

This is ready from a functional perspective, if you can help review any coding styles you prefer and also help upload the weights to a repo, so that it can be loaded directly.

Thank you for all your help!

000001

Here's the inference code I used. Using do_classifier_free_guidance=False, has some performance improvements as well.

import os
from diffusers import FlaxStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, LCMScheduler
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
from datasets import load_dataset
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))

from transformers import CLIPTokenizer, FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection

from diffusers import (
    FlaxAutoencoderKL,
    FlaxUNet2DConditionModel,
    FlaxLCMScheduler
)

base_model = "stabilityai/stable-diffusion-xl-base-1.0"
lcm_model = "sd_lora_model"

# run once to create merged weights
# pipeline_old = DiffusionPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-base-1.0",
#     torch_dtype=torch.float16, variant="fp16"
# )
# lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
# pipeline_old.load_lora_weights(lcm_lora_id)
# pipeline_old.scheduler = LCMScheduler.from_config(pipeline_old.scheduler.config)

# pipeline_old.fuse_lora(lora_scale=0.7)
# pipeline_old.save_pretrained(lcm_model, safe_serialization=False)
# del pipeline_old

weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'
tokenizer = CLIPTokenizer.from_pretrained(
    base_model,
    revision=revision,
    subfolder="tokenizer"
)

tokenizer_2 = CLIPTokenizer.from_pretrained(
    base_model,
    subfolder="tokenizer_2",
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder",
    dtype=weight_dtype
)
text_encoder_2 = FlaxCLIPTextModelWithProjection.from_pretrained(
    base_model,
    revision=revision,
    subfolder="text_encoder_2",
    dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
    base_model,
    revision=revision,
    subfolder="vae",
    dtype=weight_dtype
)

unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
    lcm_model,
    from_pt=True,
    subfolder="unet",
    dtype=weight_dtype,
    split_head_dim=False
)

scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
    lcm_model,
    subfolder="scheduler",
    dtype=jnp.float32
)

flax_pipeline = FlaxStableDiffusionXLPipeline(
  text_encoder=text_encoder,
  text_encoder_2=text_encoder_2,
  vae=vae,
  tokenizer=tokenizer,
  tokenizer_2=tokenizer_2,
  unet=unet,
  scheduler=scheduler
)

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = flax_pipeline.prepare_inputs(prompt)
    neg_prompt_ids = flax_pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

params = {}
params["unet"] = unet_params
params["text_encoder"] = text_encoder.params
params["text_encoder_2"] = text_encoder_2.params
params["vae"] = vae_params
params["scheduler"] = scheduler_state

p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = flax_pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        do_classifier_free_guidance=False,
        jit=True,
    ).images
    print("images.shape: ", images.shape)
    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return flax_pipeline.numpy_to_pil(np.array(images))

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

dts = []
i = 0
for x in range(2):
    
    start = time.time()
    prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
    #prompt = "Beautiful painting of a castle in the woods"
    neg_prompt = ""

    print(f"Prompt: {prompt}")
    images = generate(prompt, neg_prompt)
    t = time.time() - start
    print(f"Inference in {t}")

    dts.append(t)
    for img in images:
        img.save(f'{i:06d}.jpg')
        i += 1    

mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i},  Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")

@pcuenca
Copy link
Member

pcuenca commented Jan 26, 2024

This is ready from a functional perspective, if you can help review any coding styles you prefer and also help upload the weights to a repo, so that it can be loaded directly.

Sure, I can help with that!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 19, 2024
@pcuenca pcuenca added wip and removed stale Issues that haven't received updates labels Feb 19, 2024
@pcuenca
Copy link
Member

pcuenca commented Feb 19, 2024

Not stale! I'm at an offsite this week, let's aim to have this merged next week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants