Skip to content

Commit

Permalink
Flax: add shape argument to set_timesteps (huggingface#690)
Browse files Browse the repository at this point in the history
* Flax: add shape argument to set_timesteps

* style
  • Loading branch information
pcuenca authored Oct 3, 2022
1 parent 86ef695 commit 5097037
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):

return variance

def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
2 changes: 1 addition & 1 deletion schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(

self.variance_type = variance_type

def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
4 changes: 3 additions & 1 deletion schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def __init__(
):
self.state = KarrasVeSchedulerState.create()

def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
4 changes: 3 additions & 1 deletion schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def lms_derivative(tau):

return integrated_coeff

def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
2 changes: 1 addition & 1 deletion schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)

def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
2 changes: 1 addition & 1 deletion schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)

def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down

0 comments on commit 5097037

Please sign in to comment.