diff --git a/test/test_megablox.py b/test/test_gmm.py similarity index 70% rename from test/test_megablox.py rename to test/test_gmm.py index b36fa414af2..cd247d6250a 100644 --- a/test/test_megablox.py +++ b/test/test_gmm.py @@ -1,5 +1,3 @@ -"""Grouped matrix multiplication kernels for TPU written in Pallas.""" - import logging import unittest @@ -8,7 +6,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.experimental.custom_kernel import gmm +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -123,6 +121,68 @@ def test_gmm(self): np.testing.assert_allclose( ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_make_group_metadata(self): + from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata + + test_grids = [ + { + 'group_sizes': [8, 8, 8, 8], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [2, 14, 8, 8], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [16, 0, 8, 8], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [2, 0, 14, 16], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [8, 12, 0, 12], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [6, 12, 0, 14], + 'm': 32, + 'tm': 8 + }, + { + 'group_sizes': [6, 12, 0, 14], + 'm': 32, + 'tm': 4 + }, + ] + + for test_grid in test_grids: + jax_meta, jax_num_tiles = jax_make_group_metadata( + group_sizes=jnp.array(test_grid['group_sizes']), + m=test_grid['m'], + tm=test_grid['tm'], + start_group=0, + num_nonzero_groups=len(test_grid['group_sizes']), + ) + + torch_meta = _make_group_metadata( + group_sizes=torch.tensor(test_grid['group_sizes']), + m=test_grid['m'], + tm=test_grid['tm'], + ) + + for i in range(len(jax_meta)): + self.assertTrue( + torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i])) + self.assertEqual(jax_num_tiles, torch_meta[-1].item()) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 5b89abe228c..fc4024a462c 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -24,7 +24,7 @@ python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_input_output_aliases.py -python3 test/test_megablox.py +python3 test/test_gmm.py python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 067459e424d..eaaeccb5c59 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -8,7 +8,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs -from typing import List, Callable +from typing import Any, List, Callable from torch.library import impl from torch_xla.core.xla_model import XLA_LIB @@ -489,6 +489,172 @@ def paged_attention(q, return output.reshape(batch_size, num_heads, head_dim).to(q.dtype) +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +# This can only be ran in cpu now as repeat_interleave is not lowered to xla. +def _make_group_metadata( + *, + group_sizes: torch.Tensor, + m: int, + tm: int, + visit_empty_groups: bool = True, +) -> Any: + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups + 1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute including overlapping + executions. And don't confuse this with m_tiles which is m // tm. + """ + num_groups = group_sizes.shape[0] + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = torch.cumsum(group_sizes, dim=0, dtype=torch.int32) + group_offsets = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends]) + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).to(torch.int32) + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends[:-1]]) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = torch.where(group_sizes == 0, 0, rounded_group_sizes) + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + # Insert one tile for empty groups. + group_tiles = torch.where(group_sizes == 0, 1, group_tiles) + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + # TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times. + group_ids = torch.repeat_interleave( + torch.arange(num_groups, dtype=torch.int32), + group_tiles, + ) + group_ids = torch.nn.functional.pad( + group_ids, (0, tiles_m + num_groups - 1 - group_ids.shape[0]), + value=num_groups - 1) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = torch.logical_or((group_offsets[:-1] % tm) == 0, + group_sizes == 0) + + # Explicitly enable tiles for zero sized groups, if specified. This covers + # zero sized groups that start on a tile-aligned row and those that do not. + if visit_empty_groups: + partial_tile_mask = torch.where(group_sizes == 0, False, partial_tile_mask) + + partial_tile_ids = torch.where(partial_tile_mask, tiles_m, + group_offsets[:-1] // tm) + + tile_visits = ( + torch.histc( + partial_tile_ids.float(), bins=tiles_m, min=0, max=tiles_m - 1) + 1) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + # TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times. + m_tile_ids = torch.repeat_interleave( + torch.arange(tiles_m, dtype=torch.int32), + tile_visits.type(torch.int32), + ) + m_tile_ids = torch.nn.functional.pad( + m_tile_ids, (0, tiles_m + num_groups - 1 - m_tile_ids.shape[0]), + value=tiles_m - 1) + + num_tiles = group_tiles.sum(dtype=torch.int32) + return group_offsets, group_ids, m_tile_ids, num_tiles + + def gmm(lhs: torch.Tensor, rhs: torch.Tensor, group_sizes: torch.Tensor) -> torch.Tensor: """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. @@ -505,9 +671,7 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor, # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() - import jax - import jax.numpy as jnp - from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm, make_group_metadata + from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes) @@ -516,25 +680,18 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor, # TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor. # That means we need to materialize this input in order to use this gmm # kernel, and that will introduce graph breaks in the computation. - group_sizes = jnp.asarray(group_sizes.numpy()) - group_metadata, num_active_tiles = make_group_metadata( + group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata( group_sizes=group_sizes, m=lhs.shape[0], - tm=128, - start_group=0, - num_nonzero_groups=rhs.shape[0], - visit_empty_groups=False, + tm=128 # TODO (alanwaketan): Tune this later. ) - group_metadata0 = torch.from_numpy(np.array(group_metadata[0])).to( - torch.int32).to("xla") - group_metadata1 = torch.from_numpy(np.array(group_metadata[1])).to("xla") - group_metadata2 = torch.from_numpy(np.array(group_metadata[2])).to("xla") - num_active_tiles = torch.tensor(np.array(num_active_tiles)).to("xla") group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla") return torch_xla._XLAC._xla_tpu_custom_call([ - num_active_tiles, group_metadata0, group_metadata1, group_metadata2, - group_offset_torch, lhs, rhs + num_tiles.to("xla"), + group_offsets.to("xla"), + group_ids.to("xla"), + m_tile_ids.to("xla"), group_offset_torch, lhs, rhs ], payload, [torch.Size([m, n])], [lhs.dtype])