Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] FreeInit for AnimateDiff based pipelines #6874

Merged
merged 12 commits into from
Feb 19, 2024
Merged

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Feb 6, 2024

What does this PR do?

The current FreeInit implementation isn't super ideal. Based on the discussions in this PR #6644,

Proposing this change based on the previous discussions that allows reusing the freeinit utils via a Mixin so that it is easy to experiment with adding the feature to other videos pipelines, while trying to avoid introducing two functions to run the denoising loop. This would preserve the existing pipeline denoising loop semantics so that the video pipelines can be read/understood like any other pipeline in diffusers and introducing FreeInit to a video pipeline does not involve adding a lot of boilerplate.

This PR refactors FreeInit for
AnimateDiff
PIA

And adds it to
AnimateDiffVideotoVideo

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 mentioned this pull request Feb 6, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks great to me since we're mostly back to how this was first added. Also, maybe AnimateDiffControlnetPipeline in community pipelines could benefit from these changes as well wdyt? I will do some more testing with SVD and TextToVideoSynth to see if we can easily add it there as well sometime in the near future.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/pia/pipeline_pia.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I left some feedback :)
I'm ok with the nested loop here - it is pretty easy to understand now.
The problem I can think of is if we want to incorporate other techniques that involve nested loops. But we can handle that later when we encounter it

src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
latents = self._denoise_loop(**denoise_args)

video = self._retrieve_video_frames(latents, output_type, return_dict)
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can live with 4 extra lines 😬
does it make sense to add a progress bar for free_init loops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but since we're always setting self._free_init_iters even if free_init isn't enabled, we would get an extra progress bar all the time. We could introduce additional checks to display the progress bar only when free init is enabled but it might not be worth it IMO.

src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
self.temporal_stop_frequency,
)

def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here - I don't think we need a method here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these functions into the FreeInitMixin. I don't think they really need to be accessed outside the Mixin.

src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
shape=latent_shape,
generator=self._free_init_generator,
device=device,
dtype=torch.float32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we use float32 here instead dtype?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copied over from the original codebase. I believe the filtering runs in float32 and then the results are cast back to the dtype that they latents were passed in as. Although, I haven't checked what the results look like when we run filtering in fp16.

I'll run a quick check, and if there's no difference between running filtering in fp16 and fp32, we can just use dtype

src/diffusers/pipelines/animatediff/freeinit_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactor overall!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
do we need a FreeInitTexterMixin? If it makes sense we can put out an issue for the community to take up; But if it's just going to be added to 4 or 5 pipelines might not needed

@DN6
Copy link
Collaborator Author

DN6 commented Feb 16, 2024

@yiyixuxu I think we limit it to just the AnimateDiff based pipelines for now until we see some usage/requests from the community to add to other pipelines. We can add the testing mixin if this expands to >5 pipelines?

@a-r-r-o-w
Copy link
Member

@DN6 Btw does it make sense to also add to AnimateDiff ControlNet (community)?

@DN6
Copy link
Collaborator Author

DN6 commented Feb 17, 2024

@a-r-r-o-w Yeah sure thing 👍🏽.

BTW @a-r-r-o-w AnimateLCM checkpoints are available in diffusers format now.
https://huggingface.co/wangfuyun/AnimateLCM

I tried running FreeInit with AnimateLCM and I ended up with a bunch of noise in the final output. If you're up to it, would you like to run some tests here to pin point the issue? I suspect it has to do with the LCM Scheduler, but don't have the time to really dig into it. LCM + FreeInit could be very nice for better quality videos.

):
if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps
Copy link
Member

@a-r-r-o-w a-r-r-o-w Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 This is incorrect and seems like a regression from the old implementation. I was trying to debug why AnimateLCM was failing to produce good results and stumbled upon this other issue (it does produce good results btw except for when use_fast_sampling==False. setting it to True seems to give good results).

Copy the FreeInit code from here and execute. You will see that the first iteration runs for 20 steps, second iteration runs for 13 steps and third iteration runs for 20 steps. This is incorrect because when use_fast_sampling=True, it should be 7, 13 and 20 but we return here without the fast sampling check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Could I open a PR fixing this behavior since this has been merged already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @a-r-r-o-w missed this. Yes please feel free to open a PR.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Feb 18, 2024

@a-r-r-o-w Yeah sure thing 👍🏽.

BTW @a-r-r-o-w AnimateLCM checkpoints are available in diffusers format now. https://huggingface.co/wangfuyun/AnimateLCM

I tried running FreeInit with AnimateLCM and I ended up with a bunch of noise in the final output. If you're up to it, would you like to run some tests here to pin point the issue? I suspect it has to do with the LCM Scheduler, but don't have the time to really dig into it. LCM + FreeInit could be very nice for better quality videos.

After some testing, it seems like when one uses use_fast_sampling=False, the results are all noisy no matter what the underlying freeinit method is. Setting it to True always gives me good results. Note that the results below have been done after fixing the issue mentioned here. IMO, this doesn't look like an LCMScheduler issue; I will dig deeper soon.

Code
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, LCMScheduler
from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = LCMScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    clip_sample=False,
    timestep_spacing="linspace",
    steps_offset=1
)

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="sd15_lora_beta.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

pipe.enable_vae_slicing()
pipe.enable_free_init(num_iters=3, method="gaussian", use_fast_sampling=True)

output = pipe(
    prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
    negative_prompt="bad quality, worse quality",
    num_frames=16,
    guidance_scale=2.5,
    num_inference_steps=6,
    generator=torch.Generator("cpu").manual_seed(666),
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")

@DN6 DN6 merged commit d2fc5eb into main Feb 19, 2024
15 checks passed
@sayakpaul sayakpaul deleted the refactor-freeinit branch December 3, 2024 10:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants