Skip to content

Commit

Permalink
Add "Copied from" statements in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyLianLong committed Dec 1, 2023
1 parent 284fc82 commit 4a6e40e
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions examples/community/llm_grounded_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,20 +1308,24 @@ def latent_lmd_guidance(

# Below are methods copied from StableDiffusionPipeline
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand All @@ -1330,13 +1334,15 @@ def enable_vae_tiling(self):
"""
self.vae.enable_tiling()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
Expand Down Expand Up @@ -1371,6 +1377,7 @@ def _encode_prompt(

return prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
Expand Down Expand Up @@ -1576,6 +1583,7 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

Expand All @@ -1591,6 +1599,7 @@ def encode_image(self, image, device, num_images_per_prompt):
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
Expand All @@ -1611,6 +1620,7 @@ def run_safety_checker(self, image, device, dtype):
)
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0",
Expand All @@ -1623,6 +1633,7 @@ def decode_latents(self, latents):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -1644,6 +1655,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
Expand Down Expand Up @@ -1678,6 +1690,7 @@ def prepare_latents(
latents = latents * self.scheduler.init_noise_sigma
return latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
Expand All @@ -1700,6 +1713,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
Expand Down Expand Up @@ -1733,29 +1747,35 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32
assert emb.shape == (w.shape[0], embedding_dim)
return emb

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
@property
def guidance_scale(self):
return self._guidance_scale

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
@property
def guidance_rescale(self):
return self._guidance_rescale

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
@property
def clip_skip(self):
return self._clip_skip

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
@property
def num_timesteps(self):
return self._num_timesteps

0 comments on commit 4a6e40e

Please sign in to comment.