Skip to content

Commit

Permalink
[Refactor] FreeInit for AnimateDiff based pipelines (huggingface#6874)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
DN6 authored Feb 19, 2024
1 parent 779eef9 commit d2fc5eb
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 595 deletions.
376 changes: 64 additions & 312 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -163,7 +164,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
class AnimateDiffVideoToVideoPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin
):
r"""
Pipeline for video-to-video generation.
Expand Down Expand Up @@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
"""

model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

def __init__(
Expand All @@ -215,7 +218,8 @@ def __init__(
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
if isinstance(unet, UNet2DConditionModel):
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)

self.register_modules(
vae=vae,
Expand Down Expand Up @@ -584,12 +588,12 @@ def check_inputs(
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")

def get_timesteps(self, num_inference_steps, strength, device):
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
timesteps = timesteps[t_start * self.scheduler.order :]

return timesteps, num_inference_steps - t_start

Expand Down Expand Up @@ -876,9 +880,8 @@ def __call__(

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand All @@ -901,42 +904,55 @@ def __call__(
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

progress_bar.update()
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
for free_init_iter in range(num_free_init_iters):
if self.free_init_enabled:
latents, timesteps = self._apply_free_init(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)
num_inference_steps = len(timesteps)
# make sure to readjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
Expand Down
184 changes: 184 additions & 0 deletions src/diffusers/pipelines/free_init_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple, Union

import torch
import torch.fft as fft

from ..utils.torch_utils import randn_tensor


class FreeInitMixin:
r"""Mixin class for FreeInit."""

def enable_free_init(
self,
num_iters: int = 3,
use_fast_sampling: bool = False,
method: str = "butterworth",
order: int = 4,
spatial_stop_frequency: float = 0.25,
temporal_stop_frequency: float = 0.25,
):
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
Args:
num_iters (`int`, *optional*, defaults to `3`):
Number of FreeInit noise re-initialization iterations.
use_fast_sampling (`bool`, *optional*, defaults to `False`):
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
method (`str`, *optional*, defaults to `butterworth`):
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
FreeInit low pass filter.
order (`int`, *optional*, defaults to `4`):
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
whereas lower values lead to `gaussian` method behaviour.
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
the original implementation.
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
the original implementation.
"""
self._free_init_num_iters = num_iters
self._free_init_use_fast_sampling = use_fast_sampling
self._free_init_method = method
self._free_init_order = order
self._free_init_spatial_stop_frequency = spatial_stop_frequency
self._free_init_temporal_stop_frequency = temporal_stop_frequency

def disable_free_init(self):
"""Disables the FreeInit mechanism if enabled."""
self._free_init_num_iters = None

@property
def free_init_enabled(self):
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None

def _get_free_init_freq_filter(
self,
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""

time, height, width = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)

if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask

if filter_type == "butterworth":

def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":

def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":

def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")

for t in range(time):
for h in range(height):
for w in range(width):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
+ (2 * h / height - 1) ** 2
+ (2 * w / width - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)

return mask.to(device)

def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))

# frequency mix
high_pass_filter = 1 - low_pass_filter
x_freq_low = x_freq * low_pass_filter
noise_freq_high = noise_freq * high_pass_filter
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain

# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real

return x_mixed

def _apply_free_init(
self,
latents: torch.Tensor,
free_init_iteration: int,
num_inference_steps: int,
device: torch.device,
dtype: torch.dtype,
generator: torch.Generator,
):
if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps

latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)

# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
self.scheduler.set_timesteps(num_inference_steps, device=device)

return latents, self.scheduler.timesteps
Loading

0 comments on commit d2fc5eb

Please sign in to comment.