Skip to content

Commit

Permalink
Fix unit test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed May 9, 2024
1 parent 3cbc14f commit da821f5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
5 changes: 2 additions & 3 deletions test/test_megablox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.megablox.gmm as g
import torch_xla.experimental.megablox as megablox
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -130,7 +130,6 @@ def _init_test_cases(self):
def test_gmm(self):
self._init_test_cases()
for test_case in self.tests_cases:
print("Test Case: ", test_case)
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
Expand All @@ -141,7 +140,7 @@ def test_gmm(self):
lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
out = g.gmm(lhs, rhs, group_sizes)
out = megablox.gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/experimental/megablox/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _validate_args(
rhs: torch.Tensor,
group_sizes: torch.Tensor,
expected_rhs_dims: int = 3,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]:
) -> 'tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]':
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
Expand Down Expand Up @@ -59,10 +59,10 @@ def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]:

def _make_group_metadata(
*,
group_sizes: jnp.ndarray,
group_sizes: 'jnp.ndarray',
m: int,
tm: int,
start_group: jnp.ndarray,
start_group: 'jnp.ndarray',
num_nonzero_groups: int,
visit_empty_groups: bool = True,
) -> GroupMetadata:
Expand Down Expand Up @@ -239,9 +239,9 @@ def _make_group_metadata(


def _zero_uninitialized_memory(
out: jnp.ndarray,
out: 'jnp.ndarray',
*,
start_group: jnp.ndarray,
start_group: 'jnp.ndarray',
num_nonzero_groups: int,
group_metadata: GroupMetadata,
) -> torch.Tensor:
Expand Down

0 comments on commit da821f5

Please sign in to comment.