Skip to content

Commit

Permalink
[Pallas] Introduce _make_group_metadata (#7107)
Browse files Browse the repository at this point in the history
Summary:
_make_group_metadata is a helper function to assist gmm. Before we use the JAX version that cannot be stitched to our HLO. Now with this new torch version, it allows us to contain it in our HLO. However, we still need to lower two ops: pytorch.org/docs/stable/generated/torch.repeat_interleave.html and pytorch.org/docs/stable/generated/torch.histc.html. Yet we need the JAX version of the repeat to make the op shape static: jax.readthedocs.io/en/latest/_autosummary/jax.numpy.repeat.html.

Test Plan:
python test/test_gmm.py
  • Loading branch information
alanwaketan authored May 24, 2024
1 parent cb8533b commit 22e912e
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 21 deletions.
66 changes: 63 additions & 3 deletions test/test_megablox.py → test/test_gmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Grouped matrix multiplication kernels for TPU written in Pallas."""

import logging
import unittest

Expand All @@ -8,7 +6,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.custom_kernel import gmm
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -123,6 +121,68 @@ def test_gmm(self):
np.testing.assert_allclose(
ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_make_group_metadata(self):
from jax.experimental.pallas.ops.tpu.megablox.gmm import make_group_metadata as jax_make_group_metadata

test_grids = [
{
'group_sizes': [8, 8, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [2, 14, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [16, 0, 8, 8],
'm': 32,
'tm': 8
},
{
'group_sizes': [2, 0, 14, 16],
'm': 32,
'tm': 8
},
{
'group_sizes': [8, 12, 0, 12],
'm': 32,
'tm': 8
},
{
'group_sizes': [6, 12, 0, 14],
'm': 32,
'tm': 8
},
{
'group_sizes': [6, 12, 0, 14],
'm': 32,
'tm': 4
},
]

for test_grid in test_grids:
jax_meta, jax_num_tiles = jax_make_group_metadata(
group_sizes=jnp.array(test_grid['group_sizes']),
m=test_grid['m'],
tm=test_grid['tm'],
start_group=0,
num_nonzero_groups=len(test_grid['group_sizes']),
)

torch_meta = _make_group_metadata(
group_sizes=torch.tensor(test_grid['group_sizes']),
m=test_grid['m'],
tm=test_grid['tm'],
)

for i in range(len(jax_meta)):
self.assertTrue(
torch.all(torch.from_numpy(np.array(jax_meta[i])) == torch_meta[i]))
self.assertEqual(jax_num_tiles, torch_meta[-1].item())


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
2 changes: 1 addition & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
python3 test/test_megablox.py
python3 test/test_gmm.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand Down
191 changes: 174 additions & 17 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs

from typing import List, Callable
from typing import Any, List, Callable
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

Expand Down Expand Up @@ -489,6 +489,172 @@ def paged_attention(q,
return output.reshape(batch_size, num_heads, head_dim).to(q.dtype)


def _calculate_num_tiles(x: int, tx: int) -> int:
tiles, rem = divmod(x, tx)
if rem:
raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).")
return tiles


# This can only be ran in cpu now as repeat_interleave is not lowered to xla.
def _make_group_metadata(
*,
group_sizes: torch.Tensor,
m: int,
tm: int,
visit_empty_groups: bool = True,
) -> Any:
"""Create the metadata needed for grouped matmul computation.
Args:
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
m: The number of rows in lhs.
tm: The m-dimension tile size being used.
visit_empty_groups: If True, do not squeeze tiles for empty groups out of
the metadata. This is necessary for tgmm, where we at least need to zero
the output for each group.
Returns:
tuple of:
group_offsets: A 1d, jnp.ndarray with shape [num_groups + 1] and jnp.int32
dtype. group_offsets[i] indicates the row at which group [i] starts in
the lhs matrix and group_offsets[i-1] = m.
group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and
jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
work on.
m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups - 1] and
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
will work on.
num_tiles: The number of m-dimension tiles to execute including overlapping
executions. And don't confuse this with m_tiles which is m // tm.
"""
num_groups = group_sizes.shape[0]

# Calculate the offset of each group, starting at zero. This metadata is
# similar to row offsets in a CSR matrix. The following properties hold:
#
# group_offsets.shape = [num_groups + 1]
# group_offsets[0] = 0
# group_offsets[num_groups] = m
#
# The row at which group 'i' starts is group_offsets[i].
group_ends = torch.cumsum(group_sizes, dim=0, dtype=torch.int32)
group_offsets = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends])

# Assign a group id to each grid index.
#
# If a group starts somewhere other than the start of a tile or ends somewhere
# other than the end of a tile we need to compute that full tile. Calculate
# the number of tiles for each group by rounding their end up to the nearest
# 'tm' and their start down to the nearest 'tm'.

# (1) Round the group_ends up to the nearest multiple of 'tm'.
#
# NOTE: This does not change group_offsets[num_groups], which is m
# (because we enforce m is divisible by tm).
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).to(torch.int32)

# (2) Round the group_starts down to the nearest multiple of 'tm'.
group_starts = torch.cat([torch.zeros(1, dtype=torch.int32), group_ends[:-1]])
rounded_group_starts = group_starts // tm * tm

# (3) Calculate the number of rows in each group.
#
# NOTE: Handle zero-sized groups as a special case. If the start for a
# zero-sized group is not divisible by 'tm' its start will be rounded down and
# its end will be rounded up such that its size will become 1 tile here.
rounded_group_sizes = rounded_group_ends - rounded_group_starts
rounded_group_sizes = torch.where(group_sizes == 0, 0, rounded_group_sizes)

# (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles.
#
# An m-dimension tile is 'owned' by group 'i' if the first row of the tile
# belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1
# initial partial tiles if it's first row does not occur in the first row of a
# tile. The '0-th' group never has a partial tile because it always starts at
# the 0-th row.
#
# If no group has a partial tile, the total number of tiles is equal to
# 'm // tm'. If every group has a partial except the 0-th group, the total
# number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that
#
# tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1
#
# Where tiles_m = m // tm.
#
# NOTE: All group sizes are divisible by 'tm' because of the rounding in steps
# (1) and (2) so this division is exact.
group_tiles = rounded_group_sizes // tm

if visit_empty_groups:
# Insert one tile for empty groups.
group_tiles = torch.where(group_sizes == 0, 1, group_tiles)

# Create the group ids for each grid index based on the tile counts for each
# group.
#
# NOTE: This repeat(...) will pad group_ids with the final group id if
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
group_ids = torch.repeat_interleave(
torch.arange(num_groups, dtype=torch.int32),
group_tiles,
)
group_ids = torch.nn.functional.pad(
group_ids, (0, tiles_m + num_groups - 1 - group_ids.shape[0]),
value=num_groups - 1)

# Assign an m-dimension tile id to each grid index.
#
# NOTE: Output tiles can only be re-visited consecutively. The following
# procedure guarantees that m-dimension tile indices respect this.

# (1) Calculate how many times each m-dimension tile will be visited.
#
# Each tile is guaranteed to be visited once by the group that owns the tile.
# The remaining possible visits occur when a group starts inside of a tile at
# a position other than the first row. We can calculate which m-dimension tile
# each group starts in by floor-dividing its offset with `tm` and then count
# tile visits with a histogram.
#
# To avoid double counting tile visits from the group that owns the tile,
# filter these out by assigning their tile id to `tile_m` (one beyond the max)
# such that they're ignored by the subsequent histogram. Also filter out any
# group which is empty.
#
# TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear.
partial_tile_mask = torch.logical_or((group_offsets[:-1] % tm) == 0,
group_sizes == 0)

# Explicitly enable tiles for zero sized groups, if specified. This covers
# zero sized groups that start on a tile-aligned row and those that do not.
if visit_empty_groups:
partial_tile_mask = torch.where(group_sizes == 0, False, partial_tile_mask)

partial_tile_ids = torch.where(partial_tile_mask, tiles_m,
group_offsets[:-1] // tm)

tile_visits = (
torch.histc(
partial_tile_ids.float(), bins=tiles_m, min=0, max=tiles_m - 1) + 1)

# Create the m-dimension tile ids for each grid index based on the visit
# counts for each tile.
# TODO (alanwaketan): lower jax's version of repeat. This dynamism will force us to compile many times.
m_tile_ids = torch.repeat_interleave(
torch.arange(tiles_m, dtype=torch.int32),
tile_visits.type(torch.int32),
)
m_tile_ids = torch.nn.functional.pad(
m_tile_ids, (0, tiles_m + num_groups - 1 - m_tile_ids.shape[0]),
value=tiles_m - 1)

num_tiles = group_tiles.sum(dtype=torch.int32)
return group_offsets, group_ids, m_tile_ids, num_tiles


def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Expand All @@ -505,9 +671,7 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm, make_group_metadata
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm

payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes)

Expand All @@ -516,25 +680,18 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
# TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor.
# That means we need to materialize this input in order to use this gmm
# kernel, and that will introduce graph breaks in the computation.
group_sizes = jnp.asarray(group_sizes.numpy())
group_metadata, num_active_tiles = make_group_metadata(
group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata(
group_sizes=group_sizes,
m=lhs.shape[0],
tm=128,
start_group=0,
num_nonzero_groups=rhs.shape[0],
visit_empty_groups=False,
tm=128 # TODO (alanwaketan): Tune this later.
)
group_metadata0 = torch.from_numpy(np.array(group_metadata[0])).to(
torch.int32).to("xla")
group_metadata1 = torch.from_numpy(np.array(group_metadata[1])).to("xla")
group_metadata2 = torch.from_numpy(np.array(group_metadata[2])).to("xla")
num_active_tiles = torch.tensor(np.array(num_active_tiles)).to("xla")
group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")

return torch_xla._XLAC._xla_tpu_custom_call([
num_active_tiles, group_metadata0, group_metadata1, group_metadata2,
group_offset_torch, lhs, rhs
num_tiles.to("xla"),
group_offsets.to("xla"),
group_ids.to("xla"),
m_tile_ids.to("xla"), group_offset_torch, lhs, rhs
], payload, [torch.Size([m, n])], [lhs.dtype])


Expand Down

0 comments on commit 22e912e

Please sign in to comment.