Skip to content

Commit

Permalink
Interruptable Pipelines (huggingface#5867)
Browse files Browse the repository at this point in the history
* add interruptable pipelines

* add tests

* updatemsmq

* add interrupt property

* make fix copies

* Revert "make fix copies"

This reverts commit 914b353.

* add docs

* add tutorial

* Update docs/source/en/tutorials/interrupting_diffusion_process.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/tutorials/interrupting_diffusion_process.md

Co-authored-by: Steven Liu <[email protected]>

* update

* fix quality issues

* fix

* update

---------

Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
4 people authored and donhardman committed Dec 29, 2023
1 parent cd088a4 commit 00eff83
Show file tree
Hide file tree
Showing 13 changed files with 422 additions and 0 deletions.
39 changes: 39 additions & 0 deletions docs/source/en/using-diffusers/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!

</Tip>


## Using Callbacks to interrupt the Diffusion Process

The following Pipelines support interrupting the diffusion process via callback

- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)

Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.

This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.

In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.

```python
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50

def interrupt_callback(pipe, i, t, callback_kwargs):
stop_idx = 10
if i == stop_idx:
pipe._interrupt = True

return callback_kwargs

pipe(
"A photo of a cat",
num_inference_steps=num_inference_steps,
callback_on_step_end=interrupt_callback,
)
```
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -909,6 +913,7 @@ def __call__(
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -986,6 +991,9 @@ def __call__(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -963,6 +967,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1041,6 +1046,9 @@ def __call__(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
def __call__(
self,
Expand Down Expand Up @@ -1144,6 +1148,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1288,6 +1293,9 @@ def __call__(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,10 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1067,6 +1071,7 @@ def __call__(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1196,6 +1201,9 @@ def __call__(
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,10 @@ def denoising_start(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1221,6 +1225,7 @@ def __call__(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1376,6 +1381,9 @@ def denoising_value_valid(dnv):
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,10 @@ def denoising_start(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1462,6 +1466,7 @@ def __call__(
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1684,6 +1689,8 @@ def denoising_value_valid(dnv):
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

Expand Down
52 changes: 52 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,58 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."

def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "hey"
num_inference_steps = 3

# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []

def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs

pipe_state = PipelineState()
sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images

# interrupt generation at step index
interrupt_step_idx = 1

def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True

return callback_kwargs

output_interrupted = sd_pipe(
prompt,
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images

# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]

# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)


@slow
@require_torch_gpu
Expand Down
56 changes: 56 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,62 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)

def test_pipeline_interrupt(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)

prompt = "hey"
num_inference_steps = 3

# store intermediate latents from the generation process
class PipelineState:
def __init__(self):
self.state = []

def apply(self, pipe, i, t, callback_kwargs):
self.state.append(callback_kwargs["latents"])
return callback_kwargs

pipe_state = PipelineState()
sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="np",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=pipe_state.apply,
).images

# interrupt generation at step index
interrupt_step_idx = 1

def callback_on_step_end(pipe, i, t, callback_kwargs):
if i == interrupt_step_idx:
pipe._interrupt = True

return callback_kwargs

output_interrupted = sd_pipe(
prompt,
image=inputs["image"],
num_inference_steps=num_inference_steps,
output_type="latent",
generator=torch.Generator("cpu").manual_seed(0),
callback_on_step_end=callback_on_step_end,
).images

# fetch intermediate latents at the interrupted step
# from the completed generation process
intermediate_latent = pipe_state.state[interrupt_step_idx]

# compare the intermediate latent to the output of the interrupted process
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)


@slow
@require_torch_gpu
Expand Down
Loading

0 comments on commit 00eff83

Please sign in to comment.