Skip to content

Commit

Permalink
[Pallas] Make gmm output a tensor (#7120)
Browse files Browse the repository at this point in the history
Summary:
Somehow gmm outputs a list intead of a tensor. Let's output a tensor.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 28, 2024
1 parent 1a8c2fe commit 65b5ace
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_gmm(self):

atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype)
np.testing.assert_allclose(
ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol)
ref_out, np.array(out.cpu()), rtol=rtol, atol=atol)

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_offsets.to("xla"),
group_ids.to("xla"),
m_tile_ids.to("xla"), group_offset_torch, lhs, rhs
], payload, [torch.Size([m, n])], [lhs.dtype])
], payload, [torch.Size([m, n])], [lhs.dtype])[0]


def non_xla_attetion(q, k, v, attention_type):
Expand Down

0 comments on commit 65b5ace

Please sign in to comment.