Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Mar 26, 2024
1 parent 6b543f1 commit 860a296
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB

_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"

XLA_LIB.define(
"tpu_custom_call_(Tensor(a!) output, Tensor[] inputs, str payload) -> ()",)

Expand Down Expand Up @@ -75,13 +77,12 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
import jax._src.pallas.mosaic.pallas_call_registration

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
if dtype == torch.float32:
if XLA_USE_BF16:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float32
elif dtype == torch.float64:
if XLA_USE_BF16:
if _XLA_USE_BF16:
return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
Expand Down

0 comments on commit 860a296

Please sign in to comment.