diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index cfde8456617..eeaa2aaba0c 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -3,6 +3,8 @@ import re import tempfile +from ._internal import tpu + logging.basicConfig() logger = logging.getLogger(__name__) @@ -30,17 +32,16 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) -def _set_missing_env(name, value): - if name not in os.environ: - os.environ[name] = value +def _setup_default_env(): + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') + if tpu.num_available_chips() > 0: + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') + os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') -def _setup_default_env(): - _set_missing_env('TF_CPP_MIN_LOG_LEVEL', '1') - _set_missing_env('GRPC_VERBOSITY', 'ERROR') - _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') - _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') - _set_missing_env('TPU_MEGACORE', 'megacore_dense') + if tpu.version() == 4: + os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') _fd, _tmp_fname = -1, '' @@ -48,7 +49,7 @@ def _setup_default_env(): def _setup_debug_env(): fd, tmp_fname = tempfile.mkstemp('.ptxla', text=True) - _set_missing_env('XLA_FNTRACKER_FILE', tmp_fname) + os.environ.setdefault('XLA_FNTRACKER_FILE', tmp_fname) return fd, tmp_fname