Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Jan 18, 2024
1 parent e22cfda commit e4c3d85
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 19 deletions.
18 changes: 13 additions & 5 deletions build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import shutil


def check_env_flag(name: str, default: str = '') -> bool:
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']

Expand Down Expand Up @@ -33,7 +34,7 @@ def bazel_options_from_env() -> Iterable[str]:
cache_silo_name = os.getenv('SILO_NAME', default='dev')
if cache_silo_name:
bazel_flags.append('--remote_default_exec_properties=cache-silo-key=%s' %
cache_silo_name)
cache_silo_name)

if check_env_flag('BUILD_CPP_TESTS', default='0'):
bazel_flags.append('//test/cpp:all')
Expand All @@ -53,8 +54,14 @@ def bazel_options_from_env() -> Iterable[str]:

return bazel_flags

def bazel_build(bazel_target: str, destination_dir: str, options: Iterable[str] = []):
bazel_argv = ['bazel', 'build', bazel_target, f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"]

def bazel_build(bazel_target: str,
destination_dir: str,
options: Iterable[str] = []):
bazel_argv = [
'bazel', 'build', bazel_target,
f"--symlink_prefix={os.path.join(os.getcwd(), 'bazel-')}"
]

# Remove duplicated flags because they confuse bazel
flags = set(bazel_options_from_env() + options)
Expand All @@ -63,11 +70,12 @@ def bazel_build(bazel_target: str, destination_dir: str, options: Iterable[str]
print(' '.join(bazel_argv), flush=True)
subprocess.check_call(bazel_argv, stdout=sys.stdout, stderr=sys.stderr)

target_path = bazel_target.replace('@xla//', 'external/xla/').replace('//', '').replace(':', '/')
target_path = bazel_target.replace('@xla//', 'external/xla/').replace(
'//', '').replace(':', '/')
output_path = os.path.join('bazel-bin', target_path)
output_filename = os.path.basename(output_path)

if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
os.makedirs(destination_dir)

shutil.copyfile(output_path, os.path.join(destination_dir, output_filename))
3 changes: 2 additions & 1 deletion plugins/cpu/setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import build_util
import setuptools

build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so', 'torch_xla_cpu_plugin/lib')
build_util.bazel_build('//plugins/cpu:pjrt_c_api_cpu_plugin.so',
'torch_xla_cpu_plugin/lib')

setuptools.setup()
5 changes: 4 additions & 1 deletion plugins/cpu/torch_xla_cpu_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from torch_xla.experimental import plugins
from torch_xla._internal import tpu


class CpuPlugin(plugins.DevicePlugin):

def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'lib/pjrt_c_api_cpu_plugin.so')
return os.path.join(
os.path.dirname(__file__), 'lib/pjrt_c_api_cpu_plugin.so')

def physical_chip_count(self) -> int:
return 1
3 changes: 2 additions & 1 deletion plugins/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import build_util
import setuptools

build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', 'torch_xla_cuda_plugin/lib', ['--config=cuda'])
build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so',
'torch_xla_cuda_plugin/lib', ['--config=cuda'])

setuptools.setup()
28 changes: 19 additions & 9 deletions plugins/cuda/torch_xla_cuda_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from torch_xla.experimental import plugins
import torch_xla.utils.utils as xu


class GpuPlugin(plugins.DevicePlugin):

def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
Expand All @@ -18,14 +21,21 @@ def client_create_options(self) -> dict:

# The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67
return {
"platform_name": "gpu",
# TODO(wcromar): make this configurable
"allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default",
"memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75),
"preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True),
"visible_devices": [local_process_rank],
"node_id": global_process_rank,
"num_nodes": global_world_size,
"platform_name":
"gpu",
# TODO(wcromar): make this configurable
"allocator":
"cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool,
False) else "default",
"memory_fraction":
xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75),
"preallocate":
xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True),
"visible_devices": [local_process_rank],
"node_id":
global_process_rank,
"num_nodes":
global_world_size,
}

def requires_xla_coordinator(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def _get_build_mode():
return sys.argv[i]



def get_git_head_sha(base_dir):
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=base_dir).decode('ascii').strip()
Expand Down Expand Up @@ -191,6 +190,7 @@ def run(self):
# Copy libtpu.so into torch_xla/lib
maybe_bundle_libtpu(base_dir)


class BazelExtension(Extension):
"""A C/C++ extension that is defined as a Bazel BUILD target."""

Expand Down Expand Up @@ -222,7 +222,8 @@ def bazel_build(self, ext):
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
]

cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C,
'_GLIBCXX_USE_CXX11_ABI', None)
if cxx_abi is not None:
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')

Expand Down

0 comments on commit e4c3d85

Please sign in to comment.