diff --git a/torch_xla/experimental/megablox/common.py b/torch_xla/experimental/megablox/common.py index 429c145a9e34..c934dd7ab28e 100644 --- a/torch_xla/experimental/megablox/common.py +++ b/torch_xla/experimental/megablox/common.py @@ -2,7 +2,7 @@ from typing import Union import torch -import tpu_features +from torch_xla.experimental.megablox import tpu_features def assert_is_supported_dtype(dtype: torch.dtype) -> None: diff --git a/torch_xla/experimental/megablox/gmm.py b/torch_xla/experimental/megablox/gmm.py index c45c12f86911..1b1548440d27 100644 --- a/torch_xla/experimental/megablox/gmm.py +++ b/torch_xla/experimental/megablox/gmm.py @@ -1,7 +1,7 @@ """Grouped matrix multiplication kernels for TPU written in Pallas.""" from typing import Any, Callable, Optional, Union -import common +from torch_xla.experimental.megablox import common import torch import torch_xla import numpy as np