diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index dcf790186..5d5ccc141 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -526,7 +526,9 @@ def sample(self, ) elif sampler == "ddim": - samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, + samples = self.ddim_sampling(x_latent, conditioning, S, + callback=callback, img_callback=img_callback, + unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, mask = mask,init_latent=x_T,use_original_steps=False) @@ -687,7 +689,8 @@ def add_noise(self, x0, t): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None): timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] @@ -707,10 +710,12 @@ def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1. x0_noisy = x0 x_dec = x0_noisy* mask + (1. - mask) * x_dec - x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + x_dec, pred_x0 = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + if mask is not None: return x0 * mask + (1. - mask) * x_dec @@ -756,7 +761,7 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev + return x_prev, pred_x0 def append_zero(self, x):