Skip to content

Commit

Permalink
Update custom diffusion attn processor (huggingface#5663)
Browse files Browse the repository at this point in the history
update custom diffusion attn processor
  • Loading branch information
DN6 authored and kashif committed Nov 11, 2023
1 parent 0114565 commit 75116b9
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,7 @@ def __call__(
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states


Expand Down Expand Up @@ -1433,8 +1434,11 @@ def __call__(
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states)
value = self.to_v_custom_diffusion(encoder_hidden_states)
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)

else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down

0 comments on commit 75116b9

Please sign in to comment.