diff --git a/models/attention.py b/models/attention.py index b4e5f2e07f7d..c2f27bd9282d 100644 --- a/models/attention.py +++ b/models/attention.py @@ -274,13 +274,8 @@ def forward(self, hidden_states, context=None, mask=None): return self.to_out(hidden_states) def _attention(self, query, key, value): - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value)