Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding megablox gmm standalone #6940

Merged
merged 18 commits into from
May 10, 2024
Merged

Adding megablox gmm standalone #6940

merged 18 commits into from
May 10, 2024

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented Apr 18, 2024

In this PR, we add megablox kernel. The current implementation adds a new file megablox_gmm. I plan to merge it to custom_kernel with the rest of kernels.

@miladm miladm marked this pull request as draft April 18, 2024 22:55
@miladm miladm self-assigned this Apr 23, 2024
@miladm miladm added the pallas label Apr 23, 2024
@wonjoolee95 wonjoolee95 self-requested a review May 1, 2024 00:04
@miladm miladm marked this pull request as ready for review May 3, 2024 05:00
@miladm miladm requested a review from alanwaketan May 3, 2024 05:01
test/run_tests.sh Outdated Show resolved Hide resolved
@miladm
Copy link
Collaborator Author

miladm commented May 4, 2024

cc @tgale96 for review.

@miladm miladm requested a review from tgale96 May 6, 2024 18:52
test/test_megablox.py Outdated Show resolved Hide resolved
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple


def _make_group_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you could trace the metadata function we have in the library with the GMM to avoid duplicating this tricky bit of code? If not this is fine, just curious.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracing doesn't seem to be an option AFAIK - though, it would be great if I we found a way to call the jax implementation of this method and make this whole implementation leaner. I suggest we do it as a follow up PR.

cc @alanwaketan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just use the JAX version?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we use the jax version to do this compute. Let's make sure we pass jax CPU tensors such that this part of compute can be done in cpu instead. As far as I can tell only group_sizes is used and it's 1D, so should be pretty lightweight to compute. We should also benchmark this against the reference_gmm in case this part drastically increase the tracing time. On that matter, we should cache the result.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we can't do this given the group_sizes is data produced in the middle of the graph. And it means we need to do a graph break.

@wonjoolee95
Copy link
Collaborator

Picking this up, now rebased from master to fix the conflicts. This PR should be ready to be reviewed/merged, I'll run the TPU CI to verify one more time.

test/test_megablox.py Outdated Show resolved Hide resolved
@wonjoolee95
Copy link
Collaborator

@JackCaoG thanks for the comments, this should be ready for another round of review.

@JackCaoG JackCaoG added the tpuci label May 8, 2024
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the TPUCI tag and rerun the CI. Feel free to merge once v4 test passed.

@alanwaketan
Copy link
Collaborator

alanwaketan commented May 9, 2024

Hopefully, I can take a look tomorrow after going over all the reading materials w.r.t MoE and megablocks. If I couldn't get to it, feel free to land it as it is. We can always follow up.

@wonjoolee95
Copy link
Collaborator

Given that the CI (+ TPU CI) is green, I'll go ahead and merge this. I'll follow-up with any fixes if needed.

@wonjoolee95 wonjoolee95 merged commit 40f7e1f into master May 10, 2024
20 checks passed
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wonjoolee95 Do you think we can make a follow up PR to simplify this before moving to tgmm?

lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
preferred_element_type: torch.dtype = torch.float32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should omit preferred_element_type, tiling, group_offset, existing_out, transpose_rhs and interpret parameters unless we know the users for sure need that.

return (group_offsets, group_ids, m_tile_ids), num_tiles


def _zero_uninitialized_memory(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just use the JAX version?

GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple


def _make_group_metadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just use the JAX version?

import numpy as np


def _validate_args(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just use the JAX version?

@@ -0,0 +1,22 @@
"""Common utilities for Pallas kernels."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can be deleted if we directly use the helper from JAX.

@@ -0,0 +1 @@
from .gmm import gmm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we remove all the duplicated code. We can move this method back to custom_kernel.py.

from jax.experimental import pallas as pl


class MegabloxTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we merge this to test_pallas.py?

group_offset_torch = torch.from_numpy(np.array(group_offset)).to("xla")
output_shape = torch.Size([m, n])
out = torch_xla._XLAC._xla_tpu_custom_call([
num_active_tiles, group_metadata0, group_metadata1, group_metadata2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only duplicate the logic to get us these parameters. Anything else can be removed.

group_offset_torch, lhs, rhs
], payload, [output_shape], [preferred_element_type])

if existing_out is None and num_current_groups < num_total_groups:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this is only needed after we have expert parallelism. I still cannot tell if we can get there so far.


class MegabloxTest(unittest.TestCase):

def _reference_gmm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just do it in torch instead of np?

start += group_sizes[i]
return np.array(np.concatenate(out, axis=0))

def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, for us, we just need to make sure our piping is correct and we don't need to ensure gmm itself is correct. That's JAX's job. So, let's remove this and pick one or two cases that are tuned to our wrapper.

starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
return torch.from_numpy(ends - starts).to(torch.int32)

def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use torch, and we don't need this. We can just torch.allclose.

return 1e-3, 1e-2 # atol, rtol
return 1e-4, 1e-2 # atol, rtol

LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this?


LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]

def _init_test_cases(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not need all of these.


lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a CPU tensor!!!!!!!!!!!!!!!!!!!!!!!!

lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
out = megablox.gmm(lhs, rhs, group_sizes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always output fp32 in this test case regardless of the input dtypes....

alanwaketan added a commit that referenced this pull request May 23, 2024
Summary:
This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay...

Test Plan:
python test/test_megablox.py
qihqi pushed a commit that referenced this pull request May 29, 2024
Summary:
This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay...

Test Plan:
python test/test_megablox.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants