diff --git a/schedulers/scheduling_deis_multistep.py b/schedulers/scheduling_deis_multistep.py index 7cf6a9b33b37..bd44d2444154 100644 --- a/schedulers/scheduling_deis_multistep.py +++ b/schedulers/scheduling_deis_multistep.py @@ -734,7 +734,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/schedulers/scheduling_dpmsolver_multistep.py b/schedulers/scheduling_dpmsolver_multistep.py index beab985e3350..086505c5052b 100644 --- a/schedulers/scheduling_dpmsolver_multistep.py +++ b/schedulers/scheduling_dpmsolver_multistep.py @@ -896,7 +896,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/schedulers/scheduling_dpmsolver_multistep_inverse.py b/schedulers/scheduling_dpmsolver_multistep_inverse.py index 61d6810ce286..cfb53c943cea 100644 --- a/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -891,7 +891,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/schedulers/scheduling_dpmsolver_singlestep.py b/schedulers/scheduling_dpmsolver_singlestep.py index 0f1175472f3e..7e8149ab55c4 100644 --- a/schedulers/scheduling_dpmsolver_singlestep.py +++ b/schedulers/scheduling_dpmsolver_singlestep.py @@ -897,7 +897,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/schedulers/scheduling_unipc_multistep.py b/schedulers/scheduling_unipc_multistep.py index 1d58ab5259ef..eaa6273e2768 100644 --- a/schedulers/scheduling_unipc_multistep.py +++ b/schedulers/scheduling_unipc_multistep.py @@ -828,7 +828,16 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [] + for timestep in timesteps: + index_candidates = (schedule_timesteps == timestep).nonzero() + if len(index_candidates) == 0: + step_index = len(schedule_timesteps) - 1 + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape):