Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sigmas] Keep sigmas on CPU #6173

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
Expand Down Expand Up @@ -230,6 +231,7 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device=device)

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def step_index(self):
Expand Down Expand Up @@ -254,6 +255,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def step_index(self):
Expand Down Expand Up @@ -290,6 +291,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas

@property
Expand Down Expand Up @@ -289,6 +290,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
Expand Down Expand Up @@ -347,6 +348,7 @@ def set_timesteps(
self.mid_point_sigma = None

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None

# for exp beta schedules, such as the one for `pipeline_shap_e.py`
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def get_order_list(self, num_inference_steps: int) -> List[int]:
"""
Expand Down Expand Up @@ -288,6 +289,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
self.is_scale_input_called = False

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def init_noise_sigma(self):
Expand Down Expand Up @@ -249,6 +250,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def __init__(
self.use_karras_sigmas = use_karras_sigmas

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def init_noise_sigma(self):
Expand Down Expand Up @@ -341,6 +342,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_heun_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def __init__(
self.use_karras_sigmas = use_karras_sigmas

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
Expand Down Expand Up @@ -269,6 +270,7 @@ def set_timesteps(
self.dt = None

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
Expand Down Expand Up @@ -295,6 +296,7 @@ def set_timesteps(
self._index_counter = defaultdict(int)

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
Expand Down Expand Up @@ -284,6 +285,7 @@ def set_timesteps(
self._index_counter = defaultdict(int)

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def state_in_first_order(self):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
self.is_scale_input_called = False

self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def init_noise_sigma(self):
Expand Down Expand Up @@ -279,6 +280,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

self.derivatives = []

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def step_index(self):
Expand Down Expand Up @@ -268,6 +269,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down
Loading