From 5f4b2451ab8fc9b8d3738f3f2d237835177f2d3f Mon Sep 17 00:00:00 2001 From: mmakevic Date: Fri, 15 Mar 2024 13:23:08 +0000 Subject: [PATCH 1/5] Add rocm pjrt plugin --- .bazelrc | 8 +++ BUILD | 7 +++ build_util.py | 2 + plugins/rocm/pyproject.toml | 14 +++++ plugins/rocm/setup.py | 13 +++++ .../rocm/torch_xla_rocm_plugin/__init__.py | 56 +++++++++++++++++++ 6 files changed, 100 insertions(+) create mode 100644 plugins/rocm/pyproject.toml create mode 100644 plugins/rocm/setup.py create mode 100644 plugins/rocm/torch_xla_rocm_plugin/__init__.py diff --git a/.bazelrc b/.bazelrc index 694cf3fd125..09ec72e0a25 100644 --- a/.bazelrc +++ b/.bazelrc @@ -54,6 +54,14 @@ build:cuda --@local_config_cuda//:enable_cuda build:cuda --define=xla_python_enable_gpu=true build:cuda --cxxopt=-DXLA_CUDA=1 +build:rocm --repo_env TF_NEED_ROCM=1 +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain + + +build:rocm --@xla//xla/python:enable_gpu=true +build:rocm --define=xla_python_enable_gpu=true +build:rocm --cxxopt=-DXLA_ROCM=1 + # Coverage with cuda/gcc/nvcc requires manually setting coverage flags. coverage:cuda --per_file_copt=third_party/.*,torch_xla/.*@--coverage coverage:cuda --linkopt=-lgcov diff --git a/BUILD b/BUILD index c9a2578c722..462a84af71b 100644 --- a/BUILD +++ b/BUILD @@ -6,6 +6,11 @@ load( load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) + compile_pip_requirements( name = "requirements", extra_args = [ @@ -43,6 +48,8 @@ cc_binary( "@torch//:libtorch_python", ] + if_cuda_is_configured([ "@xla//xla/stream_executor:cuda_platform", + ]) + if_rocm_is_configured([ + "@xla//xla/stream_executor:rocm_platform", ]), ) diff --git a/build_util.py b/build_util.py index 487f5116323..8be7c6bfa3b 100644 --- a/build_util.py +++ b/build_util.py @@ -47,6 +47,8 @@ def bazel_options_from_env() -> Iterable[str]: bazel_flags.append('--config=cuda') if check_env_flag('XLA_CPU_USE_ACL'): bazel_flags.append('--config=acl') + if check_env_flag('XLA_ROCM'): + bazel_flags.append('--config=rocm') return bazel_flags diff --git a/plugins/rocm/pyproject.toml b/plugins/rocm/pyproject.toml new file mode 100644 index 00000000000..3c9d029cff1 --- /dev/null +++ b/plugins/rocm/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "torch_xla_rocm_plugin" +version = "0.0.1" +description = "PyTorch/XLA ROCM Plugin" + +[tool.setuptools.package-data] +torch_xla_rocm_plugin = ["lib/*.so"] + +[project.entry-points."torch_xla.plugins"] +rocm = "torch_xla_rocm_plugin:RocmPlugin" \ No newline at end of file diff --git a/plugins/rocm/setup.py b/plugins/rocm/setup.py new file mode 100644 index 00000000000..5bf74935257 --- /dev/null +++ b/plugins/rocm/setup.py @@ -0,0 +1,13 @@ +import os +import sys + +# add `build_util` to import path +sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) + +import build_util +import setuptools + +build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', + 'torch_xla_rocm_plugin/lib', ['--config=rocm']) + +setuptools.setup() diff --git a/plugins/rocm/torch_xla_rocm_plugin/__init__.py b/plugins/rocm/torch_xla_rocm_plugin/__init__.py new file mode 100644 index 00000000000..2f87b96b927 --- /dev/null +++ b/plugins/rocm/torch_xla_rocm_plugin/__init__.py @@ -0,0 +1,56 @@ +import os +from torch_xla.experimental import plugins +import torch_xla.utils.utils as xu + + +class RocmPlugin(plugins.DevicePlugin): + def _get_process_rank(self) -> int: + local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int, + xu.getenv_as("LOCAL_RANK", int, 0)) + global_process_rank = xu.getenv_as("RANK", int, local_process_rank) + + return local_process_rank, global_process_rank + + def _get_world_size(self) -> int: + local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int, + xu.getenv_as("LOCAL_WORLD_SIZE", int, 1)) + global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + + return local_world_size, global_world_size + + def library_path(self) -> str: + return os.path.join(os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') + + def physical_chip_count(self) -> int: + # TODO: default to actual device count + return xu.getenv_as('GPU_NUM_DEVICES', int, 1) + + def client_create_options(self) -> dict: + local_process_rank, global_process_rank = self._get_process_rank() + local_world_size, global_world_size = self._get_world_size() + + # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 + options = { + "platform_name": + "gpu", + # TODO(wcromar): make this configurable + "allocator": + "default", + "memory_fraction": + xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None), + "preallocate": + xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), + # Use all devices by default and when using SPMD + "visible_devices": [local_process_rank] + if local_world_size > 1 else None, + "node_id": + global_process_rank, + "num_nodes": + global_world_size, + } + + return {k: v for k, v in options.items() if v is not None} + + def requires_xla_coordinator(self) -> bool: + _, global_world_size = self._get_world_size() + return global_world_size > 1 \ No newline at end of file From 3a9391514e86b45577ef66ec953229c42a63bfe0 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Mon, 12 Aug 2024 08:45:05 +0000 Subject: [PATCH 2/5] Add missing options to .bazelrc --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 09ec72e0a25..ccd45ae61e3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -57,6 +57,8 @@ build:cuda --cxxopt=-DXLA_CUDA=1 build:rocm --repo_env TF_NEED_ROCM=1 build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm_hipcc=true +build:rocm --define=tensorflow_mkldnn_contraction_kernel=0 build:rocm --@xla//xla/python:enable_gpu=true build:rocm --define=xla_python_enable_gpu=true From 94fbec517bb81024c18e8091eb5bd29e659a96e4 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 20 Aug 2024 08:31:42 +0000 Subject: [PATCH 3/5] Differentiate between rocm and cuda --- torch_xla/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index e214e7a47a7..c96c2b8f706 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -7,7 +7,7 @@ import torch -if not torch.cuda.is_available(): +if not torch.cuda.is_available() or torch.version.hip: # Load _XLAC_cuda_functions to RTLD_GLOBAL, so that it can be used by _XLAC. flags = sys.getdlopenflags() sys.setdlopenflags(flags | os.RTLD_NOW | os.RTLD_GLOBAL) @@ -17,6 +17,7 @@ # Then, restore the original flags. sys.setdlopenflags(flags) + import _XLAC from ._internal import tpu from .version import __version__ From 35955aedde11625fe007a12e3b4fbe86da20c1cd Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 20 Aug 2024 13:06:38 +0000 Subject: [PATCH 4/5] Add configure_single_process to __init__.py --- plugins/rocm/torch_xla_rocm_plugin/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/rocm/torch_xla_rocm_plugin/__init__.py b/plugins/rocm/torch_xla_rocm_plugin/__init__.py index 2f87b96b927..939d50c59ad 100644 --- a/plugins/rocm/torch_xla_rocm_plugin/__init__.py +++ b/plugins/rocm/torch_xla_rocm_plugin/__init__.py @@ -25,6 +25,9 @@ def physical_chip_count(self) -> int: # TODO: default to actual device count return xu.getenv_as('GPU_NUM_DEVICES', int, 1) + def configure_single_process(self): + pass + def client_create_options(self) -> dict: local_process_rank, global_process_rank = self._get_process_rank() local_world_size, global_world_size = self._get_world_size() @@ -33,7 +36,6 @@ def client_create_options(self) -> dict: options = { "platform_name": "gpu", - # TODO(wcromar): make this configurable "allocator": "default", "memory_fraction": From dc1874529fcfb2c9ccae20e305980de733241cf1 Mon Sep 17 00:00:00 2001 From: Milica Makevic Date: Tue, 20 Aug 2024 13:50:08 +0000 Subject: [PATCH 5/5] Fix formatting --- torch_xla/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index c96c2b8f706..5b119a6dc8d 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -17,7 +17,6 @@ # Then, restore the original flags. sys.setdlopenflags(flags) - import _XLAC from ._internal import tpu from .version import __version__