diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh new file mode 100755 index 00000000000..9c988f56b60 --- /dev/null +++ b/test/neuron/run_tests.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -xue + +python3 test/neuron/test_neuron_utils.py diff --git a/test/neuron/test_neuron_utils.py b/test/neuron/test_neuron_utils.py new file mode 100644 index 00000000000..b99110556b4 --- /dev/null +++ b/test/neuron/test_neuron_utils.py @@ -0,0 +1,57 @@ +import os +import pytest +import unittest +from torch_xla._internal.neuron_utils import * + + +class NeuronUtilsTest(unittest.TestCase): + + def test_get_visible_cores_list(self): + os.environ["NEURON_RT_VISIBLE_CORES"] = "1" + assert (get_visible_cores_list() == [1]) + os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3" + assert (get_visible_cores_list() == [1, 2, 3]) + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3" + assert (get_visible_cores_list() == [1, 2, 3]) + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8" + assert (get_visible_cores_list() == [1, 2, 3, 5, 6, 7, 8]) + os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8" + assert (get_visible_cores_list() == [1, 3, 5, 6, 7, 8]) + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8,3-5" + with pytest.raises(ValueError): + get_visible_cores_list() + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8-5" + with pytest.raises(ValueError): + get_visible_cores_list() + os.environ["NEURON_RT_VISIBLE_CORES"] = "a-b,5-8-5" + with pytest.raises(Exception): + get_visible_cores_list() + os.environ["NEURON_RT_VISIBLE_CORES"] = "a" + with pytest.raises(Exception): + get_visible_cores_list() + + def test_remap_visible_cores(self): + os.environ["NEURON_RT_VISIBLE_CORES"] = "1" + remap_visible_cores(0, 1) + assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "1") + os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3" + remap_visible_cores(1, 3) + assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "2") + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3" + remap_visible_cores(2, 3) + assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "3") + os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8" + remap_visible_cores(5, 7) + assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "7") + os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8" + remap_visible_cores(5, 6) + assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "8") + with pytest.raises(ValueError): + remap_visible_cores(5, 9) + with pytest.raises(ValueError): + remap_visible_cores(6, 6) + + +if __name__ == "__main__": + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index ac314e3c60c..fcb44e2cb93 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -86,7 +86,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt): reload(torch_xla) logs_context = contextlib.nullcontext() if expect_using_pjrt: - self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU']) + self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'NEURON']) else: self.assertIsNone(xr.device_type()) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 2f8e1e947de..3a75796b155 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -116,11 +116,31 @@ def _summarize_fn_tracker(): def _aws_ec2_inf_trn_init(): try: - from torch_neuronx import xla + from libneuronxla.libneuronpjrt_path import libneuronpjrt_path except ImportError: - return + pass else: - xla.init() + # Need to set NEURON_LIBRARY_PATH here for proper Neuron Cache behavior + os.environ.setdefault('NEURON_LIBRARY_PATH', libneuronpjrt_path()) + # Enable addition features and overrides + try: + from torch_neuronx import xla + except ImportError: + pass + else: + xla.init() + + # Basic initializations if torch-neuronx is not available + from ._internal import neuron + if os.path.basename(sys.argv[0]) != 'neuron_parallel_compile': + import libneuronxla + libneuronxla.configure_environment() + neuron.set_envvar_defaults() + neuron.configure_pjrt_environment() + # Found libneuronxla + return True + # Did not find libneuronxla + return False def _setup_tpu_vm_library_path() -> bool: @@ -179,7 +199,7 @@ def _check_deprecated_env_var(): _found_libtpu = _setup_tpu_vm_library_path() # Setup Neuron library for AWS EC2 inf/trn instances. -_aws_ec2_inf_trn_init() +_found_libneuronxla = _aws_ec2_inf_trn_init() def _prepare_to_exit(): diff --git a/torch_xla/_internal/neuron.py b/torch_xla/_internal/neuron.py index 5d5481c18d4..e7a98cdcf5b 100644 --- a/torch_xla/_internal/neuron.py +++ b/torch_xla/_internal/neuron.py @@ -3,33 +3,115 @@ from torch_xla.experimental import plugins +import sys +import torch.distributed as dist + +from .neuron_utils import get_visible_cores_list, remap_visible_cores + +logging.basicConfig() +logger = logging.getLogger(__name__) + + +# Set root communication address/port +def set_rt_root_comm_id(): + if os.environ.get('NEURON_RT_ROOT_COMM_ID', None) is None: + if 'MASTER_ADDR' not in os.environ: + logging.warning( + "MASTER_ADDR environment variable is not set, defaulting to localhost" + ) + root_port = 62182 + root_addr = os.environ.get('MASTER_ADDR', 'localhost') + is_ipv6 = len(root_addr.split(":")) >= 3 + if is_ipv6: + modified = False + if not root_addr.startswith("["): + root_addr = "[" + root_addr + modified = True + if not root_addr.endswith("]"): + root_addr = root_addr + "]" + modified = True + if modified: + logger.warning( + "IPv6 address detected for MASTER_ADDR and missing brackets added: {}" + .format(root_addr)) + os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(root_addr, root_port) + + +def set_envvar_defaults(): + os.environ.setdefault('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', '50') + + +def configure_pjrt_environment(): + """ + Setting all necessary PJRT default environment variables. + """ + from torch.distributed import is_torchelastic_launched + + # Set root communication address/port + set_rt_root_comm_id() + + # Set env variables if we don't use GSPMD, using PJRT, and using torchrun + if os.environ.get('XLA_USE_SPMD', '0') != '1' \ + and is_torchelastic_launched(): + # Env variables that only need to be set once + # NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster, + # so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client. + if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ: + if 'WORLD_SIZE' not in os.environ: + logger.warning( + 'WORLD_SIZE environment variable not set, defaulting to 1.') + os.environ["NEURON_PJRT_WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1") + if 'LOCAL_WORLD_SIZE' not in os.environ: + logger.warning( + 'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.') + os.environ['PJRT_LOCAL_PROCESS_COUNT'] = os.environ.get( + 'LOCAL_WORLD_SIZE', '1') + + # Env variables that need to be set once per process + if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): + os.environ['NEURON_RT_VISIBLE_CORES'] = os.environ.get('LOCAL_RANK', '0') + else: + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', '1')) + remap_visible_cores(local_rank, local_world_size) + + if 'RANK' not in os.environ: + logger.warning('RANK environment variable is not set, defaulting to 0.') + os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0') + if 'LOCAL_RANK' not in os.environ: + logger.warning( + 'LOCAL RANK environment variable is not set, defaulting to 0.') + os.environ['PJRT_LOCAL_PROCESS_RANK'] = os.environ.get('LOCAL_RANK', '0') + def num_local_processes() -> int: - if 'MASTER_ADDR' not in os.environ: - logging.warning("MASTER_ADDR not setting, defaulting to localhost") - os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format( - os.environ.get('MASTER_ADDR', 'localhost'), '62182') - if "NEURONCORE_NUM_DEVICES" not in os.environ: - logging.warning("NEURONCORE_NUM_DEVICES not set, defaulting to 1") + set_rt_root_comm_id() num_processes = int(os.environ.get("NEURONCORE_NUM_DEVICES", "1")) os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = ','.join( ['1' for _ in range(num_processes)]) - return num_processes +# When torchrun is used, setting these environments causes the +# second instance in 2-node cluster to think it is node 0 instead of node 1. +# Need to skip these settings and let configure_pjrt_environment to +# set the distributed PJRT environment variables. +# If NEURONCORE_NUM_DEVICES is used, then go ahead and set the environments. def initialize_env(local_rank, local_world_size): - os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank) - assert ( - local_rank < local_world_size - ), "ERROR in initialize_env: PJRT_LOCAL_PROCESS_RANK is not less than PJRT_LOCAL_PROCESS_COUNT" - os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank) + from torch.distributed import is_torchelastic_launched + if not is_torchelastic_launched(): + os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank) + if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): + os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank) + else: + remap_visible_cores(local_rank, local_world_size) class NeuronPlugin(plugins.DevicePlugin): def library_path(self): - return os.environ.get("NEURON_LIBRARY_PATH", "libneuronpjrt.so") + from libneuronxla.libneuronpjrt_path import libneuronpjrt_path + return os.environ.get("NEURON_LIBRARY_PATH", libneuronpjrt_path()) def configure_multiprocess(self, local_rank, local_world_size): initialize_env(local_rank, local_world_size) diff --git a/torch_xla/_internal/neuron_utils.py b/torch_xla/_internal/neuron_utils.py new file mode 100644 index 00000000000..70c1aedbbb0 --- /dev/null +++ b/torch_xla/_internal/neuron_utils.py @@ -0,0 +1,66 @@ +import os +import logging +logging.basicConfig() +logger = logging.getLogger(__name__) + + +def convert_range(range_spec): + try: + lowerupper = list(map(int, range_spec.split("-"))) + except Exception as e: + print(f"ERROR: Malformed range specs in NEURON_RT_VISIBLE_CORES;" + + f"expecting or - (got {range_spec})") + raise e + if len(lowerupper) > 2: + raise ValueError( + f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should be of " + + f"the form or - (got {range_spec})") + if len(lowerupper) == 2: + if lowerupper[0] > lowerupper[1]: + raise ValueError( + f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should " + + f"be of the form or - (got {range_spec})") + lowerupper = range(lowerupper[0], lowerupper[1] + 1) + return lowerupper + + +def get_visible_cores_list(): + import os + + range_list = os.environ.get("NEURON_RT_VISIBLE_CORES") + cores_list = None + if range_list: + range_list = range_list.split(",") + cores_list = [] + for i in range_list: + new = convert_range(i) + if (set(cores_list) & set(new)) != set(): + raise ValueError( + "ERROR: Please ensure the ranges in NEURON_RT_VISIBLE_CORES are mutually exclusive." + ) + cores_list.extend(new) + return cores_list + + +def remap_visible_cores(local_rank, local_world_size): + cores_list = get_visible_cores_list() + count = len(cores_list) + assert (local_world_size > 0), "Local world size should be non-zero" + if count <= 1 and local_world_size == 1: + # Allow user to pass NEURON_RT_VISIBLE_CORES for sinlge-core workload + pass + elif local_world_size != count: + raise ValueError( + f"LOCAL_WORLD_SIZE (torchrun) or PJRT_LOCAL_PROCESS_COUNT (xmp.spawn) value of {local_world_size} " + + + f"is not equal to count {count} from NEURON_RT_VISIBLE_CORES {cores_list}" + ) + elif local_rank >= count: + raise ValueError( + f"LOCAL_RANK (torchrun) or PJRT_LOCAL_PROCESS_RANK (xmp.spawn) value of {local_rank} is higher than " + + f"count {count} from NEURON_RT_VISIBLE_CORES {cores_list}") + else: + remapped_core = cores_list[local_rank] + logger.warning(f"Remapping NEURON_RT_VISIBLE_CORES {cores_list} to " + + f"NEURON_RT_VISIBLE_CORES[LOCAL_RANK]={remapped_core}") + os.environ['NEURON_RT_VISIBLE_CORES'] = str(remapped_core) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 0b963e378ec..e4560df6c70 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -72,6 +72,9 @@ def _maybe_select_default_device(): + num_devices_str) os.environ[xenv.PJRT_DEVICE] = 'CUDA' os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str + elif torch_xla._found_libneuronxla: + logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.') + os.environ[xenv.PJRT_DEVICE] = 'NEURON' else: logging.warning('Defaulting to PJRT_DEVICE=CPU') os.environ[xenv.PJRT_DEVICE] = 'CPU'