Skip to content

Commit

Permalink
Remove unused parameters and fixed FutureWarning (huggingface#6317)
Browse files Browse the repository at this point in the history
* Remove unused parameters and fixed `FutureWarning`

* Fixed wrong config instance

* update unittest for `DDIMInverseScheduler`
  • Loading branch information
Justin900429 authored and adrvs committed Jan 2, 2024
1 parent 5b42010 commit fc7f072
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
5 changes: 1 addition & 4 deletions src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,6 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -332,7 +329,7 @@ def step(
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
)

# 2. compute alphas, betas
Expand Down
6 changes: 3 additions & 3 deletions tests/schedulers/test_scheduler_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class DDIMInverseSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMInverseScheduler,)
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
forward_default_kwargs = (("num_inference_steps", 50),)

def get_scheduler_config(self, **kwargs):
config = {
Expand All @@ -26,7 +26,7 @@ def full_loop(self, **config):
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

num_inference_steps, eta = 10, 0.0
num_inference_steps = 10

model = self.dummy_model()
sample = self.dummy_sample_deter
Expand All @@ -35,7 +35,7 @@ def full_loop(self, **config):

for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
sample = scheduler.step(residual, t, sample).prev_sample

return sample

Expand Down

0 comments on commit fc7f072

Please sign in to comment.