Skip to content

Commit

Permalink
[Pallas] Refactor the gmm kernel (#7099)
Browse files Browse the repository at this point in the history
Summary:
This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay...

Test Plan:
python test/test_megablox.py
  • Loading branch information
alanwaketan authored and qihqi committed May 29, 2024
1 parent 429b507 commit a1e5f4c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 449 deletions.
35 changes: 4 additions & 31 deletions test/test_megablox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.megablox as megablox
from torch_xla.experimental.custom_kernel import gmm
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -97,34 +97,6 @@ def _init_test_cases(self):
'n': 256,
'num_groups': 2
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 128,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 256,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 128,
'k': 256,
'n': 128,
'num_groups': 8
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 512,
'k': 128,
'n': 256,
'num_groups': 2
})

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
Expand All @@ -139,8 +111,9 @@ def test_gmm(self):

lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
out = megablox.gmm(lhs, rhs, group_sizes)
group_sizes = self._group_sizes_strategy(
m=m, num_groups=num_groups) # This is a cpu tensor!!!!!!!
out = gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
Expand Down
50 changes: 50 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import warnings

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -488,6 +489,55 @@ def paged_attention(q,
return output.reshape(batch_size, num_heads, head_dim).to(q.dtype)


def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Args:
lhs: A 2d, jnp.ndarray with shape [m, k].
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
Returns:
A 2d, jnp.ndarray with shape [m, n].
"""
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm, make_group_metadata

payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes)

m, n = lhs.shape[0], rhs.shape[2]
# Create the metadata we need for computation.
# TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor.
# That means we need to materialize this input in order to use this gmm
# kernel, and that will introduce graph breaks in the computation.
group_sizes = jnp.asarray(group_sizes.numpy())
group_metadata, num_active_tiles = make_group_metadata(
group_sizes=group_sizes,
m=lhs.shape[0],
tm=128,
start_group=0,
num_nonzero_groups=rhs.shape[0],
visit_empty_groups=False,
)
group_metadata0 = torch.from_numpy(np.array(group_metadata[0])).to(
torch.int32).to("xla")
group_metadata1 = torch.from_numpy(np.array(group_metadata[1])).to("xla")
group_metadata2 = torch.from_numpy(np.array(group_metadata[2])).to("xla")
num_active_tiles = torch.tensor(np.array(num_active_tiles)).to("xla")
group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")

return torch_xla._XLAC._xla_tpu_custom_call([
num_active_tiles, group_metadata0, group_metadata1, group_metadata2,
group_offset_torch, lhs, rhs
], payload, [torch.Size([m, n])], [lhs.dtype])


def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
Expand Down
1 change: 0 additions & 1 deletion torch_xla/experimental/megablox/__init__.py

This file was deleted.

22 changes: 0 additions & 22 deletions torch_xla/experimental/megablox/common.py

This file was deleted.

Loading

0 comments on commit a1e5f4c

Please sign in to comment.