Skip to content

Commit

Permalink
[Pallas] Make gmm functional (#7117)
Browse files Browse the repository at this point in the history
Summary:
This should be the last PR needed to make gmm functional. Now, everything should be executed in the xla devices and produces the correct output.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 26, 2024
1 parent 7d31f7d commit a9b4fad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 33 deletions.
28 changes: 20 additions & 8 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram
from torch_xla import runtime as xr
from torch_xla._internal import tpu
Expand Down Expand Up @@ -98,6 +99,8 @@ def _init_test_cases(self):

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
met.clear_all()

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
Expand All @@ -110,20 +113,24 @@ def test_gmm(self):
lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(
m=m, num_groups=num_groups) # This is a cpu tensor!!!!!!!
m=m, num_groups=num_groups).to('xla')
out = gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
group_sizes.numpy())
group_sizes.cpu().numpy())

atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype)
np.testing.assert_allclose(
ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol)

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_make_group_metadata(self):
from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata
met.clear_all()

test_grids = [
{
Expand Down Expand Up @@ -173,15 +180,19 @@ def test_make_group_metadata(self):
)

torch_meta = _make_group_metadata(
group_sizes=torch.tensor(test_grid['group_sizes']),
group_sizes=torch.tensor(test_grid['group_sizes']).to("xla"),
m=test_grid['m'],
tm=test_grid['tm'],
)

for i in range(len(jax_meta)):
self.assertTrue(
torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i]))
self.assertEqual(jax_num_tiles, torch_meta[-1].item())
torch.all(
torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i].cpu()))
self.assertEqual(jax_num_tiles, torch_meta[-1].cpu().item())

# Make sure _make_group_metadata doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

def test_histogram(self):
test_grids = [
Expand Down Expand Up @@ -215,13 +226,13 @@ def test_histogram(self):
max=test_grid['max'],
)

chart, _ = _histogram(
chart = _histogram(
torch.tensor(test_grid['input'], dtype=torch.int32).to("xla"),
min=test_grid['min'],
max=test_grid['max'],
)

self.assertTrue(torch.all(torch_chart == chart.cpu()))
self.assertTrue(torch.all(torch_chart == chart.cpu()))

def test_histogram_raise(self):
with self.assertRaisesRegex(AssertionError,
Expand All @@ -232,7 +243,8 @@ def test_histogram_raise(self):
max=5,
)

with self.assertRaisesRegex(AssertionError, "min must be less than max."):
with self.assertRaisesRegex(AssertionError,
"min must be less than or equal to max."):
_histogram(
torch.tensor([1, 4, 4, 1, 2, 3], dtype=torch.int32),
min=4,
Expand Down
41 changes: 16 additions & 25 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,18 +501,18 @@ def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min < max, "min must be less than max."
assert min <= max, "min must be less than or equal to max."

def searchsorted(sorted_sequence: torch.Tensor,
values_to_search: torch.Tensor) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)

bin_edges = torch.linspace(
min, max, max - min + 1, dtype=input.dtype).to(input.device)
return searchsorted(bin_edges, input), bin_edges
return searchsorted(bin_edges, input)


# This can only be ran in cpu now as repeat_interleave is not lowered to xla.
# Refence: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L78
def _make_group_metadata(
*,
group_sizes: torch.Tensor,
Expand Down Expand Up @@ -544,6 +544,7 @@ def _make_group_metadata(
num_tiles: The number of m-dimension tiles to execute including overlapping
executions. And don't confuse this with m_tiles which is m // tm.
"""
device = group_sizes.device
num_groups = group_sizes.shape[0]

# Calculate the offset of each group, starting at zero. This metadata is
Expand All @@ -555,7 +556,8 @@ def _make_group_metadata(
#
# The row at which group 'i' starts is group_offsets[i].
group_ends = torch.cumsum(group_sizes, dim=0, dtype=torch.int32)
group_offsets = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends])
group_offsets = torch.cat(
[torch.zeros(1, dtype=torch.int32).to(device), group_ends])

# Assign a group id to each grid index.
#
Expand All @@ -571,7 +573,8 @@ def _make_group_metadata(
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).to(torch.int32)

# (2) Round the group_starts down to the nearest multiple of 'tm'.
group_starts = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends[:-1]])
group_starts = torch.cat(
[torch.zeros(1, dtype=torch.int32).to(device), group_ends[:-1]])
rounded_group_starts = group_starts // tm * tm

# (3) Calculate the number of rows in each group.
Expand Down Expand Up @@ -613,14 +616,9 @@ def _make_group_metadata(
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
group_ids = torch.repeat_interleave(
torch.arange(num_groups, dtype=torch.int32),
group_tiles,
)
group_ids = torch.nn.functional.pad(
group_ids, (0, tiles_m + num_groups - 1 - group_ids.shape[0]),
value=num_groups - 1)
group_ids = repeat_with_fixed_output_size(
torch.arange(num_groups, dtype=torch.int32).to(device), group_tiles,
tiles_m + num_groups - 1)

# Assign an m-dimension tile id to each grid index.
#
Expand Down Expand Up @@ -652,20 +650,13 @@ def _make_group_metadata(
partial_tile_ids = torch.where(partial_tile_mask, tiles_m,
group_offsets[:-1] // tm)

tile_visits = (
torch.histc(
partial_tile_ids.float(), bins=tiles_m, min=0, max=tiles_m - 1) + 1)
tile_visits = (_histogram(partial_tile_ids, min=0, max=tiles_m - 1) + 1)

# Create the m-dimension tile ids for each grid index based on the visit
# counts for each tile.
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
m_tile_ids = torch.repeat_interleave(
torch.arange(tiles_m, dtype=torch.int32),
tile_visits.type(torch.int32),
)
m_tile_ids = torch.nn.functional.pad(
m_tile_ids, (0, tiles_m + num_groups - 1 - m_tile_ids.shape[0]),
value=tiles_m - 1)
m_tile_ids = repeat_with_fixed_output_size(
torch.arange(tiles_m, dtype=torch.int32).to(device), tile_visits,
tiles_m + num_groups - 1)

num_tiles = group_tiles.sum(dtype=torch.int32)
return group_offsets, group_ids, m_tile_ids, num_tiles
Expand Down Expand Up @@ -706,7 +697,7 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
block_split_indicators = torch.zeros(
total_repeat_length, dtype=torch.int64, device=device)
block_split_indicators.scatter_add_(0, valid_indices,
block_split_indicators.scatter_add_(0, valid_indices.to(torch.int64),
torch.ones_like(block_split_indicators))
# out_of_bound indices also scatter to index 0, need to offset them
block_split_indicators[0] -= out_of_bound_count
Expand Down

0 comments on commit a9b4fad

Please sign in to comment.