From f54fdc3ce0c6f104607f8c6c6a75509444b31aee Mon Sep 17 00:00:00 2001 From: Milad Mohammadi Date: Fri, 3 May 2024 16:23:29 +0000 Subject: [PATCH] library fix --- torch_xla/experimental/megablox/common.py | 2 +- torch_xla/experimental/megablox/gmm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/megablox/common.py b/torch_xla/experimental/megablox/common.py index 429c145a9e34..c934dd7ab28e 100644 --- a/torch_xla/experimental/megablox/common.py +++ b/torch_xla/experimental/megablox/common.py @@ -2,7 +2,7 @@ from typing import Union import torch -import tpu_features +from torch_xla.experimental.megablox import tpu_features def assert_is_supported_dtype(dtype: torch.dtype) -> None: diff --git a/torch_xla/experimental/megablox/gmm.py b/torch_xla/experimental/megablox/gmm.py index c45c12f86911..1b1548440d27 100644 --- a/torch_xla/experimental/megablox/gmm.py +++ b/torch_xla/experimental/megablox/gmm.py @@ -1,7 +1,7 @@ """Grouped matrix multiplication kernels for TPU written in Pallas.""" from typing import Any, Callable, Optional, Union -import common +from torch_xla.experimental.megablox import common import torch import torch_xla import numpy as np