-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding megablox gmm standalone #6940
Conversation
cc @tgale96 for review. |
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple | ||
|
||
|
||
def _make_group_metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if you could trace the metadata function we have in the library with the GMM to avoid duplicating this tricky bit of code? If not this is fine, just curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracing doesn't seem to be an option AFAIK - though, it would be great if I we found a way to call the jax implementation of this method and make this whole implementation leaner. I suggest we do it as a follow up PR.
cc @alanwaketan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just use the JAX version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we use the jax version to do this compute. Let's make sure we pass jax CPU tensors such that this part of compute can be done in cpu instead. As far as I can tell only group_sizes is used and it's 1D, so should be pretty lightweight to compute. We should also benchmark this against the reference_gmm in case this part drastically increase the tracing time. On that matter, we should cache the result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we can't do this given the group_sizes is data produced in the middle of the graph. And it means we need to do a graph break.
Picking this up, now rebased from master to fix the conflicts. This PR should be ready to be reviewed/merged, I'll run the TPU CI to verify one more time. |
@JackCaoG thanks for the comments, this should be ready for another round of review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the TPUCI tag and rerun the CI. Feel free to merge once v4 test passed.
Hopefully, I can take a look tomorrow after going over all the reading materials w.r.t MoE and megablocks. If I couldn't get to it, feel free to land it as it is. We can always follow up. |
Given that the CI (+ TPU CI) is green, I'll go ahead and merge this. I'll follow-up with any fixes if needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wonjoolee95 Do you think we can make a follow up PR to simplify this before moving to tgmm?
lhs: torch.Tensor, | ||
rhs: torch.Tensor, | ||
group_sizes: torch.Tensor, | ||
preferred_element_type: torch.dtype = torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should omit preferred_element_type, tiling, group_offset, existing_out, transpose_rhs and interpret parameters unless we know the users for sure need that.
return (group_offsets, group_ids, m_tile_ids), num_tiles | ||
|
||
|
||
def _zero_uninitialized_memory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just use the JAX version?
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple | ||
|
||
|
||
def _make_group_metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just use the JAX version?
import numpy as np | ||
|
||
|
||
def _validate_args( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just use the JAX version?
@@ -0,0 +1,22 @@ | |||
"""Common utilities for Pallas kernels.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file can be deleted if we directly use the helper from JAX.
@@ -0,0 +1 @@ | |||
from .gmm import gmm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we remove all the duplicated code. We can move this method back to custom_kernel.py.
from jax.experimental import pallas as pl | ||
|
||
|
||
class MegabloxTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we merge this to test_pallas.py?
group_offset_torch = torch.from_numpy(np.array(group_offset)).to("xla") | ||
output_shape = torch.Size([m, n]) | ||
out = torch_xla._XLAC._xla_tpu_custom_call([ | ||
num_active_tiles, group_metadata0, group_metadata1, group_metadata2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only duplicate the logic to get us these parameters. Anything else can be removed.
group_offset_torch, lhs, rhs | ||
], payload, [output_shape], [preferred_element_type]) | ||
|
||
if existing_out is None and num_current_groups < num_total_groups: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, this is only needed after we have expert parallelism. I still cannot tell if we can get there so far.
|
||
class MegabloxTest(unittest.TestCase): | ||
|
||
def _reference_gmm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can just do it in torch instead of np?
start += group_sizes[i] | ||
return np.array(np.concatenate(out, axis=0)) | ||
|
||
def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, for us, we just need to make sure our piping is correct and we don't need to ensure gmm itself is correct. That's JAX's job. So, let's remove this and pick one or two cases that are tuned to our wrapper.
starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) | ||
return torch.from_numpy(ends - starts).to(torch.int32) | ||
|
||
def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use torch, and we don't need this. We can just torch.allclose.
return 1e-3, 1e-2 # atol, rtol | ||
return 1e-4, 1e-2 # atol, rtol | ||
|
||
LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this?
|
||
LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] | ||
|
||
def _init_test_cases(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might not need all of these.
|
||
lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') | ||
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') | ||
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a CPU tensor!!!!!!!!!!!!!!!!!!!!!!!!
lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') | ||
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') | ||
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) | ||
out = megablox.gmm(lhs, rhs, group_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always output fp32 in this test case regardless of the input dtypes....
Summary: This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay... Test Plan: python test/test_megablox.py
Summary: This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay... Test Plan: python test/test_megablox.py
In this PR, we add megablox kernel. The current implementation adds a new file
megablox_gmm
. I plan to merge it tocustom_kernel
with the rest of kernels.