Skip to content

Commit

Permalink
Optimize Stable Diffusion (huggingface#371)
Browse files Browse the repository at this point in the history
* initial commit

* make UNet stream capturable

* try to fix noise_pred value

* remove cuda graph and keep NB

* non blocking unet with PNDMScheduler

* make timesteps np arrays for pndm scheduler
because lists don't get formatted to tensors in `self.set_format`

* make max async in pndm

* use channel last format in unet

* avoid moving timesteps device in each unet call

* avoid memcpy op in `get_timestep_embedding`

* add `channels_last` kwarg to `DiffusionPipeline.from_pretrained`

* update TODO

* replace `channels_last` kwarg with `memory_format` for more generality

* revert the channels_last changes to leave it for another PR

* remove non_blocking when moving input ids to device

* remove blocking from all .to() operations at beginning of pipeline

* fix merging

* fix merging

* model can run in other precisions without autocast

* attn refactoring

* Revert "attn refactoring"

This reverts commit 0c70c0e.

* remove restriction to run conv_norm in fp32

* use `baddbmm` instead of `matmul`for better in attention for better perf

* removing all reshapes to test perf

* Revert "removing all reshapes to test perf"

This reverts commit 006ccb8.

* add shapes comments

* hardcore whats needed for jitting

* Revert "hardcore whats needed for jitting"

This reverts commit 2fa9c69.

* Revert "remove restriction to run conv_norm in fp32"

This reverts commit cec5928.

* revert using baddmm in attention's forward

* cleanup comment

* remove restriction to run conv_norm in fp32. no quality loss was noticed

This reverts commit cc9bc13.

* add more optimizations techniques to docs

* Revert "add shapes comments"

This reverts commit 31c58ea.

* apply suggestions

* make quality

* apply suggestions

* styling

* `scheduler.timesteps` are now arrays so we dont need .to()

* remove useless .type()

* use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms`

* move scheduler timestamps to correct device if tensors

* add device to `set_timesteps` in LMSD scheduler

* `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it

* quick fix

* styling

* remove kwargs from schedulers `set_timesteps`

* revert to using max in K-LMS inpaint pipeline test

* Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it"

This reverts commit 00d5a51.

* move timesteps to correct device before loop in SD pipeline

* apply previous fix to other SD pipelines

* UNet now accepts tensor timesteps even on wrong device, to avoid errors
- it shouldnt affect performance if timesteps are alrdy on correct device
- it does slow down performance if they're on the wrong device

* fix pipeline when timesteps are arrays with strides
  • Loading branch information
NouamaneTazi authored Sep 30, 2022
1 parent 26952e3 commit 88cb510
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 20 deletions.
15 changes: 11 additions & 4 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def forward(self, hidden_states):

# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))

attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)

# compute attention output
Expand Down Expand Up @@ -275,7 +274,13 @@ def forward(self, hidden_states, context=None, mask=None):
return self.to_out(hidden_states)

def _attention(self, query, key, value):
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
Expand All @@ -292,7 +297,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])

Expand Down
6 changes: 4 additions & 2 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)

emb = torch.exp(exponent).to(device=timesteps.device)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]

# scale embeddings
Expand Down
4 changes: 2 additions & 2 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)

if self.upsample is not None:
Expand All @@ -349,7 +349,7 @@ def forward(self, x, temb):

# make sure hidden states is in float32
# when running in half-precision
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
Expand Down
8 changes: 4 additions & 4 deletions models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps.to(dtype=torch.float32)
timesteps = timesteps[None].to(device=sample.device)
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb.to(self.dtype))

# 2. pre-process
sample = self.conv_in(sample)
Expand Down Expand Up @@ -279,7 +279,7 @@ def forward(
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

Expand Down
16 changes: 13 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,23 @@ def __call__(
latents_shape,
generator=generator,
device=latents_device,
dtype=text_embeddings.dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(latents_device)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
if torch.is_tensor(self.scheduler.timesteps):
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
Expand All @@ -247,7 +255,7 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
Expand Down Expand Up @@ -278,7 +286,9 @@ def __call__(

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
t_index = t_start + i

# expand the latents if we are doing classifier free guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ def __call__(

latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)

for i, t in tqdm(enumerate(timesteps_tensor)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down
8 changes: 5 additions & 3 deletions schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,24 @@ def lms_derivative(tau):

return integrated_coeff

def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps

timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)

self.derivatives = []

Expand Down

0 comments on commit 88cb510

Please sign in to comment.