diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index af95dc955b66..8576c908e0a5 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -5,13 +5,13 @@ on: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' push: branches: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' workflow_dispatch: concurrency: diff --git a/.github/workflows/build_upstream_image.yml b/.github/workflows/build_upstream_image.yml index bb8ce87f01ce..37992bc20f8e 100644 --- a/.github/workflows/build_upstream_image.yml +++ b/.github/workflows/build_upstream_image.yml @@ -5,7 +5,7 @@ on: - master - r[0-9]+.[0-9]+ paths-ignore: - - 'experimental/torch_xla2/**' + - 'experimental/**' workflow_dispatch: jobs: build: diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index f449983abf44..1a999e441dc2 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -7,6 +7,7 @@ instantiate_device_type_tests, ops) from torch.utils import _pytree as pytree from torch_xla2 import tensor +import torch_xla2 skiplist = { diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 54af0eccab42..f7dbde712636 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -16,7 +16,6 @@ from jax._src import xla_bridge os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') -jax.config.update('jax_enable_x64', True) # torch_xla2:oss-begin old_pjrt_options = jax.config.jax_pjrt_client_create_options @@ -80,4 +79,16 @@ def disable_globally(): unsupported_dtype=unsupported_dtype) import jax -torch._register_device_module('jax', jax) \ No newline at end of file +torch._register_device_module('jax', jax) + + +def enable_accuracy_mode(): + jax.config.update('jax_enable_x64', True) + jax.config.update('jax_default_matmul_precision', 'highest') + default_env().config.internal_respect_torch_return_dtypes = True + + +def enable_performance_mode(): + jax.config.update('jax_enable_x64', False) + jax.config.update('jax_default_matmul_precision', 'default') + default_env().config.internal_respect_torch_return_dtypes = False \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 119f3b44d7e3..8a0870996a2b 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -15,3 +15,4 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True use_torch_native_for_cpu_tensor: bool = False + internal_respect_torch_return_dtypes: bool = False