From 93a04d1bcb76b1c31df321b148fa9c918df77630 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 10 Oct 2023 18:39:16 +0000 Subject: [PATCH 1/3] Set TPU_MEGACORE in configure_topology --- torch_xla/_internal/tpu.py | 3 +++ torch_xla/core/xla_env_vars.py | 1 + 2 files changed, 4 insertions(+) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 89fbca4dcbc..a2b2b0eaf7b 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -266,6 +266,9 @@ def configure_topology(local_rank: int, os.environ.setdefault(xenv.TPU_VISIBLE_CHIPS, str(local_rank)) os.environ.setdefault(xenv.TPU_PROCESS_PORT, str(ports[local_rank])) + if version() == 4: + os.environ.setdefault(xenv.TPU_MEGACORE, 'megacore_dense') + def discover_master_worker_ip(use_localhost: bool = True) -> str: """Find the IP of the TPU host with TPU:0. diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index f67ea2d9fb6..24aaa54b60c 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -22,6 +22,7 @@ TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES' TPU_VISIBLE_CHIPS = 'TPU_VISIBLE_CHIPS' TPU_PROCESS_PORT = 'TPU_PROCESS_PORT' +TPU_MEGACORE = 'TPU_MEGACORE' PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT' PJRT_GPU_ASYNC_CLIENT = 'PJRT_GPU_ASYNC_CLIENT' PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR' From de7522560a45c9d14dfae93755d068d4682f0217 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 10 Oct 2023 18:42:04 +0000 Subject: [PATCH 2/3] remove --- torch_xla/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index cfde8456617..13983e3b02c 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -40,7 +40,6 @@ def _setup_default_env(): _set_missing_env('GRPC_VERBOSITY', 'ERROR') _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') - _set_missing_env('TPU_MEGACORE', 'megacore_dense') _fd, _tmp_fname = -1, '' From 3b19f155f3346862aa2cd0365d635659af83c2f9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 10 Oct 2023 20:21:42 +0000 Subject: [PATCH 3/3] Move back to __init__.py --- torch_xla/__init__.py | 20 +++++++++++--------- torch_xla/_internal/tpu.py | 3 --- torch_xla/core/xla_env_vars.py | 1 - 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 13983e3b02c..eeaa2aaba0c 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -3,6 +3,8 @@ import re import tempfile +from ._internal import tpu + logging.basicConfig() logger = logging.getLogger(__name__) @@ -30,16 +32,16 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) -def _set_missing_env(name, value): - if name not in os.environ: - os.environ[name] = value +def _setup_default_env(): + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') + if tpu.num_available_chips() > 0: + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') + os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') -def _setup_default_env(): - _set_missing_env('TF_CPP_MIN_LOG_LEVEL', '1') - _set_missing_env('GRPC_VERBOSITY', 'ERROR') - _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') - _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') + if tpu.version() == 4: + os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') _fd, _tmp_fname = -1, '' @@ -47,7 +49,7 @@ def _setup_default_env(): def _setup_debug_env(): fd, tmp_fname = tempfile.mkstemp('.ptxla', text=True) - _set_missing_env('XLA_FNTRACKER_FILE', tmp_fname) + os.environ.setdefault('XLA_FNTRACKER_FILE', tmp_fname) return fd, tmp_fname diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index a2b2b0eaf7b..89fbca4dcbc 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -266,9 +266,6 @@ def configure_topology(local_rank: int, os.environ.setdefault(xenv.TPU_VISIBLE_CHIPS, str(local_rank)) os.environ.setdefault(xenv.TPU_PROCESS_PORT, str(ports[local_rank])) - if version() == 4: - os.environ.setdefault(xenv.TPU_MEGACORE, 'megacore_dense') - def discover_master_worker_ip(use_localhost: bool = True) -> str: """Find the IP of the TPU host with TPU:0. diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index 24aaa54b60c..f67ea2d9fb6 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -22,7 +22,6 @@ TPU_PROCESS_ADDRESSES = 'TPU_PROCESS_ADDRESSES' TPU_VISIBLE_CHIPS = 'TPU_VISIBLE_CHIPS' TPU_PROCESS_PORT = 'TPU_PROCESS_PORT' -TPU_MEGACORE = 'TPU_MEGACORE' PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT' PJRT_GPU_ASYNC_CLIENT = 'PJRT_GPU_ASYNC_CLIENT' PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR'