From c4d20110774dbcc6ada7e62b290500f8ad777fba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Tue, 13 Feb 2024 02:53:22 -0600 Subject: [PATCH] Add yield callback to prior pipeline --- .../stable_cascade/pipeline_stable_cascade_prior.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 2a5cb49bcd84..0d584012dc6b 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -615,7 +615,8 @@ def __call__( if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + r = callback(step_idx, t, latents) + yield r # Offload all models self.maybe_free_model_hooks() @@ -626,4 +627,4 @@ def __call__( if not return_dict: return (latents,) - return WuerstchenPriorPipelineOutput(latents) + yield WuerstchenPriorPipelineOutput(latents)