Skip to content

Commit

Permalink
Fix slow tests (huggingface#689)
Browse files Browse the repository at this point in the history
* revert using baddbmm in attention
- to fix `test_stable_diffusion_memory_chunking` test

* styling
  • Loading branch information
NouamaneTazi authored Sep 30, 2022
1 parent a0e4da6 commit 026f309
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,8 @@ def forward(self, hidden_states, context=None, mask=None):
return self.to_out(hidden_states)

def _attention(self, query, key, value):
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
Expand Down

0 comments on commit 026f309

Please sign in to comment.