Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Add initial ROCm PJRT support #7896

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ build:cuda --@local_config_cuda//:enable_cuda
build:cuda --define=xla_python_enable_gpu=true
build:cuda --cxxopt=-DXLA_CUDA=1

build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain

build:rocm --define=using_rocm_hipcc=true
build:rocm --define=tensorflow_mkldnn_contraction_kernel=0

build:rocm --@xla//xla/python:enable_gpu=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --cxxopt=-DXLA_ROCM=1

# Coverage with cuda/gcc/nvcc requires manually setting coverage flags.
coverage:cuda --per_file_copt=third_party/.*,torch_xla/.*@--coverage
coverage:cuda --linkopt=-lgcov
Expand Down
7 changes: 7 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ load(
load("@python//:defs.bzl", "compile_pip_requirements")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")

load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_is_configured",
)

compile_pip_requirements(
name = "requirements",
extra_args = [
Expand Down Expand Up @@ -43,6 +48,8 @@ cc_binary(
"@torch//:libtorch_python",
] + if_cuda_is_configured([
"@xla//xla/stream_executor:cuda_platform",
]) + if_rocm_is_configured([
"@xla//xla/stream_executor:rocm_platform",
]),
)

Expand Down
2 changes: 2 additions & 0 deletions build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def bazel_options_from_env() -> Iterable[str]:
bazel_flags.append('--config=cuda')
if check_env_flag('XLA_CPU_USE_ACL'):
bazel_flags.append('--config=acl')
if check_env_flag('XLA_ROCM'):
bazel_flags.append('--config=rocm')

return bazel_flags

Expand Down
14 changes: 14 additions & 0 deletions plugins/rocm/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "torch_xla_rocm_plugin"
version = "0.0.1"
description = "PyTorch/XLA ROCM Plugin"

[tool.setuptools.package-data]
torch_xla_rocm_plugin = ["lib/*.so"]

[project.entry-points."torch_xla.plugins"]
rocm = "torch_xla_rocm_plugin:RocmPlugin"
13 changes: 13 additions & 0 deletions plugins/rocm/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
import sys

# add `build_util` to import path
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

import build_util
import setuptools

build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so',
'torch_xla_rocm_plugin/lib', ['--config=rocm'])

setuptools.setup()
58 changes: 58 additions & 0 deletions plugins/rocm/torch_xla_rocm_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
from torch_xla.experimental import plugins
import torch_xla.utils.utils as xu


class RocmPlugin(plugins.DevicePlugin):
def _get_process_rank(self) -> int:
local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int,
xu.getenv_as("LOCAL_RANK", int, 0))
global_process_rank = xu.getenv_as("RANK", int, local_process_rank)

return local_process_rank, global_process_rank

def _get_world_size(self) -> int:
local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int,
xu.getenv_as("LOCAL_WORLD_SIZE", int, 1))
global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size)

return local_world_size, global_world_size

def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
return xu.getenv_as('GPU_NUM_DEVICES', int, 1)

def configure_single_process(self):
pass

def client_create_options(self) -> dict:
local_process_rank, global_process_rank = self._get_process_rank()
local_world_size, global_world_size = self._get_world_size()

# The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67
options = {
"platform_name":
"gpu",
"allocator":
"default",
"memory_fraction":
xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None),
"preallocate":
xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None),
# Use all devices by default and when using SPMD
"visible_devices": [local_process_rank]
if local_world_size > 1 else None,
"node_id":
global_process_rank,
"num_nodes":
global_world_size,
}

return {k: v for k, v in options.items() if v is not None}

def requires_xla_coordinator(self) -> bool:
_, global_world_size = self._get_world_size()
return global_world_size > 1
2 changes: 1 addition & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

if not torch.cuda.is_available():
if not torch.cuda.is_available() or torch.version.hip:
# Load _XLAC_cuda_functions to RTLD_GLOBAL, so that it can be used by _XLAC.
flags = sys.getdlopenflags()
sys.setdlopenflags(flags | os.RTLD_NOW | os.RTLD_GLOBAL)
Expand Down