diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md index 690d86c17a54e..ab6fb3779b708 100644 --- a/docs/source/en/using-diffusers/callback.md +++ b/docs/source/en/using-diffusers/callback.md @@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point! + + +## Using Callbacks to interrupt the Diffusion Process + +The following Pipelines support interrupting the diffusion process via callback + +- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md) +- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md) +- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md) +- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) +- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) +- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) + +Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback. + +This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback. + +In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50. + +```python +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipe.enable_model_cpu_offload() +num_inference_steps = 50 + +def interrupt_callback(pipe, i, t, callback_kwargs): + stop_idx = 10 + if i == stop_idx: + pipe._interrupt = True + + return callback_kwargs + +pipe( + "A photo of a cat", + num_inference_steps=num_inference_steps, + callback_on_step_end=interrupt_callback, +) +``` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b05d0b17dd5ae..dc4ad60ce0910 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -768,6 +768,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -909,6 +913,7 @@ def __call__( self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -986,6 +991,9 @@ def __call__( self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index d2538749f3beb..45dbd1128df09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -832,6 +832,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -963,6 +967,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1041,6 +1046,9 @@ def __call__( self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index bc6c65f4a6546..3733102f365a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -958,6 +958,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() def __call__( self, @@ -1144,6 +1148,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1288,6 +1293,9 @@ def __call__( self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 569668a1686d5..f9bafc973307f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -849,6 +849,10 @@ def denoising_end(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1067,6 +1071,7 @@ def __call__( self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1196,6 +1201,9 @@ def __call__( self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 4f75ce6878ad6..1c22affba1aa8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -990,6 +990,10 @@ def denoising_start(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1221,6 +1225,7 @@ def __call__( self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1376,6 +1381,9 @@ def denoising_value_valid(dnv): self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 751823ea4b101..2f02a213b8947 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1210,6 +1210,10 @@ def denoising_start(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1462,6 +1466,7 @@ def __call__( self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1684,6 +1689,8 @@ def denoising_value_valid(dnv): self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index ac105d22fa82b..8854b482dec77 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -692,6 +692,58 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + num_inference_steps = 3 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index fb56d868f1ccc..cd69b56e024fb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -320,6 +320,62 @@ def test_inference_batch_single_identical(self): def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = "hey" + num_inference_steps = 3 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + image=inputs["image"], + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + image=inputs["image"], + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index a69edb8696412..fe664b21e2712 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -319,6 +319,64 @@ def test_stable_diffusion_inpaint_mask_latents(self): out_1 = sd_pipe(**inputs).images assert np.abs(out_0 - out_1).max() < 1e-2 + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = "hey" + num_inference_steps = 3 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + image=inputs["image"], + mask_image=inputs["mask_image"], + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + image=inputs["image"], + mask_image=inputs["mask_image"], + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): pipeline_class = StableDiffusionInpaintPipeline diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 280030d94b7c3..80bff3663a98a 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -969,6 +969,58 @@ def test_stable_diffusion_xl_with_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + num_inference_steps = 3 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 7cad3fff0a475..0a7d4d0de4caf 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -439,6 +439,64 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self): > 1e-4 ) + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = "hey" + num_inference_steps = 5 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + image=inputs["image"], + strength=0.8, + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + image=inputs["image"], + strength=0.8, + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4) + class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 4a2798b3edf4e..27fb224fb07a2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -746,3 +746,63 @@ def test_stable_diffusion_xl_inpaint_2_images(self): image_slice1 = images[0, -3:, -3:, -1] image_slice2 = images[1, -3:, -3:, -1] assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2 + + def test_pipeline_interrupt(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = "hey" + num_inference_steps = 5 + + # store intermediate latents from the generation process + class PipelineState: + def __init__(self): + self.state = [] + + def apply(self, pipe, i, t, callback_kwargs): + self.state.append(callback_kwargs["latents"]) + return callback_kwargs + + pipe_state = PipelineState() + sd_pipe( + prompt, + image=inputs["image"], + mask_image=inputs["mask_image"], + strength=0.8, + num_inference_steps=num_inference_steps, + output_type="np", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=pipe_state.apply, + ).images + + # interrupt generation at step index + interrupt_step_idx = 1 + + def callback_on_step_end(pipe, i, t, callback_kwargs): + if i == interrupt_step_idx: + pipe._interrupt = True + + return callback_kwargs + + output_interrupted = sd_pipe( + prompt, + image=inputs["image"], + mask_image=inputs["mask_image"], + strength=0.8, + num_inference_steps=num_inference_steps, + output_type="latent", + generator=torch.Generator("cpu").manual_seed(0), + callback_on_step_end=callback_on_step_end, + ).images + + # fetch intermediate latents at the interrupted step + # from the completed generation process + intermediate_latent = pipe_state.state[interrupt_step_idx] + + # compare the intermediate latent to the output of the interrupted process + # they should be the same + assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)