From f0af628b84473dbf630dfbeb4f8f967912c05dbf Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 10 Oct 2023 14:07:46 -0700 Subject: [PATCH] Conditionally set default TPU settings in `__init__.py` (#5696) * Set TPU_MEGACORE in configure_topology * remove * Move back to __init__.py --- torch_xla/__init__.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index cfde84566178..eeaa2aaba0c9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -3,6 +3,8 @@ import re import tempfile +from ._internal import tpu + logging.basicConfig() logger = logging.getLogger(__name__) @@ -30,17 +32,16 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) -def _set_missing_env(name, value): - if name not in os.environ: - os.environ[name] = value +def _setup_default_env(): + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') + if tpu.num_available_chips() > 0: + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') + os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') -def _setup_default_env(): - _set_missing_env('TF_CPP_MIN_LOG_LEVEL', '1') - _set_missing_env('GRPC_VERBOSITY', 'ERROR') - _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') - _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') - _set_missing_env('TPU_MEGACORE', 'megacore_dense') + if tpu.version() == 4: + os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') _fd, _tmp_fname = -1, '' @@ -48,7 +49,7 @@ def _setup_default_env(): def _setup_debug_env(): fd, tmp_fname = tempfile.mkstemp('.ptxla', text=True) - _set_missing_env('XLA_FNTRACKER_FILE', tmp_fname) + os.environ.setdefault('XLA_FNTRACKER_FILE', tmp_fname) return fd, tmp_fname