Skip to content

Commit

Permalink
Add callback parameters for Stable Diffusion pipelines (huggingface#521)
Browse files Browse the repository at this point in the history
* Add callback parameters for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Lint code with `black --preview`

Signed-off-by: James R T <[email protected]>

* Refactor callback implementation for Stable Diffusion pipelines

* Fix missing imports

Signed-off-by: James R T <[email protected]>

* Fix documentation format

Signed-off-by: James R T <[email protected]>

* Add kwargs parameter to standardize with other pipelines

Signed-off-by: James R T <[email protected]>

* Modify Stable Diffusion pipeline callback parameters

Signed-off-by: James R T <[email protected]>

* Remove useless imports

Signed-off-by: James R T <[email protected]>

* Change types for timestep and onnx latents

* Fix docstring style

* Return decode_latents and run_safety_checker back into __call__

* Remove unused imports

* Add intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

* Fix intermediate state tests for Stable Diffusion pipelines

Signed-off-by: James R T <[email protected]>

Signed-off-by: James R T <[email protected]>
  • Loading branch information
jamestiotio authored Oct 2, 2022
1 parent 026f309 commit 86ef695
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 12 deletions.
24 changes: 21 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import torch

Expand Down Expand Up @@ -122,6 +122,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Expand Down Expand Up @@ -159,6 +161,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -178,6 +186,14 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -277,14 +293,16 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
Expand Down
26 changes: 23 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -133,6 +133,9 @@ def __call__(
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -170,6 +173,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -188,6 +197,14 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -265,6 +282,7 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
Expand Down Expand Up @@ -295,14 +313,16 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)

Expand Down
27 changes: 24 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -149,6 +149,9 @@ def __call__(
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -190,6 +193,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -208,6 +217,14 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -297,7 +314,9 @@ def __call__(
extra_step_kwargs["eta"] = eta

latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
Expand Down Expand Up @@ -331,14 +350,16 @@ def __call__(

latents = (init_latents_proper * mask) + (latents * (1 - mask))

# scale and decode the image latents with vae
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)

Expand Down
20 changes: 17 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -56,6 +56,8 @@ def __call__(
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
if isinstance(prompt, str):
Expand All @@ -68,6 +70,14 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -151,14 +161,18 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = np.array(latents)

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)

Expand Down

0 comments on commit 86ef695

Please sign in to comment.