diff --git a/test/test_megablox.py b/test/test_megablox.py index e4ce5513490..af4ecf76b31 100644 --- a/test/test_megablox.py +++ b/test/test_megablox.py @@ -8,7 +8,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -import torch_xla.experimental.megablox.gmm as g +import torch_xla.experimental.megablox as megablox from torch_xla import runtime as xr from torch_xla._internal import tpu @@ -130,7 +130,6 @@ def _init_test_cases(self): def test_gmm(self): self._init_test_cases() for test_case in self.tests_cases: - print("Test Case: ", test_case) num_groups = test_case['num_groups'] k = test_case['k'] m = test_case['m'] @@ -141,7 +140,7 @@ def test_gmm(self): lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) - out = g.gmm(lhs, rhs, group_sizes) + out = megablox.gmm(lhs, rhs, group_sizes) ref_out = self._reference_gmm(lhs.cpu().float().numpy(), rhs.cpu().float().numpy(), diff --git a/torch_xla/experimental/megablox/gmm.py b/torch_xla/experimental/megablox/gmm.py index 625f8985879..518553d474e 100644 --- a/torch_xla/experimental/megablox/gmm.py +++ b/torch_xla/experimental/megablox/gmm.py @@ -14,7 +14,7 @@ def _validate_args( rhs: torch.Tensor, group_sizes: torch.Tensor, expected_rhs_dims: int = 3, -) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]: +) -> 'tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]': # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard() @@ -59,10 +59,10 @@ def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]: def _make_group_metadata( *, - group_sizes: jnp.ndarray, + group_sizes: 'jnp.ndarray', m: int, tm: int, - start_group: jnp.ndarray, + start_group: 'jnp.ndarray', num_nonzero_groups: int, visit_empty_groups: bool = True, ) -> GroupMetadata: @@ -239,9 +239,9 @@ def _make_group_metadata( def _zero_uninitialized_memory( - out: jnp.ndarray, + out: 'jnp.ndarray', *, - start_group: jnp.ndarray, + start_group: 'jnp.ndarray', num_nonzero_groups: int, group_metadata: GroupMetadata, ) -> torch.Tensor: