From e4c3d859488cb6f802699fc51fc2850d7bd85232 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 18 Jan 2024 21:15:13 +0000 Subject: [PATCH] formatting --- build_util.py | 18 ++++++++---- plugins/cpu/setup.py | 3 +- plugins/cpu/torch_xla_cpu_plugin/__init__.py | 5 +++- plugins/cuda/setup.py | 3 +- .../cuda/torch_xla_cuda_plugin/__init__.py | 28 +++++++++++++------ setup.py | 5 ++-- 6 files changed, 43 insertions(+), 19 deletions(-) diff --git a/build_util.py b/build_util.py index 236c86c151b4..78e4bd5e4539 100644 --- a/build_util.py +++ b/build_util.py @@ -4,6 +4,7 @@ import sys import shutil + def check_env_flag(name: str, default: str = '') -> bool: return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] @@ -33,7 +34,7 @@ def bazel_options_from_env() -> Iterable[str]: cache_silo_name = os.getenv('SILO_NAME', default='dev') if cache_silo_name: bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' % - cache_silo_name) + cache_silo_name) if check_env_flag('BUILD_CPP_TESTS', default='0'): bazel_flags.append('//test/cpp:all') @@ -53,8 +54,14 @@ def bazel_options_from_env() -> Iterable[str]: return bazel_flags -def bazel_build(bazel_target: str, destination_dir: str, options: Iterable[str] = []): - bazel_argv = ['bazel', 'build', bazel_target, f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"] + +def bazel_build(bazel_target: str, + destination_dir: str, + options: Iterable[str] = []): + bazel_argv = [ + 'bazel', 'build', bazel_target, + f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}" + ] # Remove duplicated flags because they confuse bazel flags = set(bazel_options_from_env() + options) @@ -63,11 +70,12 @@ def bazel_build(bazel_target: str, destination_dir: str, options: Iterable[str] print(' '.join(bazel_argv), flush=True) subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr) - target_path = bazel_target.replace('@xla//', 'external/xla/').replace('//', '').replace(':', '/') + target_path = bazel_target.replace('@xla//', 'external/xla/').replace( + '//', '').replace(':', '/') output_path = os.path.join('bazel-bin', target_path) output_filename = os.path.basename(output_path) if not os.path.exists(destination_dir): - os.makedirs(destination_dir) + os.makedirs(destination_dir) shutil.copyfile(output_path, os.path.join(destination_dir, output_filename)) diff --git a/plugins/cpu/setup.py b/plugins/cpu/setup.py index 9bf28a0ac3b1..9182c86a635c 100644 --- a/plugins/cpu/setup.py +++ b/plugins/cpu/setup.py @@ -1,6 +1,7 @@ import build_util import setuptools -build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', 'torch_xla_cpu_plugin/lib') +build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', + 'torch_xla_cpu_plugin/lib') setuptools.setup() diff --git a/plugins/cpu/torch_xla_cpu_plugin/__init__.py b/plugins/cpu/torch_xla_cpu_plugin/__init__.py index e22f4c65a06c..ab2588aae9d2 100644 --- a/plugins/cpu/torch_xla_cpu_plugin/__init__.py +++ b/plugins/cpu/torch_xla_cpu_plugin/__init__.py @@ -2,9 +2,12 @@ from torch_xla.experimental import plugins from torch_xla._internal import tpu + class CpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: - return os.path.join(os.path.dirname(__file__), 'lib/pjrt_c_api_cpu_plugin.so') + return os.path.join( + os.path.dirname(__file__), 'lib/pjrt_c_api_cpu_plugin.so') def physical_chip_count(self) -> int: return 1 diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py index 0b6928db9173..8f6aaf00b74f 100644 --- a/plugins/cuda/setup.py +++ b/plugins/cuda/setup.py @@ -1,6 +1,7 @@ import build_util import setuptools -build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', 'torch_xla_cuda_plugin/lib', ['--config=cuda']) +build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', + 'torch_xla_cuda_plugin/lib', ['--config=cuda']) setuptools.setup() diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index c5eabbdc2f07..d480a0439b62 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -2,9 +2,12 @@ from torch_xla.experimental import plugins import torch_xla.utils.utils as xu + class GpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: - return os.path.join(os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') + return os.path.join( + os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') def physical_chip_count(self) -> int: # TODO: default to actual device count @@ -18,14 +21,21 @@ def client_create_options(self) -> dict: # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 return { - "platform_name": "gpu", - # TODO(wcromar): make this configurable - "allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", - "memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), - "preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), - "visible_devices": [local_process_rank], - "node_id": global_process_rank, - "num_nodes": global_world_size, + "platform_name": + "gpu", + # TODO(wcromar): make this configurable + "allocator": + "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, + False) else "default", + "memory_fraction": + xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), + "preallocate": + xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), + "visible_devices": [local_process_rank], + "node_id": + global_process_rank, + "num_nodes": + global_world_size, } def requires_xla_coordinator(self) -> bool: diff --git a/setup.py b/setup.py index 263a22b10886..92d2b7ca5634 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,6 @@ def _get_build_mode(): return sys.argv[i] - def get_git_head_sha(base_dir): xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=base_dir).decode('ascii').strip() @@ -191,6 +190,7 @@ def run(self): # Copy libtpu.so into torch_xla/lib maybe_bundle_libtpu(base_dir) + class BazelExtension(Extension): """A C/C++ extension that is defined as a Bazel BUILD target.""" @@ -222,7 +222,8 @@ def bazel_build(self, ext): f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] - cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) + cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, + '_GLIBCXX_USE_CXX11_ABI', None) if cxx_abi is not None: bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')