From 2492ddc516b73eda4299fa11d3a04b29dcd84046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Sun, 11 Nov 2018 13:45:41 -0800 Subject: [PATCH] Initial import as a separate torch_xla extension Skip tests which require additional work. --- .gitmodules | 3 + LICENSE | 29 + README.md | 23 + build_torch_xla_libs.sh | 23 + setup.py | 116 +++ test/test_operations.py | 912 ++++++++++++++++++ third_party/tensorflow | 1 + third_party/xla_client/BUILD | 88 ++ third_party/xla_client/computation_client.cc | 167 ++++ third_party/xla_client/computation_client.h | 75 ++ .../xla_client/tf_exported_symbols.lds | 9 + third_party/xla_client/tf_version_script.lds | 13 + third_party/xla_client/unique.h | 35 + .../xla_client/xla_computation_client.cc | 165 ++++ .../xla_client/xla_computation_client.h | 99 ++ third_party/xla_client/xla_util.cc | 36 + third_party/xla_client/xla_util.h | 30 + .../xla_client/xrt_computation_client.cc | 632 ++++++++++++ .../xla_client/xrt_computation_client.h | 330 +++++++ torch_xla/__init__.py | 2 + torch_xla/csrc/batch_norm.cpp | 44 + torch_xla/csrc/batch_norm.h | 32 + torch_xla/csrc/convolution.cpp | 302 ++++++ torch_xla/csrc/convolution.h | 34 + torch_xla/csrc/cross_replica_reduces.cpp | 17 + torch_xla/csrc/cross_replica_reduces.h | 13 + torch_xla/csrc/data_ops.cpp | 170 ++++ torch_xla/csrc/data_ops.h | 38 + torch_xla/csrc/elementwise.cpp | 62 ++ torch_xla/csrc/elementwise.h | 26 + torch_xla/csrc/graph_context.cpp | 60 ++ torch_xla/csrc/graph_context.h | 92 ++ torch_xla/csrc/helpers.cpp | 72 ++ torch_xla/csrc/helpers.h | 56 ++ torch_xla/csrc/init_python_bindings.cpp | 172 ++++ torch_xla/csrc/init_python_bindings.h | 12 + torch_xla/csrc/log_softmax.cpp | 80 ++ torch_xla/csrc/log_softmax.h | 18 + torch_xla/csrc/module.cpp | 605 ++++++++++++ torch_xla/csrc/module.h | 144 +++ torch_xla/csrc/nll_loss.cpp | 95 ++ torch_xla/csrc/nll_loss.h | 19 + torch_xla/csrc/passes/eval_static_size.cpp | 41 + torch_xla/csrc/passes/eval_static_size.h | 12 + .../passes/remove_unused_forward_outputs.cpp | 124 +++ .../passes/remove_unused_forward_outputs.h | 12 + .../passes/replace_untraced_operators.cpp | 98 ++ .../csrc/passes/replace_untraced_operators.h | 12 + .../passes/threshold_backward_peephole.cpp | 41 + .../csrc/passes/threshold_backward_peephole.h | 14 + torch_xla/csrc/pooling.cpp | 154 +++ torch_xla/csrc/pooling.h | 28 + torch_xla/csrc/reduction.cpp | 20 + torch_xla/csrc/reduction.h | 14 + torch_xla/csrc/tensor.cpp | 559 +++++++++++ torch_xla/csrc/tensor.h | 190 ++++ torch_xla/csrc/torch_util.cpp | 40 + torch_xla/csrc/torch_util.h | 20 + torch_xla/csrc/translator.cpp | 483 ++++++++++ torch_xla/csrc/translator.h | 60 ++ 60 files changed, 6873 insertions(+) create mode 100644 .gitmodules create mode 100644 LICENSE create mode 100755 build_torch_xla_libs.sh create mode 100644 setup.py create mode 100644 test/test_operations.py create mode 160000 third_party/tensorflow create mode 100644 third_party/xla_client/BUILD create mode 100644 third_party/xla_client/computation_client.cc create mode 100644 third_party/xla_client/computation_client.h create mode 100644 third_party/xla_client/tf_exported_symbols.lds create mode 100644 third_party/xla_client/tf_version_script.lds create mode 100644 third_party/xla_client/unique.h create mode 100644 third_party/xla_client/xla_computation_client.cc create mode 100644 third_party/xla_client/xla_computation_client.h create mode 100644 third_party/xla_client/xla_util.cc create mode 100644 third_party/xla_client/xla_util.h create mode 100644 third_party/xla_client/xrt_computation_client.cc create mode 100644 third_party/xla_client/xrt_computation_client.h create mode 100644 torch_xla/__init__.py create mode 100644 torch_xla/csrc/batch_norm.cpp create mode 100644 torch_xla/csrc/batch_norm.h create mode 100644 torch_xla/csrc/convolution.cpp create mode 100644 torch_xla/csrc/convolution.h create mode 100644 torch_xla/csrc/cross_replica_reduces.cpp create mode 100644 torch_xla/csrc/cross_replica_reduces.h create mode 100644 torch_xla/csrc/data_ops.cpp create mode 100644 torch_xla/csrc/data_ops.h create mode 100644 torch_xla/csrc/elementwise.cpp create mode 100644 torch_xla/csrc/elementwise.h create mode 100644 torch_xla/csrc/graph_context.cpp create mode 100644 torch_xla/csrc/graph_context.h create mode 100644 torch_xla/csrc/helpers.cpp create mode 100644 torch_xla/csrc/helpers.h create mode 100644 torch_xla/csrc/init_python_bindings.cpp create mode 100644 torch_xla/csrc/init_python_bindings.h create mode 100644 torch_xla/csrc/log_softmax.cpp create mode 100644 torch_xla/csrc/log_softmax.h create mode 100644 torch_xla/csrc/module.cpp create mode 100644 torch_xla/csrc/module.h create mode 100644 torch_xla/csrc/nll_loss.cpp create mode 100644 torch_xla/csrc/nll_loss.h create mode 100644 torch_xla/csrc/passes/eval_static_size.cpp create mode 100644 torch_xla/csrc/passes/eval_static_size.h create mode 100644 torch_xla/csrc/passes/remove_unused_forward_outputs.cpp create mode 100644 torch_xla/csrc/passes/remove_unused_forward_outputs.h create mode 100644 torch_xla/csrc/passes/replace_untraced_operators.cpp create mode 100644 torch_xla/csrc/passes/replace_untraced_operators.h create mode 100644 torch_xla/csrc/passes/threshold_backward_peephole.cpp create mode 100644 torch_xla/csrc/passes/threshold_backward_peephole.h create mode 100644 torch_xla/csrc/pooling.cpp create mode 100644 torch_xla/csrc/pooling.h create mode 100644 torch_xla/csrc/reduction.cpp create mode 100644 torch_xla/csrc/reduction.h create mode 100644 torch_xla/csrc/tensor.cpp create mode 100644 torch_xla/csrc/tensor.h create mode 100644 torch_xla/csrc/torch_util.cpp create mode 100644 torch_xla/csrc/torch_util.h create mode 100644 torch_xla/csrc/translator.cpp create mode 100644 torch_xla/csrc/translator.h diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..b1cb803bfcd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/tensorflow"] + path = third_party/tensorflow + url = https://github.com/tensorflow/tensorflow.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000000..6fa55f3e9b8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2018 Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index e69de29bb2d..e10ab972fc0 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,23 @@ +# How To Build And Run PyTorch For TPU + +To build: + +* Build PyTorch from source, following the regular [instructions](https://github.com/pytorch/pytorch#from-source). +* Clone this repository in the root folder of the PyTorch sources used for the previous step. + Run `git submodule update --init` to get the third-party dependencies and `python setup.py install` to build and install the extension. + +To run the tests, follow __one__ of the options below: + +* Run on CPU using the local client: + + `export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"` + +* Run on CPU using the XRT client: + + `export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localhost/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localhost:0;"` + +* Run on TPU using the XRT client: + + `export XLA_USE_XRT=1 XRT_DEVICE_MAP="TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0" XRT_WORKERS="tpu_worker:0;grpc://localhost:51000"` + +Then run `python test/test_operations.py`. Some of the tests are currently skipped. diff --git a/build_torch_xla_libs.sh b/build_torch_xla_libs.sh new file mode 100755 index 00000000000..ccd483086cc --- /dev/null +++ b/build_torch_xla_libs.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -ex + +cd "$(dirname "$0")" +PWD=`printf "%q\n" "$(pwd)"` +BASE_DIR="$PWD" +echo $BASE_DIR +THIRD_PARTY_DIR="$BASE_DIR/third_party" + +cp -r -f $THIRD_PARTY_DIR/xla_client $THIRD_PARTY_DIR/tensorflow/tensorflow/compiler/xla/ + +pushd $THIRD_PARTY_DIR/tensorflow +git reset --hard +git clean -f +bazel build -c opt //tensorflow/compiler/xla/xla_client:libxla_computation_client.so +popd + +mkdir -p torch_xla/lib +chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so +cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so torch_xla/lib +chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so +cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so torch_xla/lib diff --git a/setup.py b/setup.py new file mode 100644 index 00000000000..ce3b3d909b0 --- /dev/null +++ b/setup.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python + +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CppExtension +import os +import platform +import subprocess +import sys + + +def _check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + +torch_xla_sources = [ + 'torch_xla/csrc/batch_norm.cpp', + 'torch_xla/csrc/convolution.cpp', + 'torch_xla/csrc/cross_replica_reduces.cpp', + 'torch_xla/csrc/data_ops.cpp', + 'torch_xla/csrc/elementwise.cpp', + 'torch_xla/csrc/graph_context.cpp', + 'torch_xla/csrc/helpers.cpp', + 'torch_xla/csrc/init_python_bindings.cpp', + 'torch_xla/csrc/log_softmax.cpp', + 'torch_xla/csrc/module.cpp', + 'torch_xla/csrc/nll_loss.cpp', + 'torch_xla/csrc/pooling.cpp', + 'torch_xla/csrc/reduction.cpp', + 'torch_xla/csrc/tensor.cpp', + 'torch_xla/csrc/torch_util.cpp', + 'torch_xla/csrc/translator.cpp', + 'torch_xla/csrc/passes/eval_static_size.cpp', + 'torch_xla/csrc/passes/remove_unused_forward_outputs.cpp', + 'torch_xla/csrc/passes/replace_untraced_operators.cpp', + 'torch_xla/csrc/passes/threshold_backward_peephole.cpp', +] + +build_libs_cmd = './build_torch_xla_libs.sh' + +if subprocess.call(build_libs_cmd) != 0: + print("Failed to run '{}'".format(build_libs_cmd)) + sys.exit(1) + +# Constant known variables used throughout this file +cwd = os.path.dirname(os.path.abspath(__file__)) +lib_path = os.path.join(cwd, 'torch_xla', 'lib') +pytorch_source_path = os.getenv('PYTORCH_SOURCE_PATH', '..') +third_party_path = os.path.join(cwd, 'third_party') + +include_dirs = [ + third_party_path + '/tensorflow/bazel-tensorflow', + third_party_path + '/tensorflow/bazel-genfiles', + third_party_path + '/tensorflow/bazel-tensorflow/external/protobuf_archive/src', + third_party_path + '/tensorflow/bazel-tensorflow/external/eigen_archive', + third_party_path + '/tensorflow/bazel-tensorflow/external/com_google_absl', +] +include_dirs += [ + pytorch_source_path, + os.path.join(pytorch_source_path, 'torch', 'csrc'), + os.path.join(pytorch_source_path, 'torch', 'lib', 'tmp_install', 'include'), +] + +library_dirs = [] +library_dirs.append(lib_path) + +extra_link_args = [] + +DEBUG = _check_env_flag('DEBUG') +IS_WINDOWS = (platform.system() == 'Windows') +IS_DARWIN = (platform.system() == 'Darwin') +IS_LINUX = (platform.system() == 'Linux') + + +def make_relative_rpath(path): + if IS_DARWIN: + return '-Wl,-rpath,@loader_path/' + path + elif IS_WINDOWS: + return '' + else: + return '-Wl,-rpath,$ORIGIN/' + path + +extra_compile_args = [] + +if DEBUG: + if IS_WINDOWS: + extra_link_args.append('/DEBUG:FULL') + else: + extra_compile_args += ['-O0', '-g'] + extra_link_args += ['-O0', '-g'] + +extra_link_args += ['-lxla_computation_client'] + +setup( + name='torch_xla', + version='0.1', + description='XLA bridge for PyTorch', + url='https://github.com/pytorch/xla', + author='Alex Suhan, Davide Libenzi', + author_email='asuhan@google.com', + # Exclude the build files. + packages=find_packages(exclude=['build']), + ext_modules=[ + CppExtension( + '_C', + torch_xla_sources, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + library_dirs=library_dirs, + extra_link_args=extra_link_args + [make_relative_rpath('torch_xla/lib')], + ), + ], + package_data={ + 'torch_xla': [ + 'lib/*.so*', + ] + }, + cmdclass={'build_ext': BuildExtension}) diff --git a/test/test_operations.py b/test/test_operations.py new file mode 100644 index 00000000000..1b3923c2532 --- /dev/null +++ b/test/test_operations.py @@ -0,0 +1,912 @@ +# Parse local options first, and rewrite the sys.argv[]. +# We need to do that before import "common", as otherwise we get an error for +# unrecognized arguments. +import argparse +import sys + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument('--replicated', action='store_true') +parser.add_argument('--long_test', action='store_true') +FLAGS, leftovers = parser.parse_known_args() +sys.argv = [sys.argv[0]] + leftovers +sys.path.append('../test') + +# Normal imports section starts here. +import collections +from common_utils import TestCase, run_tests, iter_indices +import itertools +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch_xla +import unittest + + +DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices']) + + +class Holder(object): + pass + + +def _as_list(t): + return t if isinstance(t, (tuple, list)) else [t] + + +def _get_device_support(devname): + assert devname in ['TPU', 'CPU'] + # If the Cloud TPU config file is present, we support TPUs. + if (os.path.isfile(os.path.join(os.environ['HOME'], '.pytorch_tpu.conf')) or + os.environ.get('XRT_TPU_CONFIG', None)): + if devname == 'TPU': + return DeviceSupport(num_devices=int( + os.environ.get('TPU_NUM_DEVICES', 8))) + return DeviceSupport(num_devices=int( + os.environ.get('CPU_NUM_DEVICES', 1))) + xrt = os.environ.get('XLA_USE_XRT', None) + if xrt is None or int(xrt) == 0: + xla_platform = os.environ.get('XLA_PLATFORM', None) + if xla_platform == devname: + return DeviceSupport(num_devices=1) + else: + xrt_devmap = os.environ.get('XRT_DEVICE_MAP', None) + if xrt_devmap is None: + return None + num_devices = 0 + for dev_spec in xrt_devmap.split('|'): + dev_parts = dev_spec.split(';') + if dev_parts[0].startswith(devname): + num_devices += 1 + if num_devices > 0: + return DeviceSupport(num_devices=num_devices) + return None + + +def _support_replicated(devname, num_devices): + devsup = _get_device_support(devname) + if not devsup: + return False + return devsup.num_devices >= num_devices + + +def _random_inputs(shapes, num_replicas=1): + random_tensors = [] + for _ in range(0, num_replicas): + replica_inputs = [] + for shape in shapes: + replica_inputs.append(torch.randn(*shape)) + random_tensors.append(tuple(replica_inputs)) + return tuple(random_tensors) + + +def _random_like(tensor_list): + random_tensors = [] + for o in tensor_list: + if o.dtype == torch.float32 or o.dtype == torch.float64: + random_tensors += [torch.randn(*o.shape, dtype=o.dtype)] + elif o.dtype == torch.int64: + # TODO remove this, we shouldn't be needing to pass random_tensor for long types + random_tensors += [torch.empty_like(o)] + else: + raise RuntimeError('Unsupported type: ', o.dtype) + return random_tensors + + +def _zeros_like(tensor_list): + zeros_tensors = [] + for o in tensor_list: + if o.dtype == torch.float32 or o.dtype == torch.float64: + zeros_tensors += [torch.zeros(*o.shape, dtype=o.dtype)] + elif o.dtype == torch.int64: + # TODO remove this, we shouldn't be needing to pass zeros_tensor for long types + zeros_tensors += [torch.zeros_like(o)] + else: + raise RuntimeError('Unsupported type: ', o.dtype) + return zeros_tensors + + +def _dump_differences(target, result, rtol=1e-5, atol=1e-3): + env = Holder() + env.max_diff = 0.0 + env.max_rel = None + env.max_index = None + + def check_values(a, b, index): + r = max(abs(a), abs(b)) * rtol + diff = abs(a - b) + if diff > min(r, atol): + print('a={}\tb={}\tdiff={}\tindex={}'.format(a, b, diff, index)) + if diff > env.max_diff: + env.max_diff = diff + env.max_rel = diff / max(abs(a), abs(b)) + env.max_index = index + + if isinstance(target, torch.Tensor): + assert isinstance(result, torch.Tensor) + assert target.size() == result.size() + for i in iter_indices(target): + check_values(target[i], result[i], i) + elif isinstance(target, (list, tuple)): + assert isinstance(result, (list, tuple)) + assert len(target) == len(result) + for i, v in enumerate(target): + check_values(v, result[i], [i]) + elif isinstance(target, float): + assert isinstance(result, float) + check_values(target, result, []) + if env.max_index is not None: + print('\nmax_diff={}\tmax_rel={}\tindex={}'.format( + env.max_diff, env.max_rel, env.max_index)) + + +def _xla_run(model, input, device='TPU'): + if isinstance(input, (tuple, list)): + traced_model = torch.jit.trace(model, *input[0]) + xla_model = torch_xla._C.XlaModule( + traced_model, use_full_conv_precision=True) + input_xla = [] + for n, replica_input in enumerate(input): + xla_replica_input = [] + for i in replica_input: + xla_replica_input.append( + torch_xla._C.XLATensor(i, '{}:{}'.format(device, n))) + input_xla.append(tuple(xla_replica_input)) + output_xla = xla_model(*input_xla) + output = [] + for xla_replica_outputs in output_xla: + replica_outputs = [] + for o in _as_list(xla_replica_outputs): + replica_outputs.append(o.to_tensor()) + output.append(tuple(replica_outputs)) + return tuple(output) + else: + traced_model = torch.jit.trace(model, input) + xla_model = torch_xla._C.XlaModule( + traced_model, use_full_conv_precision=True) + input_xla = torch_xla._C.XLATensor(input) + output_xla = xla_model(tuple([input_xla])) + return output_xla[0].to_tensor() + + +def _forward_passes(graph): + torch._C._jit_pass_canonicalize(graph) + torch_xla._C._jit_pass_eval_static_size(graph) + torch._C._jit_pass_constant_propagation(graph) + torch_xla._C._jit_pass_replace_untraced_operators(graph) + torch._C._jit_pass_dce(graph) + + +def _backward_passes(graph): + torch._C._jit_pass_specialize_undef(graph) + torch_xla._C._jit_pass_eval_static_size(graph) + torch._C._jit_pass_constant_propagation(graph) + torch_xla._C._jit_pass_threshold_backward_peephole(graph) + torch._C._jit_pass_dce(graph) + + +class XlaTestCase(TestCase): + def assertEqualRel(self, out, expected, rel_err=1e-2, abs_err=1e-5): + try: + diff_tensor = (out - expected).abs() + max_rel_err = torch.max(out.abs(), expected.abs()) * rel_err + # Allow higher relative differences as long as we're still below the + # absolute error. + max_abs_err = torch.max(max_rel_err, torch.ones_like(out) * abs_err) + super(XlaTestCase, self).assertEqual(diff_tensor.size(), + max_abs_err.size()) + if torch.le(diff_tensor, max_abs_err).min().item() == 0: + self.fail('Relative error higher than the maximum tolerance') + except: + _dump_differences(out, expected, rtol=rel_err, atol=abs_err) + raise + + def assertEqualDbg(self, out, expected): + try: + super(XlaTestCase, self).assertEqual(out, expected) + except: + _dump_differences(out, expected, rtol=1e-8, atol=1e-8) + raise + + def compareReplicated(self, model, inputs, xla_outputs): + self.assertEqual(len(inputs), len(xla_outputs)) + for i, input in enumerate(inputs): + expected = _as_list(model(*input)) + xla_output = _as_list(xla_outputs[i]) + self.assertEqual(len(expected), len(xla_output)) + for j, expected_tensor in enumerate(expected): + self.assertEqualDbg(xla_output[j], expected_tensor) + + +class TestMulAdd(XlaTestCase): + def test(self): + + class XlaMulAdd(nn.Module): + def forward(self, x, y): + return x * y + y + + x = torch.rand(3, 5) + y = torch.rand(3, 5) + model = XlaMulAdd() + traced_model = torch.jit.trace(model, (x, y)) + xla_model = torch_xla._C.XlaModule(traced_model) + inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)] + output_xla = xla_model((tuple(inputs_xla))) + expected = model(x, y) + self.assertEqualDbg(output_xla[0].to_tensor().data, expected.data) + + +class TestRelu(XlaTestCase): + def test(self): + + class XlaRelu(nn.Module): + def forward(self, x): + return F.relu(x) + + x = torch.randn(2, 1, 4, 6) + model = XlaRelu() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestThreshold(XlaTestCase): + def test(self): + + class XlaThreshold(nn.Module): + def __init__(self): + super(XlaThreshold, self).__init__() + self.threshold = nn.Threshold(0.4, 20) + + def forward(self, x): + return self.threshold(x) + + x = torch.rand(4, 2) + model = XlaThreshold() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestTranspose(XlaTestCase): + def test(self): + + class XlaTranspose(nn.Module): + def forward(self, x): + return torch.t(x) + + x = torch.rand(2, 3) + model = XlaTranspose() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestView(XlaTestCase): + def test(self): + + class XlaView(nn.Module): + def forward(self, x): + return x.view(-1, 16) + + x = torch.rand(4, 8) + model = XlaView() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +@unittest.skip('RuntimeError: differentiation of prim::ListConstruct is not supported, or it is missing necessary type information') +class TestStack(XlaTestCase): + def test(self): + + class XlaStack(nn.Module): + def __init__(self, dim): + super(XlaStack, self).__init__() + self.dim = dim + + def forward(self, x, y): + return torch.stack((x, y), self.dim) + + x = torch.rand(2, 5) + y = torch.rand(2, 5) + for dim in [0, 1]: + model = XlaStack(dim) + traced_model = torch.jit.trace(model, (x, y)) + xla_model = torch_xla._C.XlaModule(traced_model) + inputs_xla = [torch_xla._C.XLATensor(x), torch_xla._C.XLATensor(y)] + output_xla = xla_model((tuple(inputs_xla))) + expected = model(x, y) + self.assertEqualDbg(output_xla[0].to_tensor().data, expected.data) + + +class TestExpand(XlaTestCase): + def test(self): + + class XlaExpand(nn.Module): + def forward(self, x): + return x.expand(2, 5) + + x = torch.rand(5) + model = XlaExpand() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +@unittest.skip('Automatic shape inference not supported: f32[5] and f32[4,5]') +class TestLinear(XlaTestCase): + def test(self): + + class XlaLinear(nn.Module): + def __init__(self): + super(XlaLinear, self).__init__() + self.linear = nn.Linear(2, 5) + + def forward(self, x): + return self.linear(x) + + x = torch.rand(4, 2) + model = XlaLinear() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestNonContiguousTensor(XlaTestCase): + def test(self): + + class XlaPlusSelf(nn.Module): + def forward(self, x): + return x + x + + x = torch.rand(3, 7) + model = XlaPlusSelf() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + out_t = _xla_run(model, x.t()) + expected_t = model(x.t()) + self.assertEqualDbg(out_t.data, expected_t.data) + self.assertEqualDbg(out_t.data, out.t().data) + + +@unittest.skip('Pending autodiff support') +class TestConv(XlaTestCase): + def test(self): + + class XlaConv(nn.Module): + def __init__(self, stride, padding, bias): + super(XlaConv, self).__init__() + self.conv = nn.Conv2d(10, 100, 5, stride=stride, + padding=padding, bias=bias) + + def forward(self, x): + return self.conv(x) + + for stride in range(1, 4): + for padding in range(0, 3): + for bias in [True, False]: + x = torch.randn(32, 10, 28, 28) + model = XlaConv(stride, padding, bias) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualRel(out.data, expected.data) + + +class TestMaxPool(XlaTestCase): + def test(self): + + class XlaMaxPool(nn.Module): + def __init__(self, stride, padding): + super(XlaMaxPool, self).__init__() + self.stride = stride + self.padding = padding + + def forward(self, x): + return F.max_pool2d(x, 3, stride=self.stride, + padding=self.padding) + + x = torch.rand(1, 64, 112, 112) + for stride in [None, 2]: + for padding in [0, 1]: + model = XlaMaxPool(stride, padding) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestAvgPool(XlaTestCase): + def test(self): + + class XlaAvgPool(nn.Module): + def __init__(self, stride, padding, count_include_pad): + super(XlaAvgPool, self).__init__() + self.stride = stride + self.padding = padding + self.count_include_pad = count_include_pad + + def forward(self, x): + return F.avg_pool2d(x, 2, self.stride, self.padding, False, self.count_include_pad) + + x = torch.rand(1, 1, 3, 3) + for stride in [1, 2, None]: + for padding in [0, 1]: + for count_include_pad in [False, True]: + model = XlaAvgPool(stride, padding, count_include_pad) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestLogSoftmax(XlaTestCase): + def test(self): + + class XlaLogSoftmax(nn.Module): + def __init__(self, dim): + super(XlaLogSoftmax, self).__init__() + self.dim = dim + + def forward(self, x): + return F.log_softmax(x, self.dim) + + x = torch.rand(5, 3, 4, 2) + for dim in range(0, x.dim()): + model = XlaLogSoftmax(dim) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualRel(out.data, expected.data, rel_err=1e-4, abs_err=1) + + +@unittest.skip('Pending autodiff support') +class TestBatchNorm(XlaTestCase): + def test(self): + + class XlaBatchNorm(nn.Module): + def __init__(self, training): + super(XlaBatchNorm, self).__init__() + if training: + self.bn = nn.BatchNorm2d(3) + else: + self.bn = nn.BatchNorm2d(3, track_running_stats=False) + + def forward(self, x): + return self.bn(x) + + x = torch.rand(14, 3, 5, 7) + model = XlaBatchNorm(True) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class XlaMNIST(nn.Module): + def __init__(self): + super(XlaMNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +@unittest.skip('Pending autodiff support') +class TestMNIST(XlaTestCase): + def test(self): + batch_size = 32 + x = torch.randn(batch_size, 1, 28, 28) + model = XlaMNIST() + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class TestSum(XlaTestCase): + def test(self): + + class XlaSum(nn.Module): + def __init__(self, dim): + super(XlaSum, self).__init__() + self.dim = dim + + def forward(self, x): + return x.sum(dim=self.dim) + + x = torch.randn(2, 3, 4, 6) + for dim in range(0, x.dim()): + model = XlaSum(dim) + out = _xla_run(model, x) + expected = model(x) + self.assertEqualDbg(out.data, expected.data) + + +class XlaNllLoss(nn.Module): + def __init__(self): + super(XlaNllLoss, self).__init__() + self.nll_loss = nn.NLLLoss() + + def forward(self, x, labels): + return self.nll_loss(x, labels) + + +@unittest.skip('Pending autodiff support') +class TestNllLoss(TestCase): + def test(self): + input = torch.randn(3, 5, requires_grad=True) + target = torch.empty(3, dtype=torch.long).random_(5) + model = XlaNllLoss() + traced_model = torch.jit.trace(model, (input, target)) + xla_model = torch_xla._C.XlaModule(traced_model) + xla_inputs = [torch_xla._C.XLATensor(input), torch_xla._C.XLATensor(target)] + output_xla = xla_model((tuple(xla_inputs))) + expected = model(input, target) + self.assertEqual(output_xla[0].to_tensor().data, expected.data) + + +class TestLongGraphChain(XlaTestCase): + def test(self): + orig_x = torch.Tensor([[1, 2], [3, 4]]) + orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) + x = orig_x + y = orig_y + xla_x = torch_xla._C.XLATensor(orig_x) + xla_y = torch_xla._C.XLATensor(orig_y) + for i in range(0, 10000): + x = x + 2 * y + xla_x = xla_x.add(2, xla_y) + self.assertEqualRel(x, xla_x.to_tensor(), rel_err=1e-3, abs_err=5) + + +class TestGradients(XlaTestCase): + def checkGrad(self, model, inputs, grad_outputs='random', xla=True, + rel_err=1e-2, abs_err=1e-5): + # Trace and symbolically differentiate + traced_model = torch.jit.trace(model, *inputs) + fwd = traced_model._get_method('forward') + _forward_passes(fwd.graph) + + inputs_params = inputs + list(model.parameters()) + inputs_params_buffers = inputs + list(fwd.params()) + + gradient = torch._C._jit_differentiate(fwd.graph) + _forward_passes(gradient.f) + _backward_passes(gradient.df) + + ############################################################## + # Run forward and backwarg graphs via jit interpreter + exec_f = torch._C.GraphExecutor(gradient.f, False) + exec_df = torch._C.GraphExecutor(gradient.df, False) + + # forward function + raw_outputs = exec_f(*inputs_params_buffers) + raw_outputs = _as_list(raw_outputs) + outputs = raw_outputs[:gradient.f_real_outputs] + + if grad_outputs == 'random': + grad_outputs = (_random_like(outputs) + + _zeros_like(raw_outputs[gradient.f_real_outputs:])) + + raw_grad_outputs = [] + raw_grad_outputs += grad_outputs + raw_grad_outputs += [inputs_params_buffers[i] + for i in gradient.df_input_captured_inputs] + raw_grad_outputs += [raw_outputs[i] + for i in gradient.df_input_captured_outputs] + + grad_inputs = exec_df(*raw_grad_outputs) + grad_inputs = _as_list(grad_inputs) + + ############################################################## + # backward with XLA + if xla: + xla_model = torch_xla._C.XlaModule(traced_model, use_full_conv_precision=True) + inputs_xla = [torch_xla._C.XLATensor(input) for input in inputs] + xla_model((tuple(inputs_xla))) + grads_output_xla = [torch_xla._C.XLATensor(grad_output) + for grad_output in grad_outputs[:gradient.f_real_outputs]] + xla_model.backward((tuple(grads_output_xla))) + grad_inputs_xla = [input_xla.grad.to_tensor() + for input_xla in inputs_xla] + grad_inputs_xla.extend([p.grad.to_tensor() + for p in xla_model.parameters()[0]]) + ############################################################## + # forward + backward with regular autograd / torch + outputs_gt = model(*inputs) + outputs_gt = _as_list(outputs_gt) + grad_inputs_gt = torch.autograd.grad(outputs_gt, + inputs_params, + grad_outputs, + only_inputs=True) + for out_jit, out_autograd in zip(outputs, outputs_gt): + self.assertEqualRel(out_jit, out_autograd, rel_err=rel_err, + abs_err=abs_err) + + for grad_input_jit, grad_input_autograd in zip(grad_inputs, grad_inputs_gt): + self.assertEqualRel(grad_input_jit, grad_input_autograd, + rel_err=rel_err, abs_err=abs_err) + + # TODO: test buffers as well (running_mean, etc.) + if xla: + for i, (grad_input_jit, grad_input_xla) in enumerate(zip(grad_inputs, + grad_inputs_xla)): + self.assertEqualRel(grad_input_jit, grad_input_xla, rel_err, abs_err) + + def test_avgpool(self): + class AvgPoolGrad(nn.Module): + def __init__(self, stride, padding, count_include_pad): + super(AvgPoolGrad, self).__init__() + self.stride = stride + self.padding = padding + self.count_include_pad = count_include_pad + + def forward(self, x): + return F.avg_pool2d(x, 2, self.stride, self.padding, False, + self.count_include_pad) + + for stride in [1, 2]: + for padding in [0, 1]: + for count_include_pad in [False, True]: + model = AvgPoolGrad(stride, padding, count_include_pad) + inputs = [torch.randn(4, 1, 28, 28, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + def test_threshold(self): + class ThresholdPoolGrad(nn.Module): + def __init__(self): + super(ThresholdPoolGrad, self).__init__() + self.threshold = nn.Threshold(0.4, 20) + + def forward(self, x): + return self.threshold(x) + + model = ThresholdPoolGrad() + inputs = [torch.randn(4, 2, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + @unittest.skip('RuntimeError: expected 3 inputs, but got 4') + def test_maxpool(self): + class MaxPoolGrad(nn.Module): + def forward(self, x): + return F.max_pool2d(x, 2) + + model = MaxPoolGrad() + inputs = [torch.randn(4, 1, 28, 28, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + def test_tanh(self): + class TanhGrad(nn.Module): + def forward(self, x): + return torch.tanh(x) + + model = TanhGrad() + inputs = [torch.randn(4, 2, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + def test_sigmoid(self): + class SigmoidGrad(nn.Module): + def forward(self, x): + return torch.sigmoid(x) + + model = SigmoidGrad() + inputs = [torch.randn(4, 2, requires_grad=True)] + self.checkGrad(model, inputs, xla=True, rel_err=1e-2, abs_err=1e-2) + + @unittest.skip('differentiation of prim::ListUnpack is not supported, or it is missing necessary type information') + def test_chunk(self): + class ChunkGrad(nn.Module): + def forward(self, x): + return x.chunk(2, 1) + + model = ChunkGrad() + inputs = [torch.randn(4, 4, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + @unittest.skip('bool value of Tensor with more than one value is ambiguous') + def test_lstm_cell(self): + class LSTMCellGrad(nn.Module): + def __init__(self): + super(LSTMCellGrad, self).__init__() + self.i2h = nn.Linear(3, 8) + self.h2h = nn.Linear(2, 8) + + def forward(self, x, hx, cx): + gates = self.i2h(x) + self.h2h(hx) + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + return hy, cy + + model = LSTMCellGrad() + inputs = [torch.randn(4, 3, requires_grad=True), + torch.randn(4, 2, requires_grad=True), + torch.randn(4, 2, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + @unittest.skip('Pending autodiff support') + def test_conv2d(self): + if FLAGS.long_test: + config = [ + [1, 7, 15, 32], # ichans + [1, 4, 21, 32], # ochans + [1, 2, 3, 5], # size + [1, 2], # stride + [0, 1], # padding + [True, False], # bias + ] + else: + config = [ + [1, 5], # ichans + [1, 4], # ochans + [1, 3], # size + [1], # stride + [0], # padding + [False], # bias + ] + for ichans, ochans, size, stride, padding, bias in ( + itertools.product(*config)): + # TODO: dilation, groups, transpose + model = nn.Conv2d(ichans, ochans, size, stride, padding, bias=bias) + inputs = [torch.randn(4, ichans, 28, 28, requires_grad=True)] + self.checkGrad(model, inputs, xla=True, abs_err=1e-3) + + @unittest.skip('Pending autodiff support') + def test_batchnorm2d(self): + for chans in [1, 15, 32]: + for eps in [1e-5, 1e-3, 1e-2]: + # TODO: momentum, training, affine + model = nn.BatchNorm2d(chans, eps=eps) + inputs = [torch.randn(4, chans, 28, 28, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + def test_logsoftmax(self): + for dim in [0, 1]: # todo test 3d as well + for batch in [1, 3, 4]: + class LSMGrad(nn.Module): + def forward(self, x): + return F.log_softmax(x, dim) + + model = LSMGrad() + inputs = [torch.randn(batch, 9, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + @unittest.skip('Pending autodiff support') + def test_nll_loss(self): + input = torch.randn(3, 5, requires_grad=True) + target = torch.empty(3, dtype=torch.long).random_(5) + model = XlaNllLoss() + traced_model = torch.jit.trace(model, (input, target)) + xla_model = torch_xla._C.XlaModule(traced_model) + xla_inputs = [torch_xla._C.XLATensor(input), torch_xla._C.XLATensor(target)] + output_xla = xla_model((tuple(xla_inputs))) + xla_model.backward(output_xla) + output = model(input, target) + output.backward() + self.assertEqual(input.grad.data, xla_inputs[0].grad.data.to_tensor()) + + @unittest.skip('Pending autodiff support') + def test_mnist(self): + model = XlaMNIST() + inputs = [torch.randn(4, 1, 28, 28, requires_grad=True)] + self.checkGrad(model, inputs, xla=True) + + @unittest.skip('Pending autodiff support') + def test_resnet(self): + import torchvision + model = torchvision.models.resnet50() + inputs = [torch.randn(4, 3, 224, 224, requires_grad=True)] + self.checkGrad(model, inputs, xla=False) + + +class TestOptimizer(XlaTestCase): + def test_inplace_add_mul(self): + orig_x = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + orig_y = torch.Tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x = orig_x + y = orig_y + xla_x = torch_xla._C.XLATensor(orig_x) + xla_y = torch_xla._C.XLATensor(orig_y) + self.assertEqualDbg(x.add_(2, y).mul_(y), + xla_x.add_(2, xla_y).mul_(xla_y).to_tensor()) + self.assertEqualDbg(x.add_(y).mul_(y), + xla_x.add_(xla_y).mul_(xla_y).to_tensor()) + + def test_add_mul(self): + orig_x = torch.Tensor([[1, 2], [3, 4]]) + orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) + x = orig_x + y = orig_y + xla_x = torch_xla._C.XLATensor(orig_x) + xla_y = torch_xla._C.XLATensor(orig_y) + xla_ones = torch_xla._C.XLATensor(torch.ones_like(x)) + self.assertEqualDbg(x + 3 * y, xla_x.add(3, xla_y).to_tensor()) + self.assertEqualDbg(x * y, xla_x.mul(xla_y).to_tensor()) + z = (x + 9) * (y + 3) + xla_z = xla_x.add(9, xla_ones).mul(xla_y.add(3, xla_ones)) + self.assertEqualDbg(z, xla_z.to_tensor()) + self.assertEqualDbg(x + y, (xla_x + xla_y).to_tensor()) + self.assertEqualDbg(x * y, (xla_x * xla_y).to_tensor()) + self.assertEqualDbg(x * 11.0, (xla_x * 11.0).to_tensor()) + self.assertEqualDbg(x / 3.11, (xla_x / 3.11).to_tensor()) + self.assertEqualDbg(y / x, (xla_y / xla_x).to_tensor()) + + def checkSgd(self, lr, momentum, weight_decay, nsteps, do_zero_grad, + manually_batched): + input = torch.randn(4, 4, requires_grad=True) + model = nn.Linear(4, 20) + traced_model = torch.jit.trace(model, input) + xla_model = torch_xla._C.XlaModule(traced_model, use_full_conv_precision=True) + input_xla = [torch_xla._C.XLATensor(input)] + xla_model((tuple(input_xla))) + if manually_batched: + xla_optimizer = optim.XlaSGD(xla_model.parameters()[0], lr=lr, + momentum=momentum, weight_decay=weight_decay) + else: + xla_optimizer = optim.SGD(xla_model.parameters()[0], lr=lr, + momentum=momentum, weight_decay=weight_decay) + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, + weight_decay=weight_decay) + output = model(input) + grad_output = torch.randn(*output.shape) # random gradients + grad_output_xla = [torch_xla._C.XLATensor(grad_output)] + output.backward(grad_output) + xla_model.backward((tuple(grad_output_xla))) + if do_zero_grad: + optimizer.zero_grad() + xla_optimizer.zero_grad() + for _ in range(0, nsteps): + xla_optimizer.step() + optimizer.step() + xla_updated_params = [p.to_tensor().data for p in xla_model.parameters()[0]] + updated_params = [p.data for p in model.parameters()] + for i in range(0, len(updated_params)): + self.assertEqualRel(xla_updated_params[i], updated_params[i]) + + @unittest.skip('Pending optimizer changes') + def test_sgd(self): + for weight_decay in [0, 5e-4]: + for manually_batched in [False, True]: + self.checkSgd(lr=0.1, momentum=0, weight_decay=weight_decay, + nsteps=1, do_zero_grad=True, + manually_batched=manually_batched) + self.checkSgd(lr=0.1, momentum=0, weight_decay=weight_decay, nsteps=2, + do_zero_grad=False, manually_batched=manually_batched) + self.checkSgd(lr=0.1, momentum=0.5, weight_decay=weight_decay, nsteps=1, + do_zero_grad=True, manually_batched=manually_batched) + self.checkSgd(lr=0.1, momentum=0.5, weight_decay=weight_decay, nsteps=2, + do_zero_grad=False, manually_batched=manually_batched) + + +# Disabled always for now. +@unittest.skipIf(not (FLAGS.replicated and _support_replicated('TPU', 8)), + 'Replicated (8) TPU only') +class TestReplicatedSum(XlaTestCase): + def test(self): + + class XlaSum(nn.Module): + def forward(self, x, y): + return x + y + + model = XlaSum() + for num_replicas in [2, 3, 4, 5, 6, 7, 8]: + inputs = _random_inputs(((3, 3), (3, 3)), num_replicas=num_replicas) + out = _xla_run(model, inputs) + self.compareReplicated(model, inputs, out) + + +if __name__ == '__main__': + torch.set_default_tensor_type('torch.FloatTensor') + run_tests() diff --git a/third_party/tensorflow b/third_party/tensorflow new file mode 160000 index 00000000000..ca1636667db --- /dev/null +++ b/third_party/tensorflow @@ -0,0 +1 @@ +Subproject commit ca1636667dbfc7eeb35545f150916751742b5cd2 diff --git a/third_party/xla_client/BUILD b/third_party/xla_client/BUILD new file mode 100644 index 00000000000..5fcfaa2d8ca --- /dev/null +++ b/third_party/xla_client/BUILD @@ -0,0 +1,88 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") + +exports_files( + [ + "tf_version_script.lds", + "tf_exported_symbols.lds", + ], +) + +tf_cc_shared_object( + name = "libxla_computation_client.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "$(location //tensorflow/compiler/xla/xla_client:tf_exported_symbols.lds)", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-s", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow/compiler/xla/xla_client:tf_version_script.lds)", + ], + }), + visibility = ["//visibility:public"], + deps = [ + "computation_client_impl", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/rpc:grpc_stub", + "//tensorflow/compiler/xla/xla_client:tf_exported_symbols.lds", + "//tensorflow/compiler/xla/xla_client:tf_version_script.lds", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "computation_client_impl", + srcs = [ + "computation_client.cc", + "xla_computation_client.cc", + "xla_util.cc", + "xrt_computation_client.cc", + ], + hdrs = [ + "computation_client.h", + "unique.h", + "xla_computation_client.h", + "xla_util.h", + "xrt_computation_client.h", + ], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit:xla_cpu_device", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/rpc:grpc_stub", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/compiler/xrt:xrt_server", + "//tensorflow/compiler/xrt/cc:xrt_ops", + "//tensorflow/contrib/tpu:all_ops", + "//tensorflow/contrib/tpu/proto:topology_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_runtime", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/stream_executor:stream_executor_impl", + "@com_google_absl//absl/strings", + ], +) diff --git a/third_party/xla_client/computation_client.cc b/third_party/xla_client/computation_client.cc new file mode 100644 index 00000000000..55fdd6baefa --- /dev/null +++ b/third_party/xla_client/computation_client.cc @@ -0,0 +1,167 @@ +#include "tensorflow/compiler/xla/xla_client/computation_client.h" + +#include + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_client/xla_computation_client.h" +#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h" + +namespace xla { + +namespace { + +string GetEnvString(const char* name, const string& defval) { + const char* env = std::getenv(name); + return env != nullptr ? env : defval; +} + +int64 GetEnvInt(const char* name, int64 defval) { + const char* env = std::getenv(name); + return env != nullptr ? std::atol(env) : defval; +} + +string GetTpuClusterConfigPath() { + string home_folder = GetEnvString("HOME", "."); + return absl::StrCat(home_folder, "/", ".pytorch_tpu.conf"); +} + +bool ShouldUseXrtClient(string* config_path) { + *config_path = GetTpuClusterConfigPath(); + if (access(config_path->c_str(), F_OK) != -1) { + // If we have a TPU cluster config file, we are in Cloud TPU world, so steer + // towards config file based XRT client. + return true; + } + config_path->clear(); + return GetEnvInt("XLA_USE_XRT", -1) > 0; +} + +XrtComputationClient::Worker ParseWorker(const string& worker) { + std::vector parts = absl::StrSplit(worker, ':'); + CHECK(parts.size() == 1 || parts.size() == 2) << worker; + return parts.size() == 1 + ? XrtComputationClient::Worker(parts[0], 0) + : XrtComputationClient::Worker(parts[0], std::stoi(parts[1])); +} + +void AddXrtHostDevices(const string& worker_name, int task_no, + const string& server, + std::map* device_ordinals, + XrtComputationClient::Options* options) { + struct Devices { + const char* name; + int count; + } const devices[] = { + {"TPU", GetEnvInt("TPU_NUM_DEVICES", 8)}, + {"CPU", GetEnvInt("CPU_NUM_DEVICES", 1)}, + }; + string host_port = server.compare(0, 7, "grpc://") == 0 + ? server + : absl::StrCat("grpc://", server); + options->workers_map.emplace( + XrtComputationClient::Worker(worker_name, task_no), host_port); + for (auto& device : devices) { + int& device_ordinal = (*device_ordinals)[device.name]; + for (int j = 0; j < device.count; ++j, ++device_ordinal) { + string device_name = absl::StrCat(device.name, ":", device_ordinal); + string xrt_device_name = + absl::StrCat("/job:", worker_name, "/replica:0/task:", task_no, + "/device:", device_name); + options->device_map.emplace(device_name, xrt_device_name); + } + } +} + +StatusOr ParseEnvBasedTpuClusterConfig( + XrtComputationClient::Options* options) { + string tpu_config = GetEnvString("XRT_TPU_CONFIG", ""); + if (tpu_config.empty()) { + return false; + } + std::map device_ordinals; + std::vector spec_parts = absl::StrSplit(tpu_config, '|'); + TF_RET_CHECK(!spec_parts.empty()) << tpu_config; + for (const auto& spec : spec_parts) { + std::vector host_parts = absl::StrSplit(spec, ';'); + TF_RET_CHECK(host_parts.size() == 3) << spec; + AddXrtHostDevices(host_parts[0], std::stoi(host_parts[1]), host_parts[2], + &device_ordinals, options); + } + options->default_device = "TPU:0"; + return true; +} + +Status ParseTpuClusterConfig(const string& xrt_config_path, + XrtComputationClient::Options* options) { + std::map device_ordinals; + std::ifstream config_file(xrt_config_path); + string line; + while (std::getline(config_file, line)) { + if (line.compare(0, 7, "worker:") == 0) { + std::vector parts = + absl::StrSplit(line.substr(7), ' ', absl::SkipWhitespace()); + TF_RET_CHECK(parts.size() >= 2) << line; + const string& worker_name = parts[0]; + for (std::size_t i = 1; i < parts.size(); ++i) { + AddXrtHostDevices(worker_name, i - 1, parts[i], &device_ordinals, + options); + } + } + } + options->default_device = "TPU:0"; + return Status::OK(); +} + +} // namespace + +StatusOr> ComputationClient::Create() { + std::unique_ptr client; + string xrt_config_path; + if (ShouldUseXrtClient(&xrt_config_path)) { + XrtComputationClient::Options options; + if (!xrt_config_path.empty()) { + LOG(INFO) << "Loading XRT configuration from " << xrt_config_path; + TF_RETURN_IF_ERROR(ParseTpuClusterConfig(xrt_config_path, &options)); + } else { + TF_ASSIGN_OR_RETURN(bool configured, + ParseEnvBasedTpuClusterConfig(&options)); + if (!configured) { + string device_spec = + GetEnvString("XRT_DEVICE_MAP", + "TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0"); + for (const auto& device_target : absl::StrSplit(device_spec, '|')) { + std::vector parts = absl::StrSplit(device_target, ';'); + TF_RET_CHECK(parts.size() == 2) << device_target; + if (options.default_device.empty()) { + options.default_device = parts[0]; + } + options.device_map.emplace(parts[0], parts[1]); + } + string workers_spec = + GetEnvString("XRT_WORKERS", "tpu_worker:0;grpc://localhost:51000"); + for (const auto& name_target : absl::StrSplit(workers_spec, '|')) { + std::vector parts = absl::StrSplit(name_target, ';'); + TF_RET_CHECK(parts.size() == 2); + options.workers_map.emplace(ParseWorker(parts[0]), parts[1]); + } + } + } + client.reset(new XrtComputationClient(options)); + } else { + XlaComputationClient::Options options; + options.host_name = GetEnvString("XLA_GRPC_HOST", "localhost"); + options.port = GetEnvInt("XLA_GRPC_PORT", 51000); + options.platform = GetEnvString("XLA_PLATFORM", "TPU"); + client.reset(new XlaComputationClient(options)); + } + return std::move(client); +} + +} // namespace xla diff --git a/third_party/xla_client/computation_client.h b/third_party/xla_client/computation_client.h new file mode 100644 index 00000000000..1d64771eded --- /dev/null +++ b/third_party/xla_client/computation_client.h @@ -0,0 +1,75 @@ +#ifndef TENSORFLOW_COMPILER_XLA_RPC_COMPUTATION_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_COMPUTATION_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +class ComputationClient { + public: + class Data { + public: + Data(string device, Shape device_shape) + : device_(std::move(device)), device_shape_(std::move(device_shape)) {} + + virtual ~Data() {} + + const string& device() const { return device_; } + + const Shape& shape() const { return device_shape_; } + + private: + string device_; + Shape device_shape_; + }; + + static StatusOr> Create(); + + virtual ~ComputationClient() {} + + virtual std::shared_ptr TransferParameterToServer( + const xla::Literal& literal, const string& device) = 0; + + // Executes "computation" with "arguments" and returns the result. If + // "output_shape" isn't null, use it as a hint for the computation output + // layout. + virtual std::shared_ptr ExecuteComputation( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) = 0; + + virtual std::unique_ptr ExecuteComputationAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) = 0; + + // Executes the computation in replicated mode. + // The size of the arguments vector is the number of replicas to execute. + // The destination devices for each replicated computation come from the + // devices the Data objects are stored into. Within arguments[i], every Data + // object must be coming from the same device. The optional output_shape can + // be used to force the shape (and layout) or the computation result. Returns + // a vector (of the same size of the arguments vector) with the results of the + // parallel execution. The result[i] will be the result of the computation fed + // with arguments[i]. + virtual std::vector> ExecuteReplicated( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape) = 0; + + virtual StatusOr>> DeconstructTuple( + const Data& data) = 0; + + virtual string GetDefaultDevice() const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_COMPUTATION_CLIENT_H_ diff --git a/third_party/xla_client/tf_exported_symbols.lds b/third_party/xla_client/tf_exported_symbols.lds new file mode 100644 index 00000000000..2a47055ac35 --- /dev/null +++ b/third_party/xla_client/tf_exported_symbols.lds @@ -0,0 +1,9 @@ +*tensorflow* +*perftools*gputools* +*tf_* +*TF_* +*TFE_* +*nsync_* +*pywrap_xla* +*xla* +*ConvBackpropComputeDimensionsV2* diff --git a/third_party/xla_client/tf_version_script.lds b/third_party/xla_client/tf_version_script.lds new file mode 100644 index 00000000000..6eacc2db999 --- /dev/null +++ b/third_party/xla_client/tf_version_script.lds @@ -0,0 +1,13 @@ +tensorflow { + global: + *tensorflow*; + *perftools*gputools*; + *TF_*; + *TFE_*; + *nsync_*; + *pywrap_xla*; + *xla*; + *ConvBackpropComputeDimensionsV2*; + local: + *; +}; diff --git a/third_party/xla_client/unique.h b/third_party/xla_client/unique.h new file mode 100644 index 00000000000..29a4036778a --- /dev/null +++ b/third_party/xla_client/unique.h @@ -0,0 +1,35 @@ +#ifndef TENSORFLOW_COMPILER_XLA_XLA_CLIENT_UNIQUE_H_ +#define TENSORFLOW_COMPILER_XLA_XLA_CLIENT_UNIQUE_H_ + +#include + +#include "absl/types/optional.h" + +namespace xla { +namespace xla_util { + +// Helper class to allow tracking zero or more things, which should be forcibly +// be one only thing. +template > +class Unique { + public: + std::pair set(const T& value) { + if (value_) { + CHECK(C()(*value_, value)) << "'" << *value_ << "' vs '" << value << "'"; + return std::pair(false, *value_); + } + value_ = value; + return std::pair(true, *value_); + } + + operator bool() const { return value_.has_value(); } + const T& operator*() const { return *value_; } + + private: + absl::optional value_; +}; + +} // namespace xla_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_XLA_CLIENT_UNIQUE_H_ diff --git a/third_party/xla_client/xla_computation_client.cc b/third_party/xla_client/xla_computation_client.cc new file mode 100644 index 00000000000..8189bc622ec --- /dev/null +++ b/third_party/xla_client/xla_computation_client.cc @@ -0,0 +1,165 @@ +#include "tensorflow/compiler/xla/xla_client/xla_computation_client.h" + +#include "grpc++/create_channel.h" +#include "grpc++/support/channel_arguments.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/xla_client/unique.h" +#include "tensorflow/compiler/xla/xla_client/xla_util.h" + +namespace xla { + +XlaComputationClient::XlaComputationClient( + XlaComputationClient::Options options) + : options_(std::move(options)) { + if (!options_.host_name.empty()) { + ::grpc::ChannelArguments ch_args; + ch_args.SetMaxReceiveMessageSize(-1); + auto channel = ::grpc::CreateCustomChannel( + absl::StrCat(options_.host_name, ":", options_.port), + ::grpc::InsecureChannelCredentials(), ch_args); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); + LOG(INFO) << "Channel to '" << options_.host_name + << "' is connected on port " << options_.port; + + xla_service_ = grpc::XlaService::NewStub(channel); + stub_.reset(new GRPCStub(xla_service_.get())); + client_ptr_.reset(new Client(stub_.get())); + client_ = client_ptr_.get(); + } else { + se::Platform* platform = nullptr; + if (!options_.platform.empty()) { + platform = PlatformUtil::GetPlatform(options_.platform).ValueOrDie(); + } + LOG(INFO) << "Creating XLA computation client for '" + << (options_.platform.empty() ? "default" : options_.platform) + << "' platform"; + client_ = ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + } +} + +std::shared_ptr +XlaComputationClient::ExecuteComputation( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, const Shape* output_shape) { + FlushReleasedHandles(); + + ExecutionOptions eo; + *eo.mutable_debug_options() = legacy_flags::GetDebugOptionsFromFlags(); + if (output_shape != nullptr) { + *eo.mutable_shape_with_output_layout() = *output_shape; + } + string device; + std::vector arguments_data = + GetArgumentsData(arguments, &device); + StatusOr> result_or_status = + client_->Execute(computation, arguments_data, &eo); + xrt_util::CheckComputationStatus(result_or_status.status(), computation); + + ProgramShape program_shape; + if (output_shape == nullptr) { + program_shape = computation.GetProgramShape().ValueOrDie(); + output_shape = &program_shape.result(); + } + return std::make_shared( + std::move(result_or_status.ValueOrDie()), device, *output_shape, + [this](XlaData* xla_data) { ReleaseXlaData(xla_data); }); +} + +std::unique_ptr XlaComputationClient::ExecuteComputationAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, const Shape* output_shape) { + FlushReleasedHandles(); + + ExecutionOptions eo; + *eo.mutable_debug_options() = legacy_flags::GetDebugOptionsFromFlags(); + if (output_shape != nullptr) { + *eo.mutable_shape_with_output_layout() = *output_shape; + } + std::vector arguments_data = + GetArgumentsData(arguments, /*device=*/nullptr); + StatusOr result_or_status = + client_->ExecuteAndTransfer(computation, arguments_data, &eo); + xrt_util::CheckComputationStatus(result_or_status.status(), computation); + return std::unique_ptr( + new Literal(std::move(result_or_status.ValueOrDie()))); +} + +std::vector> +XlaComputationClient::ExecuteReplicated( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape) { + LOG(FATAL) << "ExecuteReplicated() API not yet implemented!"; +} + +std::shared_ptr +XlaComputationClient::TransferParameterToServer(const Literal& literal, + const string& device) { + FlushReleasedHandles(); + + std::unique_ptr handle = + client_->TransferToServer(literal).ValueOrDie(); + return std::make_shared( + std::move(handle), GetEffectiveDevice(device), literal.shape(), + [this](XlaData* xla_data) { ReleaseXlaData(xla_data); }); +} + +StatusOr>> +XlaComputationClient::DeconstructTuple(const Data& data) { + const XlaData& xla_data = dynamic_cast(data); + TF_ASSIGN_OR_RETURN(auto exploded_tuple, + client_->DeconstructTuple(*xla_data.handle)); + std::vector> tuple; + for (int64 i = 0; i < exploded_tuple.size(); ++i) { + tuple.push_back(std::make_shared( + std::move(exploded_tuple[i]), xla_data.device(), + ShapeUtil::GetTupleElementShape(xla_data.shape(), i), + [this](XlaData* xla_data) { ReleaseXlaData(xla_data); })); + } + return std::move(tuple); +} + +std::vector XlaComputationClient::GetArgumentsData( + tensorflow::gtl::ArraySlice arguments, string* device) const { + xla_util::Unique unique_device; + std::vector arguments_data; + for (auto data : arguments) { + XlaData* xla_data = dynamic_cast(data); + unique_device.set(xla_data->device()); + arguments_data.push_back(xla_data->handle.get()); + } + if (device != nullptr) { + if (unique_device) { + *device = *unique_device; + } else { + *device = GetDefaultDevice(); + } + } + return arguments_data; +} + +string XlaComputationClient::GetEffectiveDevice(const string& device) const { + return device.empty() ? GetDefaultDevice() : device; +} + +void XlaComputationClient::FlushReleasedHandles() { + std::vector> released_handles; + released_handles.swap(released_handles_); + GlobalData::Release(std::move(released_handles)); +} + +void XlaComputationClient::ReleaseXlaData(XlaData* xla_data) { + released_handles_.push_back(xla_data->Release()); +} + +string XlaComputationClient::GetDefaultDevice() const { + return options_.platform + ":0"; +} + +} // namespace xla diff --git a/third_party/xla_client/xla_computation_client.h b/third_party/xla_client/xla_computation_client.h new file mode 100644 index 00000000000..cd4007af398 --- /dev/null +++ b/third_party/xla_client/xla_computation_client.h @@ -0,0 +1,99 @@ +#ifndef TENSORFLOW_COMPILER_XLA_RPC_XLA_COMPUTATION_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_XLA_COMPUTATION_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" + +namespace xla { + +class XlaComputationClient : public ComputationClient { + struct XlaData : public Data { + using Releaser = std::function; + + XlaData(std::unique_ptr handle, string device, + Shape device_shape, Releaser releaser) + : Data(std::move(device), std::move(device_shape)), + handle(std::move(handle)), + releaser(std::move(releaser)) {} + + ~XlaData() override { + if (releaser) { + releaser(this); + } + } + + std::unique_ptr Release() { + CHECK(releaser != nullptr); + releaser = nullptr; + return std::move(handle); + } + + std::unique_ptr handle; + Releaser releaser; + }; + + public: + struct Options { + Options() : host_name(), port(-1), platform() {} + + string host_name; + int port; + string platform; + }; + + XlaComputationClient(Options options); + + std::shared_ptr TransferParameterToServer( + const Literal& literal, const string& device) override; + + std::shared_ptr ExecuteComputation( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) override; + + std::unique_ptr ExecuteComputationAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) override; + + std::vector> ExecuteReplicated( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape) override; + + StatusOr>> DeconstructTuple( + const Data& data) override; + + string GetDefaultDevice() const override; + + private: + std::vector GetArgumentsData( + tensorflow::gtl::ArraySlice arguments, string* device) const; + + // Returns the device argument if not empty, or the value returned by the + // GetDefaultDevice() API. + string GetEffectiveDevice(const string& device) const; + + // Flushes all the outstanding released handles in one RPC swipe. + void FlushReleasedHandles(); + + // Batches an XLA handle for release. + void ReleaseXlaData(XlaData* xla_data); + + Options options_; + Client* client_ = nullptr; + std::unique_ptr client_ptr_; + std::unique_ptr xla_service_; + std::unique_ptr stub_; + std::vector> released_handles_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_XLA_COMPUTATION_CLIENT_H_ diff --git a/third_party/xla_client/xla_util.cc b/third_party/xla_client/xla_util.cc new file mode 100644 index 00000000000..5d10a173a0b --- /dev/null +++ b/third_party/xla_client/xla_util.cc @@ -0,0 +1,36 @@ +#include "tensorflow/compiler/xla/xla_client/xla_util.h" + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace xrt_util { + +using namespace tensorflow; + +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const DebugOptions& debug_options) { + TF_ASSIGN_OR_RETURN( + auto hlo_module_config, + HloModule::CreateModuleConfigFromProto(proto, debug_options)); + return HloModule::CreateFromProto(proto, hlo_module_config); +} + +StatusOr GetComputationHloText(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(auto hlo_module, + CreateModuleFromProto(computation.proto())); + return hlo_module->ToString(); +} + +void CheckComputationStatus(const Status& status, + const XlaComputation& computation) { + if (!status.ok()) { + string hlo_text = GetComputationHloText(computation).ValueOrDie(); + XLA_LOG_LINES(ERROR, hlo_text); + LOG(FATAL) << status; + } +} + +} // namespace xrt_util +} // namespace xla diff --git a/third_party/xla_client/xla_util.h b/third_party/xla_client/xla_util.h new file mode 100644 index 00000000000..44f3e539da5 --- /dev/null +++ b/third_party/xla_client/xla_util.h @@ -0,0 +1,30 @@ +#ifndef TENSORFLOW_COMPILER_XLA_RPC_XLA_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_XLA_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace xrt_util { + +// Creates the HLO module which is generated by the input PB message. +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, + const DebugOptions& debug_options = DebugOptions()); + +// Returns a textual representation of the input XLA computation. +StatusOr GetComputationHloText(const XlaComputation& computation); + +// Checks whether an action on the given computation generated an error, and if +// that was the case, emit error and computation HLO text. +void CheckComputationStatus(const Status& status, + const XlaComputation& computation); + +} // namespace xrt_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_XLA_UTIL_H_ diff --git a/third_party/xla_client/xrt_computation_client.cc b/third_party/xla_client/xrt_computation_client.cc new file mode 100644 index 00000000000..0238c585d37 --- /dev/null +++ b/third_party/xla_client/xrt_computation_client.cc @@ -0,0 +1,632 @@ +#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/unique.h" +#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace xla { +namespace { + +static const char* const kCpuDevice = "/device:CPU:0"; + +} // namespace + +XrtComputationClient::XrtComputationClient( + XrtComputationClient::Options options) + : options_(std::move(options)) { + auto default_device_target = + options_.device_map.find(options_.default_device); + CHECK(default_device_target != options_.device_map.end()); + for (const auto& dev_target : options_.device_map) { + LOG(INFO) << "XRT device " << dev_target.first << " -> " + << dev_target.second; + } + LOG(INFO) << "XRT default device: " << default_device_target->first; + InitializeDevices(); +} + +std::shared_ptr +XrtComputationClient::ExecuteComputation( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, const Shape* output_shape) { + ApiCallInitialize(); + + std::vector devices; + tensorflow::ClientSession::FeedType feed_inputs; + auto exec_ops = + CreateExecuteOps(computation, BuildParallelArguments(arguments), + output_shape, &devices, &feed_inputs); + SessionData* session = GetSessionForDevice(devices.front()); + std::vector outputs; + TF_CHECK_OK(session->root.status()); + xrt_util::CheckComputationStatus( + session->session.Run(feed_inputs, {exec_ops.front().execute_output}, + &outputs), + computation); + CHECK_EQ(outputs.size(), 1); + + return std::make_shared( + devices.front(), outputs[0].scalar()(), + exec_ops.front().result_shape, + [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); }); +} + +std::unique_ptr XrtComputationClient::ExecuteComputationAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, const Shape* output_shape) { + ApiCallInitialize(); + + ProgramShape program_shape; + if (output_shape == nullptr) { + program_shape = computation.GetProgramShape().ValueOrDie(); + output_shape = &program_shape.result(); + } + string device = GetArgumentsDevice(arguments); + auto xrt_computation = CreateXrtComputation(computation, /*num_replicas=*/1, + {device}, output_shape); + tensorflow::ClientSession::FeedType feed_inputs; + auto inputs = GetArgumentsInputs(arguments, device, &feed_inputs); + SessionData* session = GetSessionForDevice(device); + tensorflow::Scope device_scope = + session->root.WithDevice(TorchDeviceToXrtDevice(device)); + const CachedNode& cached_node = + GetCompileExecuteReadNode(device_scope, device); + + feed_inputs.insert( + {cached_node.holders[0], xrt_computation->SerializeAsString()}); + + xrt::XRTExecutionConfig exec_config; + exec_config.set_release_input_handles(false); + exec_config.set_release_compilation_handle(true); + feed_inputs.insert({cached_node.holders[1], exec_config.SerializeAsString()}); + feed_inputs.insert({cached_node.holders[2], inputs}); + + std::vector outputs; + TF_CHECK_OK(session->root.status()); + xrt_util::CheckComputationStatus( + session->session.Run(feed_inputs, {*cached_node.output}, &outputs), + computation); + CHECK_EQ(outputs.size(), 1); + + LiteralProto response; + CHECK(response.ParseFromString(outputs[0].scalar()())); + return std::unique_ptr( + new Literal(Literal::CreateFromProto(response).ValueOrDie())); +} + +std::shared_ptr +XrtComputationClient::TransferParameterToServer(const Literal& literal, + const string& device) { + ApiCallInitialize(); + + string effective_device = GetEffectiveDevice(device); + const string& xrt_device = TorchDeviceToXrtDevice(effective_device); + SessionData* session = GetSessionForXrtDevice(xrt_device); + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(GetDeviceOrdinal(xrt_device)); + *alloc.mutable_value() = literal.ToProto(); + + tensorflow::ClientSession::FeedType feed_inputs; + tensorflow::Scope device_scope = session->root.WithDevice(xrt_device); + const CachedNode& cached_node = + GetAllocateNode(device_scope, effective_device); + feed_inputs.insert({cached_node.holders[0], alloc.SerializeAsString()}); + + std::vector outputs; + TF_CHECK_OK(session->root.status()); + TF_CHECK_OK( + session->session.Run(feed_inputs, {*cached_node.output}, &outputs)); + CHECK_EQ(outputs.size(), 1); + return std::make_shared( + effective_device, outputs[0].scalar()(), literal.shape(), + [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); }); +} + +std::vector> +XrtComputationClient::ExecuteReplicated( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape) { + ApiCallInitialize(); + + std::vector devices; + tensorflow::ClientSession::FeedType feed_inputs; + auto exec_ops = CreateExecuteOps(computation, arguments, output_shape, + &devices, &feed_inputs); + // In the PyTorch/XRT interface we keep a map (options_.workers_map) from a + // worker+taskno, to the GRPC server which is the entry point for that worker. + // Since XRT could re-distribute ops internally, if we have N hosts + // (worker+taskno), we could have all the workers pointing to a single GRPC + // entry point, or we could have each worker pointing directly to the target + // host. + // The advantage of the latter approach, is that we do not bottleneck + // (especially when feeding inputs) the single GRPC entry point. + // Using the N:1 approach, the session_replicas below will contain a single + // session, and all the replica executions will go through it (and distributed + // by XRT on the service side). + // Chosing the 1:1 approach (one session per worker), we will have N sessions + // within the session_replicas map, which we will be executing independently. + std::map> session_replicas; + for (size_t i = 0; i < devices.size(); ++i) { + SessionData* session = GetSessionForDevice(devices[i]); + session_replicas[session].push_back(i); + } + // TODO(dlibenzi): These could be run in parallel. + std::vector> results(devices.size()); + for (auto& sess_replica : session_replicas) { + std::vector exec_nodes; + for (auto replica : sess_replica.second) { + exec_nodes.push_back(exec_ops[replica].execute_output); + } + std::vector outputs; + TF_CHECK_OK(sess_replica.first->root.status()); + xrt_util::CheckComputationStatus( + sess_replica.first->session.Run(feed_inputs, exec_nodes, &outputs), + computation); + CHECK_EQ(outputs.size(), exec_nodes.size()); + + for (size_t i = 0; i < outputs.size(); ++i) { + auto replica = sess_replica.second[i]; + results[replica] = std::make_shared( + devices[replica], outputs[i].scalar()(), + exec_ops[replica].result_shape, + [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); }); + } + } + return results; +} + +StatusOr>> +XrtComputationClient::DeconstructTuple(const Data& data) { + ApiCallInitialize(); + + const XrtData& xrt_data = dynamic_cast(data); + SessionData* session = GetSessionForDevice(xrt_data.device()); + tensorflow::Scope device_scope = + session->root.WithDevice(TorchDeviceToXrtDevice(xrt_data.device())); + tensorflow::Scope cpu_scope = session->root.WithDevice(kCpuDevice); + int64 count = ShapeUtil::TupleElementCount(xrt_data.shape()); + std::vector sub_outputs; + tensorflow::ClientSession::FeedType feed_inputs; + for (int64 i = 0; i < count; ++i) { + const CachedNode& cached_node = + GetSubTupleNode(device_scope, xrt_data.device()); + feed_inputs.insert({cached_node.holders[0], xrt_data.handle}); + tensorflow::Tensor index_tensor(tensorflow::DT_INT32, + tensorflow::TensorShape({1})); + index_tensor.flat()(0) = i; + feed_inputs.insert({cached_node.holders[1], index_tensor}); + sub_outputs.push_back(*cached_node.output); + } + std::vector outputs; + TF_CHECK_OK(session->root.status()); + TF_CHECK_OK(session->session.Run(feed_inputs, sub_outputs, &outputs)); + CHECK_EQ(outputs.size(), count); + + std::vector> components; + for (int64 i = 0; i < count; ++i) { + components.push_back(std::make_shared( + xrt_data.device(), outputs[i].scalar()(), + ShapeUtil::GetTupleElementShape(xrt_data.shape(), i), + [this](XrtData* xrt_data) { ReleaseXrtData(xrt_data); })); + } + return std::move(components); +} + +XrtComputationClient::SessionData* XrtComputationClient::GetSessionForTarget( + const string& target) { + auto target_session = session_map_.find(target); + if (target_session == session_map_.end()) { + target_session = + session_map_ + .emplace(target, + std::unique_ptr(new SessionData(target))) + .first; + } + return target_session->second.get(); +} + +XrtComputationClient::SessionData* XrtComputationClient::GetSessionForXrtDevice( + const string& xrt_device) { + auto worker_hostport = GetWorkerForXrtDevice(xrt_device); + return GetSessionForTarget(worker_hostport.second); +} + +XrtComputationClient::SessionData* XrtComputationClient::GetSessionForDevice( + const string& device) { + return GetSessionForXrtDevice(TorchDeviceToXrtDevice(device)); +} + +const string& XrtComputationClient::GetEffectiveDevice( + const string& device) const { + return !device.empty() ? device : options_.default_device; +} + +const string& XrtComputationClient::TorchDeviceToXrtDevice( + const string& device) const { + auto device_target = options_.device_map.find(GetEffectiveDevice(device)); + CHECK(device_target != options_.device_map.end()) + << "Unable to find device: " << device; + return device_target->second; +} + +std::unique_ptr XrtComputationClient::CreateXrtComputation( + const XlaComputation& computation, int64 num_replicas, + const std::vector& devices, const Shape* output_shape) const { + CHECK_EQ(num_replicas, devices.size()); + std::unique_ptr xrt_computation( + new xrt::XLAComputation()); + auto config = xrt_computation->mutable_config(); + config->set_num_replicas(num_replicas); + config->set_num_cores_per_replica(1); + if (num_replicas > 1) { + auto device_assignment = config->mutable_device_assignment(); + auto computation_device = device_assignment->add_computation_devices(); + for (int64 i = 0; i < num_replicas; ++i) { + const string& xrt_device = TorchDeviceToXrtDevice(devices[i]); + const auto& core_coords = GetDeviceMeshCoords(xrt_device); + auto replica_device = computation_device->add_replica_devices(); + for (auto coord : core_coords) { + replica_device->add_value(coord); + } + } + } + *config->mutable_program_shape() = computation.GetProgramShape().ValueOrDie(); + if (output_shape != nullptr) { + *config->mutable_program_shape()->mutable_result() = *output_shape; + } + *xrt_computation->mutable_hlo_snapshot() = + *computation.Snapshot().ValueOrDie(); + return xrt_computation; +} + +string XrtComputationClient::GetArgumentsDevice( + tensorflow::gtl::ArraySlice arguments) const { + xla_util::Unique unique_device; + for (size_t i = 0; i < arguments.size(); ++i) { + XrtData* xrt_data = dynamic_cast(arguments[i]); + unique_device.set(xrt_data->device()); + } + // If the computation has no arguments, use the default device. + // Maybe the execute-computation APIs needs to be more explicit about it. + return unique_device ? *unique_device : options_.default_device; +} + +std::vector XrtComputationClient::GetReplicasDevices( + const std::vector>& arguments) const { + std::vector devices; + std::set unique_devices; + for (size_t i = 0; i < arguments.size(); ++i) { + devices.push_back(GetArgumentsDevice(arguments[i])); + CHECK(unique_devices.insert(devices.back()).second) + << "Cannot have two different replicas using the same device: " + << devices.back(); + } + return devices; +} + +tensorflow::Tensor XrtComputationClient::GetArgumentsInputs( + tensorflow::gtl::ArraySlice arguments, const string& device, + tensorflow::ClientSession::FeedType* feed_inputs) { + tensorflow::Tensor inputs_tensor(tensorflow::DT_INT64, + tensorflow::TensorShape({arguments.size()})); + for (size_t i = 0; i < arguments.size(); ++i) { + XrtData* xrt_data = dynamic_cast(arguments[i]); + CHECK_EQ(device, xrt_data->device()); + inputs_tensor.flat()(i) = xrt_data->handle; + } + return inputs_tensor; +} + +std::vector +XrtComputationClient::CreateExecuteOps( + const XlaComputation& computation, + const std::vector>& arguments, const Shape* output_shape, + std::vector* devices, + tensorflow::ClientSession::FeedType* feed_inputs) { + ProgramShape program_shape; + if (output_shape == nullptr) { + program_shape = computation.GetProgramShape().ValueOrDie(); + output_shape = &program_shape.result(); + } + *devices = GetReplicasDevices(arguments); + auto xrt_computation = CreateXrtComputation(computation, arguments.size(), + *devices, output_shape); + + absl::optional computation_holder; + xla_util::Unique unique_session; + std::vector exec_ops; + for (size_t i = 0; i < arguments.size(); ++i) { + auto inputs = GetArgumentsInputs(arguments[i], devices->at(i), feed_inputs); + const string& xrt_device = TorchDeviceToXrtDevice(devices->at(i)); + SessionData* session = + unique_session.set(GetSessionForXrtDevice(xrt_device)).second; + tensorflow::Scope device_scope = session->root.WithDevice(xrt_device); + const CachedNode& cached_node = + GetCompileExecuteNode(device_scope, devices->at(i)); + feed_inputs->insert( + {cached_node.holders[0], xrt_computation->SerializeAsString()}); + + xrt::XRTExecutionConfig exec_config; + exec_config.set_core_index_in_replica(0); + exec_config.set_release_input_handles(false); + exec_config.set_release_compilation_handle(true); + feed_inputs->insert( + {cached_node.holders[1], exec_config.SerializeAsString()}); + feed_inputs->insert({cached_node.holders[2], inputs}); + + exec_ops.emplace_back(*cached_node.output, *output_shape); + } + return exec_ops; +} + +void XrtComputationClient::ReleaseHandles( + tensorflow::gtl::ArraySlice handles) { + struct SessionReleases { + tensorflow::ClientSession::FeedType feed_inputs; + std::vector releases; + }; + std::map session_releases; + for (auto& handle : handles) { + SessionData* session = GetSessionForDevice(handle.device); + SessionReleases* release = &session_releases[session]; + tensorflow::Scope device_scope = + session->root.WithDevice(TorchDeviceToXrtDevice(handle.device)); + const CachedNode& cached_node = + GetReleaseAllocationHandleNode(device_scope, handle.device); + release->feed_inputs.insert({cached_node.holders[0], handle.handle}); + release->releases.push_back(*cached_node.operation); + } + for (const auto& session_releases : session_releases) { + std::vector outputs; + TF_CHECK_OK(session_releases.first->root.status()); + TF_CHECK_OK(session_releases.first->session.Run( + session_releases.second.feed_inputs, {}, + session_releases.second.releases, &outputs)); + } +} + +void XrtComputationClient::FlushReleasedHandles() { + ReleaseHandles(released_handles_); + released_handles_.clear(); +} + +void XrtComputationClient::ApiCallInitialize() { + RewindCaches(); + FlushReleasedHandles(); +} + +void XrtComputationClient::ReleaseXrtData(XrtData* xrt_data) { + xrt_data->Release(); + released_handles_.emplace_back(xrt_data->device(), xrt_data->handle); +} + +std::pair +XrtComputationClient::GetWorkerForXrtDevice(const string& xrt_device) const { + tensorflow::DeviceNameUtils::ParsedName parsed_device; + CHECK( + tensorflow::DeviceNameUtils::ParseFullName(xrt_device, &parsed_device) && + parsed_device.has_job && parsed_device.has_task) + << xrt_device; + + auto worker_hostport = + options_.workers_map.find(Worker(parsed_device.job, parsed_device.task)); + CHECK(worker_hostport != options_.workers_map.end()) << xrt_device; + return std::pair(worker_hostport->first, + worker_hostport->second); +} + +const std::vector& XrtComputationClient::GetDeviceMeshCoords( + const string& xrt_device) const { + auto it = device_mesh_coords_.find(xrt_device); + if (it == device_mesh_coords_.end()) { + LOG(FATAL) << "Missing mesh coordinates for device: " << xrt_device; + } + return it->second; +} + +tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology( + const string& xrt_device) { + auto worker_hostport = GetWorkerForXrtDevice(xrt_device); + LOG(INFO) << "Initializing TPU system for worker " + << worker_hostport.first.name << ":" + << worker_hostport.first.task_no << " at " + << worker_hostport.second; + string system_device = + absl::StrCat("/job:", worker_hostport.first.name, + "/replica:0/task:", worker_hostport.first.task_no, + "/device:TPU_SYSTEM:0"); + SessionData* session = GetSessionForTarget(worker_hostport.second); + tensorflow::Scope tpu_system_scope = session->root.WithDevice(system_device); + const auto unique_name = + tpu_system_scope.GetUniqueNameForOp("ConfigureDistributedTPU"); + auto builder = tensorflow::NodeBuilder(unique_name, "ConfigureDistributedTPU") + .Attr("embedding_config", "") + .Attr("tpu_embedding_config", "") + .Attr("is_global_init", false); + tpu_system_scope.UpdateBuilder(&builder); + + tensorflow::Node* result; + session->root.UpdateStatus( + builder.Finalize(tpu_system_scope.graph(), &result)); + TF_CHECK_OK(tpu_system_scope.status()); + session->root.UpdateStatus(tpu_system_scope.DoShapeInference(result)); + + std::vector outputs; + TF_CHECK_OK(session->root.status()); + TF_CHECK_OK(session->session.Run({tensorflow::Output(result, 0)}, &outputs)); + CHECK_EQ(outputs.size(), 1); + + tensorflow::tpu::TopologyProto topology_proto; + CHECK(topology_proto.ParseFromString(outputs[0].scalar()())); + return topology_proto; +} + +void XrtComputationClient::InitializeDevices() { + auto it = options_.device_map.find("TPU:0"); + if (it != options_.device_map.end()) { + tensorflow::tpu::TopologyProto topology_proto = + InitializeAndFetchTopology(it->second); + LOG(INFO) << "TPU topology: " << topology_proto.DebugString(); + + tensorflow::DeviceNameUtils::ParsedName parsed_device; + CHECK(tensorflow::DeviceNameUtils::ParseFullName(it->second, + &parsed_device) && + parsed_device.has_job) + << it->second; + string tpu_job_name = parsed_device.job; + for (const auto& dev_target : options_.device_map) { + CHECK(tensorflow::DeviceNameUtils::ParseFullName(dev_target.second, + &parsed_device) && + parsed_device.has_job && parsed_device.has_task && + parsed_device.has_id) + << dev_target.second; + if (parsed_device.job != tpu_job_name) { + continue; + } + CHECK_LE(parsed_device.task, topology_proto.num_tasks()); + CHECK_LE(parsed_device.id, topology_proto.num_tpu_devices_per_task()); + // The topology proto 'device_coordinates' is a linear list of + // [num_tasks][devices_per_task][mesh_shape_size] coordinates, where the + // mesh coordinates are usually [x, y, c] ('x' and 'y' being the spatial + // chip coordinated and 'c' the core number). + int64 base_index = parsed_device.task * + topology_proto.num_tpu_devices_per_task() * + topology_proto.mesh_shape_size() + + parsed_device.id * topology_proto.mesh_shape_size(); + std::vector device_mesh_coords(topology_proto.mesh_shape_size()); + for (int i = 0; i < topology_proto.mesh_shape_size(); ++i) { + device_mesh_coords[i] = + topology_proto.device_coordinates(base_index + i); + } + device_mesh_coords_.insert( + {dev_target.second, std::move(device_mesh_coords)}); + } + } +} + +string XrtComputationClient::GetDefaultDevice() const { + return options_.default_device; +} + +void XrtComputationClient::RewindCaches() { + for (auto& key_cache : node_cache_) { + key_cache.second.rewind(); + } +} + +const XrtComputationClient::CachedNode& +XrtComputationClient::GetCompileExecuteNode(const tensorflow::Scope& scope, + const string& device) { + NodeCache* cache = + &node_cache_[NodeCacheKey(device, NodeTypes::kCompileExecute)]; + if (cache->empty()) { + std::vector holders( + {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), + tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), + tensorflow::ops::Placeholder( + scope, tensorflow::DT_INT64, + tensorflow::ops::Placeholder::Shape({-1}))}); + auto computation_handle = tensorflow::ops::XRTCompile(scope, holders[0]); + std::unique_ptr node(new CachedNode( + tensorflow::ops::XRTExecute(scope, computation_handle.handle, + holders[1], + {tensorflow::Output(holders[2])}), + std::move(holders))); + cache->add(std::move(node)); + } + return cache->get(); +} + +const XrtComputationClient::CachedNode& +XrtComputationClient::GetCompileExecuteReadNode(const tensorflow::Scope& scope, + const string& device) { + NodeCache* cache = + &node_cache_[NodeCacheKey(device, NodeTypes::kCompileExecuteRead)]; + if (cache->empty()) { + std::vector holders( + {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), + tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING), + tensorflow::ops::Placeholder( + scope, tensorflow::DT_INT64, + tensorflow::ops::Placeholder::Shape({-1}))}); + auto computation_handle = tensorflow::ops::XRTCompile(scope, holders[0]); + auto execute_op = tensorflow::ops::XRTExecute( + scope, computation_handle.handle, holders[1], + {tensorflow::Output(holders[2])}); + std::unique_ptr node(new CachedNode( + tensorflow::ops::XRTReadLiteralAndRelease(scope, execute_op), + std::move(holders))); + cache->add(std::move(node)); + } + return cache->get(); +} + +const XrtComputationClient::CachedNode& XrtComputationClient::GetAllocateNode( + const tensorflow::Scope& scope, const string& device) { + NodeCache* cache = &node_cache_[NodeCacheKey(device, NodeTypes::kAllocate)]; + if (cache->empty()) { + std::vector holders( + {tensorflow::ops::Placeholder(scope, tensorflow::DT_STRING)}); + std::unique_ptr node(new CachedNode( + tensorflow::ops::XRTAllocate(scope, holders[0]), std::move(holders))); + cache->add(std::move(node)); + } + return cache->get(); +} + +const XrtComputationClient::CachedNode& +XrtComputationClient::GetReleaseAllocationHandleNode( + const tensorflow::Scope& scope, const string& device) { + NodeCache* cache = + &node_cache_[NodeCacheKey(device, NodeTypes::kReleaseAllocationHandle)]; + if (cache->empty()) { + std::vector holders( + {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64)}); + std::unique_ptr node(new CachedNode( + tensorflow::ops::XRTReleaseAllocationHandle(scope, holders[0]), + std::move(holders))); + cache->add(std::move(node)); + } + return cache->get(); +} + +const XrtComputationClient::CachedNode& XrtComputationClient::GetSubTupleNode( + const tensorflow::Scope& scope, const string& device) { + NodeCache* cache = &node_cache_[NodeCacheKey(device, NodeTypes::kSubTuple)]; + if (cache->empty()) { + std::vector holders( + {tensorflow::ops::Placeholder(scope, tensorflow::DT_INT64), + tensorflow::ops::Placeholder( + scope, tensorflow::DT_INT32, + tensorflow::ops::Placeholder::Shape({1}))}); + std::unique_ptr node(new CachedNode( + tensorflow::ops::XRTSubTuple(scope, holders[0], holders[1]), + std::move(holders))); + cache->add(std::move(node)); + } + return cache->get(); +} + +std::vector> +XrtComputationClient::BuildParallelArguments( + tensorflow::gtl::ArraySlice arguments) { + std::vector> para_arguments(1); + para_arguments[0].insert(para_arguments[0].end(), arguments.begin(), + arguments.end()); + return para_arguments; +} + +int64 XrtComputationClient::GetDeviceOrdinal(const string& device) { + auto pos = device.rfind(':'); + CHECK_NE(pos, string::npos) << device; + return std::stoi(device.substr(pos + 1)); +} + +} // namespace xla diff --git a/third_party/xla_client/xrt_computation_client.h b/third_party/xla_client/xrt_computation_client.h new file mode 100644 index 00000000000..e03c26e8c88 --- /dev/null +++ b/third_party/xla_client/xrt_computation_client.h @@ -0,0 +1,330 @@ +#ifndef TENSORFLOW_COMPILER_XLA_RPC_XRT_COMPUTATION_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_XRT_COMPUTATION_CLIENT_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/contrib/tpu/proto/topology.pb.h" + +namespace xla { + +class XrtComputationClient : public ComputationClient { + struct DeviceHandle { + DeviceHandle(string device, int64 handle) + : device(std::move(device)), handle(handle) {} + + string device; + int64 handle; + }; + + struct XrtData : public Data { + using Releaser = std::function; + + XrtData(string device, int64 handle, Shape device_shape, Releaser releaser) + : Data(std::move(device), std::move(device_shape)), + handle(handle), + releaser(std::move(releaser)) {} + + ~XrtData() override { + if (releaser) { + releaser(this); + } + } + + void Release() { releaser = nullptr; } + + int64 handle; + Releaser releaser; + }; + + public: + struct Worker { + Worker(string name, int task_no) + : name(std::move(name)), task_no(task_no) {} + + bool operator<(const Worker& rhs) const { + if (task_no != rhs.task_no) { + return task_no < rhs.task_no; + } + return name.compare(rhs.name) < 0; + } + + string name; + int task_no; + }; + + struct Options { + string default_device; + // Maps a PyTorch device ID (example, "GPU:0", "TPU:0") to the full + // coordinates in TF device format + // (ie, /job:tpu_worker/replica:0/task:0/device:TPU:0), of the worker + // exposing that device. + std::map device_map; + // Maps a TPU Worker with an HOST:PORT string. + std::map workers_map; + }; + + XrtComputationClient(Options options); + + std::shared_ptr TransferParameterToServer( + const xla::Literal& literal, const string& device) override; + + std::shared_ptr ExecuteComputation( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) override; + + std::unique_ptr ExecuteComputationAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* output_shape) override; + + std::vector> ExecuteReplicated( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape) override; + + StatusOr>> DeconstructTuple( + const Data& data) override; + + string GetDefaultDevice() const override; + + private: + struct SessionData { + SessionData(const string& target) + : root(tensorflow::Scope::NewRootScope()), session(root, target) {} + + tensorflow::Scope root; + tensorflow::ClientSession session; + }; + + // A cached node captures that single node, or the mini-graph root node, + // together with the place-holders necessary to feed the node/sub-graph. + // The end-point node can be either a tensorflow Operation or an Output. + struct CachedNode { + CachedNode(tensorflow::Output output, + std::vector holders) + : output(std::move(output)), holders(std::move(holders)) {} + CachedNode(tensorflow::Operation operation, + std::vector holders) + : operation(std::move(operation)), holders(std::move(holders)) {} + + absl::optional output; + absl::optional operation; + std::vector holders; + }; + + // The node cache holds a set of CachedNode of the same kind (by the means of + // the NodeTypes entries). + struct NodeCache { + bool empty() const { return next >= nodes.size(); } + + void add(std::unique_ptr node) { + nodes.push_back(std::move(node)); + } + + const CachedNode& get() { + CHECK_LT(next, nodes.size()); + return *nodes[next++]; + } + + void rewind() { next = 0; } + + size_t next = 0; + std::vector> nodes; + }; + + // Every "kind" of cached node (or group of nodes - mini graph), have an ID + // entry here. + enum class NodeTypes { + kCompileExecute, + kCompileExecuteRead, + kAllocate, + kSubTuple, + kReleaseAllocationHandle, + }; + + struct NodeCacheKey { + NodeCacheKey(string device, NodeTypes type) + : device(std::move(device)), type(type) {} + + bool operator<(const NodeCacheKey& rhs) const { + return type != rhs.type ? (type < rhs.type) + : (device.compare(rhs.device) < 0); + } + + string device; + NodeTypes type; + }; + + struct ExecuteContext { + ExecuteContext(tensorflow::Output execute_output, Shape result_shape) + : execute_output(std::move(execute_output)), + result_shape(std::move(result_shape)) {} + + tensorflow::Output execute_output; + Shape result_shape; + }; + + SessionData* GetSessionForTarget(const string& target); + SessionData* GetSessionForXrtDevice(const string& xrt_device); + SessionData* GetSessionForDevice(const string& device); + + const string& GetEffectiveDevice(const string& device) const; + + const string& TorchDeviceToXrtDevice(const string& device) const; + + std::unique_ptr CreateXrtComputation( + const XlaComputation& computation, int64 num_replicas, + const std::vector& devices, const Shape* output_shape) const; + + // Retrieves the unique, common, device for all the inputs. Issue a CHECK if + // the inputs are not on a common device, as we cannot create an XLA + // computation spanning multiple devices ATM. + string GetArgumentsDevice(tensorflow::gtl::ArraySlice arguments) const; + + // Retrieves the common device for each replica inputs (arguments[i]). The + // common device for each replica inputs must be unique across the replicas. + std::vector GetReplicasDevices( + const std::vector>& arguments) const; + + tensorflow::Tensor GetArgumentsInputs( + tensorflow::gtl::ArraySlice arguments, const string& device, + tensorflow::ClientSession::FeedType* feed_inputs); + + std::vector CreateExecuteOps( + const XlaComputation& computation, + const std::vector>& arguments, + const Shape* output_shape, std::vector* devices, + tensorflow::ClientSession::FeedType* feed_inputs); + + // Retrieves the worker,worker_host pair for a given XRT device (ie, + // /job:tpu_worker/replica:0/task:0/device:TPU:0). + std::pair GetWorkerForXrtDevice( + const string& xrt_device) const; + + void ReleaseHandles(tensorflow::gtl::ArraySlice handles); + + // Flushes all the outstanding released handles in one RPC swipe. + void FlushReleasedHandles(); + + // Function which is called at every entry into the XRT computation client + // APIs. Performs tasks to intialize the per-call context, like flushing all + // the accumulated handle releases, and rewinding the XRT node caches. + void ApiCallInitialize(); + + void ReleaseXrtData(XrtData* xrt_data); + + // Retrieves the mesh coordinates of a given XRT device. + const std::vector& GetDeviceMeshCoords(const string& xrt_device) const; + + tensorflow::tpu::TopologyProto InitializeAndFetchTopology( + const string& xrt_device); + + void InitializeDevices(); + + // Rewinds all the XRT node caches, marking all the cached nodes as free. + void RewindCaches(); + + // Creates an XRT graph with an XRTCompile, feeding into an XRTExecute + // operation: + // + // XRTExecute( + // XRTCompile(holders[0]), + // holders[1], + // holders[2] + // ) + // + // With: + // holders[0] = XLA Computation place-holder (DT_STRING) + // holders[1] = xrt::XRTExecutionConfig place-holder (DT_STRING) + // holders[2] = Inputs for the XRTExecute (DT_INT64[]) + const CachedNode& GetCompileExecuteNode(const tensorflow::Scope& scope, + const string& device); + + // Creates an XRT graph with an XRTCompile, feeding into an XRTExecute, + // feeding into an XRTReadLiteralAndRelease operation: + // + // XRTReadLiteralAndRelease( + // XRTExecute( + // XRTCompile(holders[0]), + // holders[1], + // holders[2] + // ) + // ) + // + // With: + // holders[0] = XLA Computation place-holder (DT_STRING) + // holders[1] = xrt::XRTExecutionConfig place-holder (DT_STRING) + // holders[2] = Inputs for the XRTExecute (DT_INT64[]) + const CachedNode& GetCompileExecuteReadNode(const tensorflow::Scope& scope, + const string& device); + + // Creates an XRTAllocate node: + // + // XRTAllocate( + // holders[0] + // ) + // + // With: + // holders[0] = xrt::XLAAllocation place-holder (DT_STRING) + const CachedNode& GetAllocateNode(const tensorflow::Scope& scope, + const string& device); + + // Creates an XRTReleaseAllocationHandle node: + // + // XRTReleaseAllocationHandle( + // holders[0] + // ) + // + // With: + // holders[0] = To be released handle place-holder (DT_INT64) + const CachedNode& GetReleaseAllocationHandleNode( + const tensorflow::Scope& scope, const string& device); + + // Creates an XRTSubTuple node: + // + // XRTSubTuple( + // holders[0], + // holders[1] + // ) + // + // With: + // holders[0] = Tuple handle place-holder (DT_INT64) + // holders[1] = Tuple index place-holder (DT_INT32[]) + const CachedNode& GetSubTupleNode(const tensorflow::Scope& scope, + const string& device); + + // Builds an argument vector usable in a replicated context, out of a single + // replica argument vector. Essentially turns a [N] into a [1][N]. + static std::vector> BuildParallelArguments( + tensorflow::gtl::ArraySlice arguments); + + // Retrieves the ordinal number out of a device string. This is the number + // after the last ':' character of the device string. + static int64 GetDeviceOrdinal(const string& device); + + Options options_; + std::map> device_mesh_coords_; + std::map> session_map_; + std::vector released_handles_; + std::map node_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_XRT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py new file mode 100644 index 00000000000..788a66ee974 --- /dev/null +++ b/torch_xla/__init__.py @@ -0,0 +1,2 @@ +import torch +import _C diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp new file mode 100644 index 00000000000..4c6afd4f354 --- /dev/null +++ b/torch_xla/csrc/batch_norm.cpp @@ -0,0 +1,44 @@ +#include "batch_norm.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +BatchNormOutput BuildBatchNorm(const Node* node, const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::XlaOp& bias) { + auto builder = input.builder(); + const float eps_value = node->get(attr::eps).value().to(); + const auto eps = XlaHelpers::ScalarValue(eps_value, builder); + const auto one = XlaHelpers::ScalarValue(1, builder); + const auto half = XlaHelpers::ScalarValue(0.5f, builder); + + auto outputs = xla::BatchNormTraining(input, weight, bias, eps_value, 1); + auto output = xla::GetTupleElement(outputs, 0); + auto save_mean = xla::GetTupleElement(outputs, 1); + auto save_var = xla::GetTupleElement(outputs, 2); + auto save_invstd_eps = one / xla::Pow(save_var + eps, half); + return {output, save_mean, save_invstd_eps}; +} + +BatchNormGrads BuildBatchNormBackward(const Node* node, const xla::XlaOp& grad, + const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::XlaOp& save_mean, + const xla::XlaOp& save_invstd_eps) { + auto builder = grad.builder(); + const float eps_value = node->get(attr::eps).value().to(); + const auto eps = XlaHelpers::ScalarValue(eps_value, builder); + const auto one = XlaHelpers::ScalarValue(1, builder); + const auto two = XlaHelpers::ScalarValue(2, builder); + const auto save_var = xla::Pow(one / save_invstd_eps, two) - eps; + const auto grads = xla::BatchNormGrad(input, weight, save_mean, save_var, + grad, eps_value, 1); + const auto grad_input = xla::GetTupleElement(grads, 0); + const auto grad_weight = xla::GetTupleElement(grads, 1); + const auto grad_bias = xla::GetTupleElement(grads, 2); + return {grad_input, grad_weight, grad_bias}; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/batch_norm.h b/torch_xla/csrc/batch_norm.h new file mode 100644 index 00000000000..02c3977e8cc --- /dev/null +++ b/torch_xla/csrc/batch_norm.h @@ -0,0 +1,32 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +struct BatchNormOutput { + xla::XlaOp output; + xla::XlaOp save_mean; // batch_mean + xla::XlaOp save_invstd_eps; // 1 / sqrt(batch_var + eps) +}; + +struct BatchNormGrads { + xla::XlaOp grad_input; + xla::XlaOp grad_weight; + xla::XlaOp grad_bias; +}; + +BatchNormOutput BuildBatchNorm(const Node* node, const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::XlaOp& bias); + +BatchNormGrads BuildBatchNormBackward(const Node* node, const xla::XlaOp& grad, + const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::XlaOp& save_mean, + const xla::XlaOp& save_invstd_eps); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp new file mode 100644 index 00000000000..c5555389899 --- /dev/null +++ b/torch_xla/csrc/convolution.cpp @@ -0,0 +1,302 @@ +#include "convolution.h" +#include "helpers.h" +#include "tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "translator.h" + +namespace torch { +namespace jit { + +namespace { + +// Computes the input gradient for a convolution. +xla::XlaOp BuildThnnConv2dBackwardInput( + const Node* node, const xla::XlaOp& grad, const xla::XlaOp& weight, + const xla::PrecisionConfig::Precision conv_precision) { + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), 9); + const auto padding_attr = + node->get>(attr::padding).value(); + CHECK_EQ(padding_attr.size(), 2); + // Adjust input size to account for specified padding. + auto input_size = XlaHelpers::TensorDimensionSizes(node_inputs[1]); + for (int i = 0; i < 2; ++i) { + input_size[2 + i] += 2 * padding_attr[i]; + } + tensorflow::TensorShape input_shape(XlaHelpers::I64List(input_size)); + const auto filter = xla::Transpose(weight, {2, 3, 1, 0}); + auto builder = grad.builder(); + const auto filter_size = + XlaHelpers::ShapeSizes(builder->GetShape(filter).ValueOrDie()); + tensorflow::TensorShape filter_shape(filter_size); + tensorflow::TensorShape out_backprop_shape( + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(node_inputs[0]))); + const auto stride_attr = + node->get>(attr::stride).value(); + std::vector strides{1, 1}; + std::copy(stride_attr.begin(), stride_attr.end(), + std::back_inserter(strides)); + tensorflow::ConvBackpropDimensions dims; + constexpr int num_spatial_dims = 2; + std::vector dilations{1, 1, 1, 1}; + const auto status = ConvBackpropComputeDimensionsV2( + "thnn_conv2d_backward", num_spatial_dims, input_shape, filter_shape, + out_backprop_shape, dilations, strides, tensorflow::Padding::VALID, + tensorflow::TensorFormat::FORMAT_NCHW, &dims); + CHECK(status.ok()) << status.error_message(); + + constexpr int batch_dim = 0; + constexpr int feature_dim = 1; + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(num_spatial_dims); + + std::vector kernel_spatial_dims(num_spatial_dims); + std::vector> padding(num_spatial_dims); + std::vector lhs_dilation(num_spatial_dims); + std::vector rhs_dilation(num_spatial_dims); + std::vector ones(num_spatial_dims, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + xla::int64 dim = 2 + i; + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // We'll need to undo the initial input padding once on the input backprop + // result since edges are constant and have to be discarded for the gradient. + xla::PaddingConfig padding_config; + for (int i = 0; i < 2; ++i) { + padding_config.add_dimensions(); + } + for (int i = 0; i < 2; ++i) { + auto* dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(-padding_attr[i]); + dims->set_edge_padding_high(-padding_attr[i]); + } + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(conv_precision); + return xla::Pad( + xla::ConvGeneralDilated(grad, mirrored_weights, + /*window_strides=*/ones, padding, lhs_dilation, + rhs_dilation, dnums, + /*feature_group_count=*/1, &precision_config), + XlaHelpers::ScalarValue(0, builder), padding_config); +} + +// Computes the weight gradient for a convolution. +xla::XlaOp BuildThnnConv2dBackwardWeight( + const Node* node, const xla::XlaOp& grad, const xla::XlaOp& input, + const xla::PrecisionConfig::Precision conv_precision) { + constexpr int n_dim = 0; + constexpr int c_dim = 1; + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), 9); + const auto padding_attr = + node->get>(attr::padding).value(); + CHECK_EQ(padding_attr.size(), 2); + // Adjust input size to account for specified padding. + auto input_size = XlaHelpers::TensorDimensionSizes(node_inputs[1]); + for (int i = 0; i < 2; ++i) { + input_size[2 + i] += 2 * padding_attr[i]; + } + tensorflow::TensorShape activations_shape(XlaHelpers::I64List(input_size)); + const auto filter_size = + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(node_inputs[2])); + std::vector filter_size_backward{filter_size[2], filter_size[3], + filter_size[1], filter_size[0]}; + tensorflow::TensorShape filter_shape(filter_size_backward); + tensorflow::TensorShape out_backprop_shape( + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(node_inputs[0]))); + const auto stride_attr = + node->get>(attr::stride).value(); + std::vector strides{1, 1}; + std::copy(stride_attr.begin(), stride_attr.end(), + std::back_inserter(strides)); + tensorflow::ConvBackpropDimensions dims; + constexpr int num_spatial_dims = 2; + std::vector dilations{1, 1, 1, 1}; + const auto status = ConvBackpropComputeDimensionsV2( + "thnn_conv2d_backward", num_spatial_dims, activations_shape, filter_shape, + out_backprop_shape, dilations, strides, tensorflow::Padding::VALID, + tensorflow::TensorFormat::FORMAT_NCHW, &dims); + CHECK(status.ok()) << status.error_message(); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector> padding(num_spatial_dims); + std::vector rhs_dilation(num_spatial_dims); + std::vector window_strides(num_spatial_dims); + std::vector ones(num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims); + dnums.set_output_feature_dimension(num_spatial_dims + 1); + + for (int i = 0; i < num_spatial_dims; ++i) { + xla::int64 dim = 2 + i; + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const xla::int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const xla::int64 pad_total = + padded_in_size - dims.spatial_dims[i].input_size; + + // Pad the bottom/right side with the remaining space. + const xla::int64 pad_before = 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = dilations[dim]; + } + + // Redo the initial input padding. + const auto padding_config = XlaHelpers::MakeXlaPaddingConfig(padding_attr); + + auto builder = grad.builder(); + const auto padded_input = xla::Pad( + input, XlaHelpers::ScalarValue(0, builder), padding_config); + + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(conv_precision); + return xla::Transpose( + xla::ConvGeneralDilated(padded_input, grad, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums, + /*feature_group_count=*/1, &precision_config), + {3, 2, 0, 1}); +} + +std::vector> MakePadding(const Node* node) { + std::vector> dims_padding; + const auto padding = node->get>(attr::padding).value(); + for (const auto dim_padding : padding) { + dims_padding.emplace_back(dim_padding, dim_padding); + } + return dims_padding; +} + +} // namespace + +xla::XlaOp BuildConvolution( + const Node* node, const xla::XlaOp& input, const xla::XlaOp& kernel, + const xla::PrecisionConfig::Precision conv_precision) { + const auto window_strides = XlaHelpers::I64List( + node->get>(attr::stride).value()); + const auto dims_padding = MakePadding(node); + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(conv_precision); + return xla::ConvWithGeneralPadding( + input, kernel, window_strides, dims_padding, + /*feature_group_count*/ 1, &precision_config); +} + +xla::XlaOp BuildConvolutionBias( + const Node* node, const xla::XlaOp& input, const xla::XlaOp& kernel, + const xla::XlaOp& bias, + const xla::PrecisionConfig::Precision conv_precision) { + const auto node_inputs = node->inputs(); + CHECK_GE(node_inputs.size(), size_t(4)); + const auto window_strides = XlaHelpers::I64List( + node->get>(attr::stride).value()); + const auto bias_size = XlaHelpers::TensorDimensionSizes(node_inputs[3]); + const auto node_outputs = node->outputs(); + auto broadcast_sizes = + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(node_outputs[0])); + CHECK_EQ(broadcast_sizes.size(), 4); + // Remove the channels dimension. + broadcast_sizes.erase(broadcast_sizes.begin() + 1); + // Make the bias match the output dimensions. + const auto bias_broadcast = + xla::Transpose(xla::Broadcast(bias, broadcast_sizes), {0, 3, 1, 2}); + const auto conv = BuildConvolution(node, input, kernel, conv_precision); + return conv + bias_broadcast; +} + +Conv2DGrads BuildConv2dBackward( + const Node* node, const xla::XlaOp& grad, const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::PrecisionConfig::Precision conv_precision) { + const auto grad_input = + BuildThnnConv2dBackwardInput(node, grad, weight, conv_precision); + // TODO: support weight and bias gradients + const auto grad_weight = + BuildThnnConv2dBackwardWeight(node, grad, input, conv_precision); + auto builder = grad.builder(); + const auto grad_bias = + xla::Reduce(grad, XlaHelpers::ScalarValue(0, builder), + XlaHelpers::CreateAddComputation(), {0, 2, 3}); + return {grad_input, grad_weight, grad_bias}; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/convolution.h b/torch_xla/csrc/convolution.h new file mode 100644 index 00000000000..0b994b4c666 --- /dev/null +++ b/torch_xla/csrc/convolution.h @@ -0,0 +1,34 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Computes the convolution of the given input and kernel with the given +// precision, with the stride and padding specified by the node attributes. +xla::XlaOp BuildConvolution( + const Node* node, const xla::XlaOp& input, const xla::XlaOp& kernel, + const xla::PrecisionConfig::Precision conv_precision); + +// Same as above, then broadcasts the bias and adds it to the result. +xla::XlaOp BuildConvolutionBias( + const Node* node, const xla::XlaOp& input, const xla::XlaOp& kernel, + const xla::XlaOp& bias, + const xla::PrecisionConfig::Precision conv_precision); + +struct Conv2DGrads { + xla::XlaOp grad_input; + xla::XlaOp grad_weight; + xla::XlaOp grad_bias; +}; + +// Computes the gradients for a convolution. +Conv2DGrads BuildConv2dBackward( + const Node* node, const xla::XlaOp& grad, const xla::XlaOp& input, + const xla::XlaOp& weight, + const xla::PrecisionConfig::Precision conv_precision); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp new file mode 100644 index 00000000000..1ddaf7378cc --- /dev/null +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -0,0 +1,17 @@ +#include "cross_replica_reduces.h" +#include +#include "helpers.h" + +namespace torch { +namespace jit { + +xla::XlaOp BuildCrossReplicaSum(const xla::XlaOp& operand, int num_replicas) { + xla::XlaOp crs = xla::CrossReplicaSum(operand); + auto scaling_value = + XlaHelpers::ScalarValue(1.0 / num_replicas, operand.builder()); + auto shape = operand.builder()->GetShape(crs).ValueOrDie(); + return crs * xla::Broadcast(scaling_value, XlaHelpers::ShapeSizes(shape)); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h new file mode 100644 index 00000000000..d3d67ddee85 --- /dev/null +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -0,0 +1,13 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace torch { +namespace jit { + +// Builds a Cross Replica Sum operation on the operand, and scales the result by +// 1.0/num_replicas. +xla::XlaOp BuildCrossReplicaSum(const xla::XlaOp& operand, int num_replicas); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp new file mode 100644 index 00000000000..24cec568d48 --- /dev/null +++ b/torch_xla/csrc/data_ops.cpp @@ -0,0 +1,170 @@ +#include "data_ops.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +namespace { + +// Graph nodes specify -1 for unknown dimensions. Return true iff all dimension +// sizes are positive. +bool IsCompleteShape(const std::vector& dim_sizes) { + return std::all_of(dim_sizes.begin(), dim_sizes.end(), + [](const int64_t dim_size) { return dim_size >= 0; }); +} + +} // namespace + +xla::XlaOp BuildView(const Node* node, const xla::XlaOp& input) { + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), 2); + const auto input_sizes = XlaHelpers::TensorDimensionSizes(node_inputs[0]); + const auto node_outputs = node->outputs(); + CHECK_EQ(node_outputs.size(), 1); + // Try to use the second argument of the operator as the target shape. + std::vector output_sizes; + switch (node->kind()) { + case aten::view: + output_sizes = node->get>(attr::size).value(); + break; + case aten::reshape: + output_sizes = node->get>(attr::shape).value(); + break; + default: + LOG(FATAL) << "Unexpected node kind, must be view or reshape"; + } + // If the second argument doesn't fully specify the target shape, use the size + // of the output. + if (!IsCompleteShape(output_sizes)) { + CHECK(node_outputs[0]->type()->cast()); + output_sizes = XlaHelpers::TensorDimensionSizes(node_outputs[0]); + } + JIT_ASSERTM(IsCompleteShape(output_sizes), + "Cannot infer target size for aten::view"); + return xla::Reshape(input, XlaHelpers::I64List(output_sizes)); +} + +xla::XlaOp BuildExpand(const Node* node, const xla::XlaOp& input) { + const auto node_inputs = node->inputs(); + CHECK_GE(node_inputs.size(), 1); + auto input_sizes = XlaHelpers::TensorDimensionSizes(node_inputs[0]); + const auto node_outputs = node->outputs(); + CHECK_EQ(node_outputs.size(), 1); + const auto output_sizes = XlaHelpers::TensorDimensionSizes(node_outputs[0]); + // Adjust the rank of the input to match the rank of the output. + CHECK_LE(input_sizes.size(), output_sizes.size()); + for (size_t i = 0; i < output_sizes.size() - input_sizes.size(); ++i) { + input_sizes.insert(input_sizes.begin(), 1); + } + const auto implicit_reshape = + xla::Reshape(input, XlaHelpers::I64List(input_sizes)); + // Squeeze the trivial (of size 1) dimensions. + std::vector non_singleton_dimensions; + std::copy_if(input_sizes.begin(), input_sizes.end(), + std::back_inserter(non_singleton_dimensions), + [](const size_t dim_size) { return dim_size != 1; }); + const auto squeezed_input = + xla::Reshape(implicit_reshape, non_singleton_dimensions); + // Broadcast the squeezed tensor, the additional dimensions are to the left. + std::vector broadcast_sizes; + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] == 1) { + broadcast_sizes.push_back(output_sizes[i]); + } + } + const auto broadcast = xla::Broadcast(squeezed_input, broadcast_sizes); + // Bring the dimensions added by broadcast where the trivial dimensions were. + std::vector reshape_permutation; + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] == 1) { + reshape_permutation.push_back(i); + } + } + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] != 1) { + reshape_permutation.push_back(i); + } + } + return xla::Reshape(broadcast, reshape_permutation, + XlaHelpers::I64List(output_sizes)); +} + +// Finds a prim::ListConstruct operation by id in the graph of "parent". +std::vector InputListAttr(const Node* parent, const size_t id) { + const auto nodes = parent->owningGraph()->block()->nodes(); + std::vector result; + for (const auto node : nodes) { + if (node->kind() != prim::ListConstruct) { + continue; + } + const auto node_outputs = node->outputs(); + CHECK_EQ(node_outputs.size(), size_t(1)); + const auto output = node_outputs[0]; + if (output->unique() != id) { + continue; + } + const auto node_inputs = node->inputs(); + for (const auto input : node_inputs) { + result.push_back(input); + } + return result; + } + CHECK(false) << "Constant with id " << id << " not found."; +} + +xla::XlaOp BuildStack(const Node* node, + const std::function& node_op, + xla::XlaBuilder* b) { + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), size_t(2)); + const auto stack_inputs = InputListAttr(node, node_inputs[0]->unique()); + const auto dim = node->get(attr::dim).value(); + std::vector reshaped_inputs; + // Reshape inputs along the dim axis. + for (size_t i = 0; i < stack_inputs.size(); ++i) { + auto reshaped_input_size = + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(stack_inputs[i])); + reshaped_input_size.insert(reshaped_input_size.begin() + dim, 1); + const auto stack_input = stack_inputs[i]; + reshaped_inputs.push_back( + xla::Reshape(node_op(stack_input), reshaped_input_size)); + } + return xla::ConcatInDim(b, reshaped_inputs, dim); +} + +xla::XlaOp BuildCat(const Node* node, + const std::function& node_op, + xla::XlaBuilder* b) { + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), size_t(2)); + const auto stack_inputs = InputListAttr(node, node_inputs[0]->unique()); + const auto dim = node->get(attr::dim).value(); + std::vector cat_inputs; + // Reshape inputs along the dim axis. + for (size_t i = 0; i < stack_inputs.size(); ++i) { + const auto stack_input = stack_inputs[i]; + cat_inputs.push_back(node_op(stack_input)); + } + return xla::ConcatInDim(b, cat_inputs, dim); +} + +std::vector BuildChunk(const Node* node, const xla::XlaOp& input) { + const auto node_input = node->inputs()[0]; + int64_t chunks = node->get(attr::chunks).value(); + int64_t dim = node->get(attr::dim).value(); + int64_t size_in_dim = XlaHelpers::TensorDimensionSizes(node_input)[dim]; + int64_t split_size = (size_in_dim + chunks - 1) / chunks; + std::vector split_sizes(chunks, split_size); + split_sizes[chunks - 1] = split_size - (split_size * chunks - size_in_dim); + std::vector splits(chunks); + int64_t start_idx = 0; + for (int64_t i = 0; i < chunks; ++i) { + const auto length = split_sizes[i]; + splits[i] = SliceInDim(input, start_idx, start_idx + length, 1, dim); + start_idx += length; + } + return splits; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h new file mode 100644 index 00000000000..f6494b35d6d --- /dev/null +++ b/torch_xla/csrc/data_ops.h @@ -0,0 +1,38 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +// Collection of XLA lowerings for operations which only involve some form of +// data movement and no computation. + +namespace torch { +namespace jit { + +// Creates a new tensor with the same data as the input tensor and the size +// specified by the "size" attribute of the given node. +xla::XlaOp BuildView(const Node* node, const xla::XlaOp& input); + +// Creates a new tensor with the singleton dimensions expanded to the sizes +// specified by the "size" attribute of the given node. +xla::XlaOp BuildExpand(const Node* node, const xla::XlaOp& input); + +// Concatenates a list of tensors along a new dimension specified by the "dim" +// attribute of the given node. +xla::XlaOp BuildStack(const Node* node, + const std::function& node_op, + xla::XlaBuilder* b); + +// Concatenates a list of tensors along an existing dimension specified by the +// "dim" attribute of the given node. +xla::XlaOp BuildCat(const Node* node, + const std::function& node_op, + xla::XlaBuilder* b); + +// Splits a tensor into a specific number of chunks specified by the "chunks" +// attribute of the given node, along an existing dimension specified by the +// "dim" attribute of the given node. +std::vector BuildChunk(const Node* node, const xla::XlaOp& input); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp new file mode 100644 index 00000000000..34ee6373e51 --- /dev/null +++ b/torch_xla/csrc/elementwise.cpp @@ -0,0 +1,62 @@ +#include "elementwise.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +xla::XlaOp BuildArithmeticOp(const Node* node, const xla::XlaOp& lhs, + const xla::XlaOp& rhs) { + switch (node->kind()) { + case aten::add: { + return lhs + rhs; + } + case aten::mul: { + return lhs * rhs; + } + default: + LOG(FATAL) << "Invalid binary operator kind: " << node->kind(); + } +} + +xla::XlaOp BuildComparisonOp(const Node* node, const xla::XlaOp& operand) { + auto builder = operand.builder(); + const auto xla_other = XlaHelpers::ScalarValue( + node->get(attr::other).value().to(), builder); + xla::XlaOp pred; + switch (node->kind()) { + case aten::gt: { + pred = xla::Gt(operand, xla_other); + break; + } + default: + LOG(FATAL) << "Invalid binary operator kind: " << node->kind(); + } + return xla::ConvertElementType(pred, xla::PrimitiveType::S8); +} + +xla::XlaOp BuildThreshold(const Node* node, const xla::XlaOp& input, + const xla::XlaOp& output, const float threshold, + const float value, xla::XlaBuilder* b) { + const auto node_inputs = node->inputs(); + const auto input_sizes = XlaHelpers::TensorDimensionSizes(node_inputs[0]); + std::vector broadcast_sizes(input_sizes.begin(), + input_sizes.end()); + const auto xla_threshold = XlaHelpers::ScalarValue(threshold, b); + const auto xla_value = XlaHelpers::ScalarValue(value, b); + return xla::Select(xla::Gt(input, xla_threshold), output, + xla::Broadcast(xla_value, broadcast_sizes)); +} + +xla::XlaOp BuildTypeAs(const Node* node, const xla::XlaOp& operand) { + const auto node_outputs = node->outputs(); + CHECK_EQ(node_outputs.size(), 1); + const auto output_tensor_type = + node_outputs[0]->type()->cast(); + CHECK(output_tensor_type); + const auto target_type = + XlaHelpers::MakeXlaPrimitiveType(output_tensor_type->scalarType()); + return xla::ConvertElementType(operand, target_type); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h new file mode 100644 index 00000000000..f7acb889600 --- /dev/null +++ b/torch_xla/csrc/elementwise.h @@ -0,0 +1,26 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Computes binary arithmetic operations. +xla::XlaOp BuildArithmeticOp(const Node* node, const xla::XlaOp& lhs, + const xla::XlaOp& rhs); + +// Computes binary comparison operations. +xla::XlaOp BuildComparisonOp(const Node* node, const xla::XlaOp& operand); + +// Converts the given operand to the type specified by the given node. +xla::XlaOp BuildTypeAs(const Node* node, const xla::XlaOp& operand); + +// Computes the elementwise threshold of the input: if the value is below the +// threshold, replace it with the provided value, otherwise leave it unchanged. +xla::XlaOp BuildThreshold(const Node* node, const xla::XlaOp& input, + const xla::XlaOp& output, const float threshold, + const float value, xla::XlaBuilder* b); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/graph_context.cpp b/torch_xla/csrc/graph_context.cpp new file mode 100644 index 00000000000..47e2346e48a --- /dev/null +++ b/torch_xla/csrc/graph_context.cpp @@ -0,0 +1,60 @@ +#include "graph_context.h" +#include "absl/strings/str_cat.h" + +namespace torch { +namespace jit { + +xla::XlaOp XlaGraphContext::GetParameter( + const std::shared_ptr& data) { + auto it = parameters_map_.find(data.get()); + if (it == parameters_map_.end()) { + xla::XlaOp param = + xla::Parameter(builder(), parameters_.size(), data->shape(), + absl::StrCat("param_", parameters_.size())); + parameters_.push_back(data); + it = parameters_map_.emplace(data.get(), param).first; + } + return it->second; +} + +std::vector XlaGraphContext::GetParametersData() + const { + std::vector parameters; + for (auto& param : parameters_) { + parameters.push_back(param.get()); + } + return parameters; +} + +xla::int64 XlaGraphContext::AddResult(xla::XlaOp op) { + root_tuple_.push_back(std::move(op)); + return root_tuple_.size() - 1; +} + +xla::StatusOr XlaGraphContext::Build() { + if (!root_tuple_.empty()) { + auto root = xla::Tuple(builder(), root_tuple_); + return builder()->Build(root); + } + return builder()->Build(); +} + +xla::StatusOr XlaGraphContext::Build( + const xla::XlaOp& root) { + CHECK(root_tuple_.empty()); + return builder()->Build(root); +} + +XlaGraphNode::XlaGraphNode( + Generator generator, xla::Shape shape, + tensorflow::gtl::ArraySlice> inputs) + : generator_(std::move(generator)), + shape_(std::move(shape)), + inputs_(inputs.begin(), inputs.end()) { + for (auto& input : inputs_) { + graph_size_ += input->graph_size(); + } +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/graph_context.h b/torch_xla/csrc/graph_context.h new file mode 100644 index 00000000000..c4ede69df35 --- /dev/null +++ b/torch_xla/csrc/graph_context.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace torch { +namespace jit { + +// Tracks an evolving XLA computation. +class XlaGraphContext { + public: + XlaGraphContext() : builder_("XlaGraphContext") {} + + xla::XlaBuilder* builder() { return &builder_; } + + // If a parameter associated with data has already been declared, it will be + // returned. Otherwise a new one will be created, associated with the tensor + // held in data. + xla::XlaOp GetParameter( + const std::shared_ptr& data); + + // Retrieves the vector holding all the tensors associated with the parameter + // instructions which have been created. + std::vector GetParametersData() const; + + // Adds the output of a given operation to the result tuple. + xla::int64 AddResult(xla::XlaOp op); + + // Build the XLA computation capturing all the operations created with the + // embedded XLA builder (returned by the builder() API). + xla::StatusOr Build(); + + // Build the XLA computation capturing all the operations created with the + // embedded XLA builder (returned by the builder() API). + // Uses root as return value forthe computation. It is an error to use this + // API after having called the AddResult() API. + xla::StatusOr Build(const xla::XlaOp& root); + + private: + xla::XlaBuilder builder_; + std::vector> parameters_; + std::map parameters_map_; + std::vector root_tuple_; +}; + +// A class whose task is to encapsulate the generation of an XLA operation. +class XlaGraphNode { + public: + // The generation function used by the XLA tensors to create xla::XlaOp nodes. + using Generator = std::function( + XlaGraphContext*, const XlaGraphNode&)>; + + static std::shared_ptr New( + Generator generator, xla::Shape shape, + tensorflow::gtl::ArraySlice> inputs) { + return std::make_shared(std::move(generator), + std::move(shape), std::move(inputs)); + } + + XlaGraphNode( + Generator generator, xla::Shape shape, + tensorflow::gtl::ArraySlice> inputs); + + // Runs the generator function using the ctx argument, and returns the XLA + // operation which is the end result of the generation. + xla::StatusOr Generate(XlaGraphContext* ctx) const { + return generator_(ctx, *this); + } + + const xla::Shape& shape() const { return shape_; } + + const std::shared_ptr& input(xla::int64 ordinal) const { + return inputs_[ordinal]; + } + + xla::int64 graph_size() const { return graph_size_; } + + private: + Generator generator_; + xla::Shape shape_; + std::vector> inputs_; + xla::int64 graph_size_ = 1; +}; + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp new file mode 100644 index 00000000000..2788ba850b7 --- /dev/null +++ b/torch_xla/csrc/helpers.cpp @@ -0,0 +1,72 @@ +#include "helpers.h" + +namespace torch { +namespace jit { + +xla::PrecisionConfig XlaHelpers::BuildPrecisionConfig( + const xla::PrecisionConfig::Precision conv_precision) { + xla::PrecisionConfig precision_config; + // Dot and convolution take two operators. + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, conv_precision); + return precision_config; +} + +std::vector XlaHelpers::TensorDimensionSizes(const Value* tensor) { + const auto tensor_type = tensor->type()->cast(); + CHECK(tensor_type); + return tensor_type->sizes(); +} + +std::vector XlaHelpers::I64List(const at::IntList& input) { + std::vector output(input.size()); + std::copy(input.begin(), input.end(), output.begin()); + return output; +} + +xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfig( + const std::vector& padding) { + xla::PaddingConfig padding_config; + for (int i = 0; i < 2; ++i) { + padding_config.add_dimensions(); + } + for (int i = 0; i < 2; ++i) { + auto* dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(padding[i]); + dims->set_edge_padding_high(padding[i]); + } + return padding_config; +} + +xla::XlaComputation XlaHelpers::CreateAddComputation() { + xla::XlaBuilder reduction_builder("xla_add_computation"); + const auto x = xla::Parameter( + &reduction_builder, 0, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "x"); + const auto y = xla::Parameter( + &reduction_builder, 1, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "y"); + Add(x, y); + return reduction_builder.Build().ConsumeValueOrDie(); +} + +xla::PrimitiveType XlaHelpers::MakeXlaPrimitiveType( + const at::ScalarType scalar_type) { + switch (scalar_type) { + case at::ScalarType::Float: + return xla::PrimitiveType::F32; + case at::ScalarType::Long: + return xla::PrimitiveType::S64; + default: + LOG(FATAL) << "Type not supported: " << scalar_type; + } +} + +std::vector XlaHelpers::ShapeSizes(const xla::Shape& shape) { + std::vector shape_sizes(shape.dimensions().begin(), + shape.dimensions().end()); + return shape_sizes; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h new file mode 100644 index 00000000000..33b299cdcca --- /dev/null +++ b/torch_xla/csrc/helpers.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Miscellaneous helpers for XLA lowering. +class XlaHelpers { + public: + // Creates a XLA constant for the given scalar_value. + template + static xla::XlaOp ScalarValue(T scalar_value, xla::XlaBuilder* builder) { + const auto scalar_literal = xla::LiteralUtil::CreateR0(scalar_value); + return xla::ConstantLiteral(builder, scalar_literal); + } + + // Returns the list of dimension sizes for the given shape. + static std::vector ShapeSizes(const xla::Shape& shape); + + // Creates a scalar broadcasted to a given shape. + template + static xla::XlaOp ScalarBroadcast(T scalar_value, const xla::Shape& shape, + xla::XlaBuilder* builder) { + auto scalar_op = ScalarValue(scalar_value, builder); + return xla::Broadcast(scalar_op, ShapeSizes(shape)); + } + + // Creates a convolution or dot precision configuration. + static xla::PrecisionConfig BuildPrecisionConfig( + const xla::PrecisionConfig::Precision conv_precision); + + // Returns the dimension sizes for the given tensor. + static std::vector TensorDimensionSizes(const Value* tensor); + + // Converts int64_t's to XLA int64's. + static std::vector I64List(const at::IntList& input); + + // Creates an XLA padding configuration from a padding attribute value. + static xla::PaddingConfig MakeXlaPaddingConfig( + const std::vector& padding); + + // Creates a binary add computation. + static xla::XlaComputation CreateAddComputation(); + + // Converts the given scalar type to an XLA primitive type. + static xla::PrimitiveType MakeXlaPrimitiveType( + const at::ScalarType scalar_type); +}; + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp new file mode 100644 index 00000000000..e8aadcac332 --- /dev/null +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -0,0 +1,172 @@ +#include "init_python_bindings.h" +#include "module.h" +#include "passes/eval_static_size.h" +#include "passes/replace_untraced_operators.h" +#include "passes/threshold_backward_peephole.h" +#include "torch_util.h" + +namespace torch { +namespace jit { + +namespace { + +void InitXlaModuleBindings(py::module m) { + py::class_>(m, "XlaModule") + .def(py::init([](const std::shared_ptr module, + bool use_full_conv_precision) { + return std::make_shared(module, + use_full_conv_precision); + }), + py::arg("module"), py::arg("use_full_conv_precision") = false) + .def("__call__", + [](XlaModule& xla_module, py::args args) -> py::object { + auto inputs = XlaCreateTensorList(args); + auto outputs = xla_module.forward(inputs); + return XlaPackTensorList(outputs); + }) + .def("backward", + [](XlaModule& xla_module, py::args args) { + auto inputs = XlaCreateTensorList(args); + xla_module.backward(inputs); + }) + .def("parameters", + [](XlaModule& xla_module) { return xla_module.parameters(); }) + .def("parameters_buffers", [](XlaModule& xla_module) { + return xla_module.parameters_buffers(); + }); + m.def("_xla_mul_add_multi", + [](const double scale_dest, + const std::vector>& dest_tuple, + const double alpha, + const std::vector>& source_tuple) { + XLATensor::MulAddMulti(scale_dest, dest_tuple, alpha, source_tuple); + }); + m.def("_xla_zero_multi", + [](const std::vector>& dest_tuple) { + XLATensor::ZeroMulti(dest_tuple); + }); +} + +void InitXlaPassesBindings(py::module m) { + m.def("_jit_pass_eval_static_size", EvalStaticSize); + m.def("_jit_pass_replace_untraced_operators", ReplaceUntracedOperators); + m.def("_jit_pass_threshold_backward_peephole", ThresholdBackwardPeephole); +} + +void InitXlaTensorBindings(py::module m) { + py::class_>(m, "XLATensor") + .def(py::init([](autograd::Variable tensor, const std::string& device) { + return std::make_shared( + tensor, XLATensor::DeviceFromString(device)); + }), + py::arg("tensor"), py::arg("device") = "") + .def("to_tensor", [](XLATensor& s) { return s.toTensor(); }) + .def("size", [](const XLATensor& s) { return s.Size(); }) + .def("__add__", [](std::shared_ptr self, + XLATensor& other) { return self->add(other, 1.0); }) + .def("add", [](std::shared_ptr self, double alpha, + XLATensor& other) { return self->add(other, alpha); }) + .def("add_", + [](std::shared_ptr self, double alpha, XLATensor& other) { + self->add_(other, alpha); + return self; + }) + .def("add_", + [](std::shared_ptr self, XLATensor& other) { + self->add_(other, 1.); + return self; + }) + .def("__mul__", + [](std::shared_ptr self, XLATensor& other) { + return self->mul(other); + }, + py::arg("other")) + .def("__mul__", [](std::shared_ptr self, + double other) { return self->mul(other); }) + .def("mul", + [](std::shared_ptr self, XLATensor& other) { + return self->mul(other); + }, + py::arg("other")) + .def("mul", [](std::shared_ptr self, + double other) { return self->mul(other); }) + .def("mul_", + [](std::shared_ptr self, XLATensor& other) { + self->mul_(other); + return self; + }, + py::arg("other")) + .def("mul_", + [](std::shared_ptr self, double other) { + self->mul_(other); + return self; + }) + .def("__div__", + [](std::shared_ptr self, XLATensor& other) { + return self->div(other); + }, + py::arg("other")) + .def("__div__", [](std::shared_ptr self, + double other) { return self->div(other); }) + .def("__truediv__", + [](std::shared_ptr self, XLATensor& other) { + return self->div(other); + }, + py::arg("other")) + .def("__truediv__", [](std::shared_ptr self, + double other) { return self->div(other); }) + .def("cross_replica_sum", + [](std::shared_ptr self, const py::list& groups) { + std::vector> crs_groups; + for (auto& group : groups) { + crs_groups.emplace_back(); + for (auto& replica_id : group.cast()) { + crs_groups.back().push_back(replica_id.cast()); + } + } + return self->cross_replica_sum(crs_groups); + }) + .def("zero_", + [](std::shared_ptr self) { + self->zero_(); + return self; + }) + .def("detach_", + [](std::shared_ptr self) { + self->detach_(); + return self; + }) + .def_property_readonly( + "data", + [](std::shared_ptr self) { + return py::cast>(self->Clone()); + }) + .def_property_readonly("is_leaf", [](const XLATensor&) { return true; }) + .def_property_readonly( + "grad", + [](XLATensor& m) -> py::object { + if (m.grad() == nullptr) { + return py::none(); + } else { + return py::cast>(m.grad()); + } + }) + .def("__repr__", [](XLATensor& m) { + std::ostringstream s; + s << m.toTensor(); + return s.str(); + }); +} + +} // namespace + +void InitXlaBindings(py::module m) { + InitXlaModuleBindings(m); + InitXlaPassesBindings(m); + InitXlaTensorBindings(m); +} + +} // namespace jit +} // namespace torch + +PYBIND11_MODULE(_C, m) { torch::jit::InitXlaBindings(m); } diff --git a/torch_xla/csrc/init_python_bindings.h b/torch_xla/csrc/init_python_bindings.h new file mode 100644 index 00000000000..a4df65898d2 --- /dev/null +++ b/torch_xla/csrc/init_python_bindings.h @@ -0,0 +1,12 @@ +#pragma once + +#include "torch/csrc/jit/pybind.h" + +namespace torch { +namespace jit { + +// Initialize bindings for XLA module, tensor and optimization passes. +void InitXlaBindings(py::module m); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/log_softmax.cpp b/torch_xla/csrc/log_softmax.cpp new file mode 100644 index 00000000000..b64a2fab7cb --- /dev/null +++ b/torch_xla/csrc/log_softmax.cpp @@ -0,0 +1,80 @@ +#include "log_softmax.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +namespace { + +xla::XlaComputation CreateMaxComputation() { + xla::XlaBuilder reduction_builder("xla_max_computation"); + const auto x = xla::Parameter( + &reduction_builder, 0, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "x"); + const auto y = xla::Parameter( + &reduction_builder, 1, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "y"); + xla::Max(x, y); + return reduction_builder.Build().ConsumeValueOrDie(); +} + +} // namespace + +xla::XlaOp BuildLogSoftmax(const Node* node, const xla::XlaOp& logits) { + // Inspired from tf2xla. + const auto node_inputs = node->inputs(); + CHECK_EQ(node_inputs.size(), size_t(2)); + xla::int64 dim = node->get(attr::dim).value(); + + auto input_size = XlaHelpers::TensorDimensionSizes(node_inputs[0]); + + std::vector broadcast_dimensions; + for (size_t broadcast_dim = 0; broadcast_dim < input_size.size(); + ++broadcast_dim) { + if (broadcast_dim == dim) { + continue; + } + broadcast_dimensions.push_back(broadcast_dim); + } + + const auto max_func = CreateMaxComputation(); + const auto min_value = xla::LiteralUtil::MinValue(xla::PrimitiveType::F32); + auto builder = logits.builder(); + const auto logits_max = xla::Reduce( + logits, xla::ConstantLiteral(builder, min_value), max_func, {dim}); + const auto shifted_logits = + xla::Sub(logits, logits_max, broadcast_dimensions); + const auto exp_shifted = xla::Exp(shifted_logits); + const auto init_value = XlaHelpers::ScalarValue(0, builder); + const auto reduce = xla::Reduce(exp_shifted, init_value, + XlaHelpers::CreateAddComputation(), {dim}); + return xla::Sub(shifted_logits, xla::Log(reduce), broadcast_dimensions); +} + +xla::XlaOp BuildLogSoftmaxGrad(const Node* node, const xla::XlaOp& grad_output, + const xla::XlaOp& output) { + // Inspired from tf2xla. + xla::int64 dim = node->get(attr::dim).value(); + + const auto node_inputs = node->inputs(); + auto input_size = XlaHelpers::TensorDimensionSizes(node_inputs[0]); + std::vector broadcast_dimensions; + for (size_t broadcast_dim = 0; broadcast_dim < input_size.size(); + ++broadcast_dim) { + if (broadcast_dim == dim) { + continue; + } + broadcast_dimensions.push_back(broadcast_dim); + } + + auto builder = grad_output.builder(); + const auto init_value = XlaHelpers::ScalarValue(0, builder); + const auto sum = xla::Reduce(grad_output, init_value, + XlaHelpers::CreateAddComputation(), {dim}); + + return xla::Sub(grad_output, + xla::Mul(xla::Exp(output), sum, broadcast_dimensions)); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/log_softmax.h b/torch_xla/csrc/log_softmax.h new file mode 100644 index 00000000000..839a1c62a52 --- /dev/null +++ b/torch_xla/csrc/log_softmax.h @@ -0,0 +1,18 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Computes log(softmax(logits)) along the dimension specified by the "dim" +// attribute of the given node. +xla::XlaOp BuildLogSoftmax(const Node* node, const xla::XlaOp& logits); + +// Computes the gradient of the input of the LogSoftmax function. +xla::XlaOp BuildLogSoftmaxGrad(const Node* node, const xla::XlaOp& grad_output, + const xla::XlaOp& output); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/module.cpp b/torch_xla/csrc/module.cpp new file mode 100644 index 00000000000..51052d4781a --- /dev/null +++ b/torch_xla/csrc/module.cpp @@ -0,0 +1,605 @@ +#include "module.h" + +#include +#include "c10/util/Exception.h" +#include "cross_replica_reduces.h" +#include "passes/eval_static_size.h" +#include "passes/remove_unused_forward_outputs.h" +#include "passes/replace_untraced_operators.h" +#include "passes/threshold_backward_peephole.h" +#include "torch/csrc/jit/passes/canonicalize_ops.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/specialize_undef.h" + +namespace torch { +namespace jit { +namespace { + +void GatherParameters(std::vector* values, + std::vector* requires_grad, + const script::Module& m) { + for (auto& param : m.get_parameters()) { + values->push_back(param->slot()); + requires_grad->push_back(!param->is_buffer); + } + for (const auto& sub : m.get_modules()) { + GatherParameters(values, requires_grad, *sub->module); + } +} + +// If "result_dh" is a tuple, decompose it and return a list of tensors created +// from its components, otherwise return a singleton list containing a single +// tensor created from it. +XlaModule::TensorBatchVector::value_type DecomposeComputationResult( + std::shared_ptr result_dh, + const xla::Shape& result_shape, uint64_t module_id) { + std::vector> decomposed_result; + std::vector forward_ret_shape = GetComponentShapes(result_shape); + auto client = XlaGetClient(); + if (forward_ret_shape.size() > 1) { + // The result is a tuple. Decompose it into a list of tensors. + auto result_components = client->DeconstructTuple(*result_dh).ValueOrDie(); + CHECK_EQ(forward_ret_shape.size(), result_components.size()); + for (size_t i = 0; i < result_components.size(); ++i) { + auto& tuple_element = result_components[i]; + auto result_component = + std::make_shared(std::move(tuple_element), module_id); + decomposed_result.push_back(result_component); + } + } else { + CHECK_EQ(forward_ret_shape.size(), 1); + auto result_component = + std::make_shared(std::move(result_dh), module_id); + decomposed_result.push_back(result_component); + } + return decomposed_result; +} + +XlaModule::TensorBatchVector DecomposeComputationResult( + const std::vector>& results, + const xla::Shape& result_shape, uint64_t module_id) { + XlaModule::TensorBatchVector batch_tensors; + for (auto& result : results) { + batch_tensors.push_back( + DecomposeComputationResult(result, result_shape, module_id)); + } + return batch_tensors; +} + +} // namespace + +std::atomic XlaModule::s_module_id_(1); +constexpr uint64_t XlaModule::kInvalidModuleId; + +XlaModule::XlaModule(const std::shared_ptr module, + bool use_full_conv_precision) + : use_full_conv_precision_(use_full_conv_precision), + enable_trace_fusion_(true), + module_id_(s_module_id_++), + script_module_(module) {} + +void XlaModule::Initialize(const TensorBatchVector& inputs) { + if (script_module_ == nullptr) { + return; + } + + // Get forward graph. + const auto forward = script_module_->find_method("forward"); + JIT_ASSERT(forward); + std::shared_ptr forward_graph = forward->graph()->copy(); + // Run forward passes. + CanonicalizeOps(forward_graph); + EvalStaticSize(forward_graph); + ConstantPropagation(forward_graph); + ReplaceUntracedOperators(forward_graph); + EliminateDeadCode(forward_graph); + + // Convert model parameters to vector of XLATensors. + std::vector params_buffers_regather; + std::vector param_requires_grad; + GatherParameters(¶ms_buffers_regather, ¶m_requires_grad, + *script_module_); + // The loop below is going to send individual parameters to the different + // cores. We might need to do something smarter here. + auto devices = CommonDevicesForReplicas(inputs); + for (const auto& device : devices) { + TensorBatchVector::value_type replica_params; + TensorBatchVector::value_type optimizable_replica_params; + for (size_t j = 0; j < params_buffers_regather.size(); ++j) { + replica_params.push_back(std::make_shared( + autograd::as_variable_ref(*params_buffers_regather[j]), device)); + if (param_requires_grad[j]) { + optimizable_replica_params.push_back(replica_params.back()); + } + } + all_params_.push_back(std::move(replica_params)); + optimizable_params_.push_back(std::move(optimizable_replica_params)); + } + // Collect the requires-gradient property making sure all the replica inputs + // agree on it. + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& replica_inputs = inputs[i]; + if (i == 0) { + for (const auto& p : replica_inputs) { + inputs_require_grad_.push_back(p->RequiresGrad()); + } + } else { + for (size_t j = 0; j < replica_inputs.size(); ++j) { + CHECK(inputs_require_grad_[j] == replica_inputs[j]->RequiresGrad()) + << "Input " << j << " of replica " << i + << " does not match the requires-grad property"; + } + } + } + inputs_require_grad_.insert(inputs_require_grad_.end(), + param_requires_grad.begin(), + param_requires_grad.end()); + + // Automatically differentiate the forward graph to get the backward graph. + // Since differentiation is mutating the graph, do it on a copy. + auto forward_graph_copy = forward_graph->copy(); + Gradient gradient = differentiate(forward_graph_copy); + + // Run the forward passes. + CanonicalizeOps(gradient.f); + EvalStaticSize(gradient.f); + ConstantPropagation(gradient.f); + ReplaceUntracedOperators(gradient.f); + EliminateDeadCode(gradient.f); + // Run the backward passes. + specializeUndef(*(gradient.df.get())); + EvalStaticSize(gradient.df); + ConstantPropagation(gradient.df); + ThresholdBackwardPeephole(gradient.df); + EliminateDeadCode(gradient.df); + // Run pass on forward and backward graphs that drops outputs that XLA doesn't + // need. + RemoveUnusedForwardOutputs(gradient); + + // Record the number of outputs for the forward computation and the captured + // input and output indices to be used by the backward computation. + f_real_outputs_ = gradient.f_real_outputs; + df_input_captured_inputs_ = gradient.df_input_captured_inputs; + df_input_captured_outputs_ = gradient.df_input_captured_outputs; + + // Take ownership of the forward and differentiated graphs and release the + // reference to the script module to mark initialization as done. + f_ = gradient.f; + df_ = gradient.df; + // Mark the module as initialized. + script_module_ = nullptr; +} + +void XlaModule::CheckInitialized() const { + // script_module_ is null after initialization. + if (script_module_ != nullptr) { + AT_ERROR("Module not initialized; did forward method run?"); + } +} + +XlaModule::TensorBatchVector XlaModule::forward( + const TensorBatchVector& inputs) { + Initialize(inputs); + if (enable_trace_fusion_) { + const auto return_node = df_->return_node(); + const auto node_inputs = return_node->inputs(); + if (!node_inputs.empty()) { + return RunFusedTrain(inputs); + } + } + return RunUnfusedForward(inputs); +} + +void XlaModule::backward(const TensorBatchVector& grad_outputs) { + CheckInitialized(); + // Tensors could have pending in-place operations, apply them first to reset + // their parent module and thus invalidate the gradients we set aside from the + // fused computation. + FlushTensorsOperations({&grad_outputs, &optimizable_params_}); + + // If we're in trace fusion mode, we start with the assumption that the input + // gradients are still valid and invalidate it if we don't receive the output + // from the forward trace to compute the gradient on. If not, we have no + // gradients by definition, since only forward pass has executed. + bool input_gradients_valid = enable_trace_fusion_; + for (size_t i = 0; forward_computation_ && i < grad_outputs.size(); ++i) { + for (const auto& grad_output : grad_outputs[i]) { + if (grad_output->ForwardModuleId() != module_id_ && + enable_trace_fusion_) { + // This is not a direct output of the forward pass. Redo the forward + // computation to capture the intermediate outputs correctly and set + // enable_trace_fusion_ to false to avoid doing fusion for the next + // training batches. + forward_computation_ = at::nullopt; + RunUnfusedForward(inputs_); + input_gradients_valid = false; + enable_trace_fusion_ = false; + break; + } + } + } + if (input_gradients_valid) { + // We already have the gradients from the fused computation, just set the + // gradients for input and parameters. + for (size_t i = 0; i < inputs_.size(); ++i) { + auto& replica_inputs = inputs_[i]; + auto& replica_grad_inputs = grad_inputs_[i]; + auto& replica_optimizable_params = optimizable_params_[i]; + JIT_ASSERT(inputs_require_grad_.size() >= + replica_inputs.size() + replica_optimizable_params.size()); + size_t grad_index = 0; + for (size_t j = 0; j < replica_inputs.size(); j++) { + if (inputs_require_grad_[j]) { + replica_inputs[j]->setGrad(replica_grad_inputs[grad_index]); + ++grad_index; + } + } + for (size_t j = 0; j < replica_optimizable_params.size(); j++) { + replica_optimizable_params[j]->setGrad(replica_grad_inputs[grad_index]); + ++grad_index; + } + } + return; + } + // NOTE: The order of the input parameters passed to the BuildComputation() + // call to build the backward computation is critical, as they have to match + // the sequence of the graph->inputs() vector. Before the gradients passed in + // from the backward() call, then then zeroed virtual inputs, and then the + // captured inputs/outputs. + TensorBatchVector raw_grad_outputs; + std::vector zero_input; + for (size_t i = 0; i < grad_outputs.size(); ++i) { + TensorBatchVector::value_type replica_raw_grad_outputs; + for (auto p : grad_outputs[i]) { + replica_raw_grad_outputs.push_back(p); + if (i == 0) { + zero_input.push_back(false); + } + } + for (auto p : captured_outputs_[i]) { + // TODO(asuhan): Remove the all zero grad outputs from the forward trace + // output. + replica_raw_grad_outputs.push_back(p); + if (i == 0) { + zero_input.push_back(true); + } + } + for (auto p : captured_inputs_outputs_[i]) { + replica_raw_grad_outputs.push_back(p); + if (i == 0) { + zero_input.push_back(false); + } + } + raw_grad_outputs.push_back(std::move(replica_raw_grad_outputs)); + } + // If backward graph is not compiled, compile it. + if (!backward_computation_) { + // The shape for all the replicas are the same, so use replica[0] for + // building the shapes vector for the BuildComputation() call. + const auto& replica_raw_grad_outputs = raw_grad_outputs.front(); + std::vector backward_shapes; + for (size_t j = 0; j < replica_raw_grad_outputs.size(); ++j) { + backward_shapes.push_back(XlaTranslator::ParameterShape( + replica_raw_grad_outputs[j]->shape(), zero_input[j])); + } + + XlaTranslator xla_bwd_impl(df_, GetPrecisionConfig()); + backward_computation_ = xla_bwd_impl.BuildComputation( + backward_shapes, GetBackwardBuildOptions(0, inputs_.size())); + } + // Collect the computation client data vector. + DataBatchVector raw_grad_outputs_data = + GetDataBatchVector(raw_grad_outputs, &zero_input); + auto devices = CommonDevicesForReplicas(grad_outputs); + const auto program_shape = + backward_computation_->GetProgramShape().ValueOrDie(); + const auto result_shape = program_shape.result(); + auto result_shape_with_layout = + MakeShapeWithDeviceLayout(result_shape, devices.front().hw_type); + + TensorBatchVector grad_inputs = + Execute(*backward_computation_, raw_grad_outputs_data, result_shape, + &result_shape_with_layout, kInvalidModuleId); + + for (size_t i = 0; i < inputs_.size(); ++i) { + auto& replica_grad_inputs = grad_inputs[i]; + auto& replica_inputs = inputs_[i]; + auto& replica_optimizable_params = optimizable_params_[i]; + JIT_ASSERT((replica_inputs.size() + replica_optimizable_params.size()) == + replica_grad_inputs.size()); + // Set .grad attributes of the input and parameter tensors. + for (size_t j = 0; j < replica_inputs.size(); j++) { + replica_inputs[j]->setGrad(replica_grad_inputs[j]); + } + for (size_t j = 0; j < replica_optimizable_params.size(); j++) { + auto t = replica_grad_inputs[j + replica_inputs.size()]; + replica_optimizable_params[j]->setGrad(t); + } + } + // Release handles to saved / captured inputs and outputs. + inputs_.clear(); + captured_outputs_.clear(); + captured_inputs_outputs_.clear(); +} + +XlaModule::TensorBatchVector XlaModule::RunFusedTrain( + const TensorBatchVector& inputs) { + Initialize(inputs); + TensorBatchVector inputs_params_buffers = PrepareForwardInput(inputs); + if (!forward_computation_) { + // Shapes are going to be the same for all replicas, so use the ones of the + // first replica here. + std::vector forward_shapes; + for (auto p : inputs_params_buffers.front()) { + forward_shapes.push_back( + XlaTranslator::ParameterShape(p->shape(), /*zero_input=*/false)); + } + BuildFusedTrainComputation(forward_shapes); + } + DataBatchVector inputs_params_buffers_data = + GetDataBatchVector(inputs_params_buffers, /*zero_input=*/nullptr); + const auto program_shape = + forward_computation_->GetProgramShape().ValueOrDie(); + const auto result_shape = program_shape.result(); + // The result is always a tuple of outputs and gradients. + CHECK(xla::ShapeUtil::IsTuple(result_shape)) + << xla::ShapeUtil::HumanString(result_shape); + const auto device = XLATensor::CommonDeviceForTensors(inputs.front()); + auto result_shape_with_layout = + MakeShapeWithDeviceLayout(result_shape, device.hw_type); + + TensorBatchVector result_components = + Execute(*forward_computation_, inputs_params_buffers_data, result_shape, + &result_shape_with_layout, module_id_); + + // First f_real_outputs_ are the forward outputs returned to user code. + CHECK_LE(f_real_outputs_, result_components.front().size()); + grad_inputs_.clear(); + TensorBatchVector forward_result; + for (auto& replica_result_components : result_components) { + TensorBatchVector::value_type replica_forward_result; + TensorBatchVector::value_type replica_grad_inputs; + for (size_t j = 0; j < f_real_outputs_; ++j) { + replica_forward_result.push_back(replica_result_components[j]); + } + for (size_t j = f_real_outputs_; j < replica_result_components.size(); + ++j) { + replica_grad_inputs.push_back(replica_result_components[j]); + } + forward_result.push_back(std::move(replica_forward_result)); + grad_inputs_.push_back(std::move(replica_grad_inputs)); + } + return forward_result; +} + +const XlaModule::TensorBatchVector& XlaModule::parameters() { + CheckInitialized(); + return optimizable_params_; +} + +const XlaModule::TensorBatchVector& XlaModule::parameters_buffers() { + CheckInitialized(); + return all_params_; +} + +xla::PrecisionConfig::Precision XlaModule::GetPrecisionConfig() const { + return use_full_conv_precision_ ? xla::PrecisionConfig::HIGHEST + : xla::PrecisionConfig::DEFAULT; +} + +void XlaModule::BuildFusedTrainComputation( + const std::vector& forward_shapes) { + XlaTranslator xla_fwd_impl(f_, GetPrecisionConfig()); + xla::XlaBuilder b("XlaFusedComputation"); + // Build the forward pass program without compiling it, the backward pass + // needs to be called before finalizing it. + auto computation_in_outs = + xla_fwd_impl.BuildComputationProgram(forward_shapes, &b); + // Take the XLA outputs from the forward pass and set them for the backward + // call in the same order the standalone, unfused version takes its arguments. + CHECK(!computation_in_outs.outputs.empty()); + std::vector grad_outputs; + for (size_t i = 0; i < f_real_outputs_; i++) { + grad_outputs.push_back(computation_in_outs.outputs[i]); + } + std::vector captured_outputs; + for (size_t i = f_real_outputs_; i < computation_in_outs.outputs.size(); + i++) { + captured_outputs.push_back(computation_in_outs.outputs[i]); + } + std::vector captured_inputs_outputs; + for (auto i : df_input_captured_inputs_) { + captured_inputs_outputs.push_back(computation_in_outs.inputs[i]); + } + for (auto i : df_input_captured_outputs_) { + captured_inputs_outputs.push_back(computation_in_outs.outputs[i]); + } + // NOTE: The order of the input parameters passed to the BuildComputation() + // call to build the backward computation is critical, as they have to match + // the sequence of the graph->inputs() vector. Before the gradients returned + // by the forward pass, then then zeroed virtual inputs, and then the captured + // inputs/outputs. + std::vector backward_shapes; + std::vector backward_operands; + for (auto p : grad_outputs) { + backward_shapes.push_back(XlaTranslator::ParameterShape( + b.GetShape(p).ValueOrDie(), /*zero_input=*/false)); + backward_operands.push_back(p); + } + for (auto p : captured_outputs) { + backward_shapes.push_back(XlaTranslator::ParameterShape( + b.GetShape(p).ValueOrDie(), /*zero_input=*/true)); + } + for (auto p : captured_inputs_outputs) { + backward_shapes.push_back(XlaTranslator::ParameterShape( + b.GetShape(p).ValueOrDie(), /*zero_input=*/false)); + backward_operands.push_back(p); + } + // The arguments are set up correctly, call into the backward computation. + XlaTranslator xla_bwd_impl(df_, GetPrecisionConfig()); + auto backward_computation = xla_bwd_impl.BuildComputation( + backward_shapes, + GetBackwardBuildOptions(f_real_outputs_, inputs_.size())); + xla::Call(&b, backward_computation, backward_operands); + forward_computation_ = b.Build().ValueOrDie(); +} + +XlaModule::TensorBatchVector XlaModule::RunUnfusedForward( + const TensorBatchVector& inputs) { + TensorBatchVector inputs_params_buffers = PrepareForwardInput(inputs); + + // Lazy-convert forward graph to XlaComputation. + if (!forward_computation_) { + // Shapes are going to be the same for all replicas, so use the ones of the + // first replica here. + std::vector forward_shapes; + for (auto p : inputs_params_buffers.front()) { + forward_shapes.push_back( + XlaTranslator::ParameterShape(p->shape(), /*zero_input=*/false)); + } + + XlaTranslator xla_fwd_impl(f_, GetPrecisionConfig()); + forward_computation_ = xla_fwd_impl.BuildComputation(forward_shapes); + } + DataBatchVector inputs_params_buffers_data = + GetDataBatchVector(inputs_params_buffers, /*zero_input=*/nullptr); + const auto program_shape = + forward_computation_->GetProgramShape().ValueOrDie(); + const auto result_shape = program_shape.result(); + const auto device = XLATensor::CommonDeviceForTensors(inputs.front()); + auto result_shape_with_layout = + MakeShapeWithDeviceLayout(result_shape, device.hw_type); + + TensorBatchVector raw_outputs = + Execute(*forward_computation_, inputs_params_buffers_data, result_shape, + &result_shape_with_layout, kInvalidModuleId); + + TensorBatchVector outputs; + for (size_t i = 0; i < raw_outputs.size(); ++i) { + auto& replica_raw_outputs = raw_outputs[i]; + TensorBatchVector::value_type replica_outputs; + for (size_t j = 0; j < f_real_outputs_; j++) { + replica_outputs.push_back(replica_raw_outputs[j]); + } + outputs.push_back(std::move(replica_outputs)); + + TensorBatchVector::value_type replica_captured_outputs; + for (size_t j = f_real_outputs_; j < replica_raw_outputs.size(); j++) { + replica_captured_outputs.push_back(replica_raw_outputs[j]); + } + captured_outputs_.push_back(std::move(replica_captured_outputs)); + + auto& replica_inputs_params_buffers = inputs_params_buffers[i]; + TensorBatchVector::value_type replica_captured_inputs_outputs; + for (auto j : df_input_captured_inputs_) { + replica_captured_inputs_outputs.push_back( + replica_inputs_params_buffers[j]); + } + for (auto j : df_input_captured_outputs_) { + replica_captured_inputs_outputs.push_back(replica_raw_outputs[j]); + } + captured_inputs_outputs_.push_back( + std::move(replica_captured_inputs_outputs)); + } + return outputs; +} + +XlaModule::TensorBatchVector XlaModule::PrepareForwardInput( + const TensorBatchVector& inputs) { + FlushTensorsOperations({&inputs, &optimizable_params_}); + // Clear the previous forward's captured vectors. + // This is needed in case backward is not yet run, but two forward calls were + // made. + captured_outputs_.clear(); + captured_inputs_outputs_.clear(); + // Needed so that in backward, we can set .grad attributes correctly. + inputs_ = inputs; + + TensorBatchVector inputs_params_buffers; + CHECK_EQ(inputs_.size(), all_params_.size()); + for (size_t i = 0; i < inputs_.size(); ++i) { + TensorBatchVector::value_type replica_inputs_params_buffers; + for (auto p : inputs_[i]) { + replica_inputs_params_buffers.push_back(p); + } + for (auto p : all_params_[i]) { + replica_inputs_params_buffers.push_back(p); + } + inputs_params_buffers.push_back(std::move(replica_inputs_params_buffers)); + } + return inputs_params_buffers; +} + +XlaModule::TensorBatchVector XlaModule::Execute( + const xla::XlaComputation& computation, const DataBatchVector& inputs, + const xla::Shape& result_shape, const xla::Shape* output_shape, + uint64_t module_id) { + auto client = XlaGetClient(); + TensorBatchVector result; + if (inputs.size() == 1) { + auto results = + client->ExecuteComputation(computation, inputs.front(), output_shape); + result.push_back(DecomposeComputationResult(std::move(results), + result_shape, module_id)); + } else { + auto results = client->ExecuteReplicated(computation, inputs, output_shape); + result = + DecomposeComputationResult(std::move(results), result_shape, module_id); + } + return result; +} + +XlaTranslator::BuildOptions XlaModule::GetBackwardBuildOptions( + size_t param_to_return_count, size_t num_replicas) { + XlaTranslator::BuildOptions options; + options.param_to_return_count = param_to_return_count; + if (num_replicas > 1) { + options.output_transform = [this, num_replicas](const xla::XlaOp& op, + size_t) { + return BuildCrossReplicaSum(op, num_replicas); + }; + } + return options; +} + +void XlaModule::FlushTensorsOperations( + std::initializer_list batch_tensors) { + for (auto batch_tensor : batch_tensors) { + for (const auto& tensors : *batch_tensor) { + XLATensor::ApplyPendingGraph(tensors); + } + } +} + +XlaModule::DataBatchVector XlaModule::GetDataBatchVector( + const TensorBatchVector& inputs, const std::vector* zero_input) { + DataBatchVector inputs_data; + for (auto& replica_inputs : inputs) { + DataBatchVector::value_type replica_inputs_data; + for (size_t j = 0; j < replica_inputs.size(); ++j) { + if (zero_input == nullptr || !zero_input->at(j)) { + replica_inputs_data.push_back(replica_inputs[j]->GetXlaData().get()); + } + } + inputs_data.push_back(std::move(replica_inputs_data)); + } + return inputs_data; +} + +std::vector XlaModule::CommonDevicesForReplicas( + const TensorBatchVector& inputs) { + std::vector devices; + std::set unique_devices; + for (auto& replica_inputs : inputs) { + devices.push_back(XLATensor::CommonDeviceForTensors(replica_inputs)); + CHECK(unique_devices.insert(devices.back()).second) + << "Duplicated device in different replicas: " + << devices.back().ToString(); + } + return devices; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/module.h b/torch_xla/csrc/module.h new file mode 100644 index 00000000000..2db4e91250b --- /dev/null +++ b/torch_xla/csrc/module.h @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include "tensor.h" +#include "torch/csrc/jit/script/module.h" +#include "torch/csrc/utils/disallow_copy.h" +#include "translator.h" + +#include +#include + +namespace torch { +namespace jit { + +struct XlaModule : public std::enable_shared_from_this { + TH_DISALLOW_COPY_AND_ASSIGN(XlaModule); + + // The i-th entry in this vector, is a vector of XLA tensors which belong the + // i-th replica. + using TensorBatchVector = + std::vector>>; + + // Creates a new XlaModule from a PyTorch script module "module". + // "use_full_conv_precision" controls whether to use maximum precision + // available in hardware for convolutions. + XlaModule(const std::shared_ptr module, + bool use_full_conv_precision); + + TensorBatchVector forward(const TensorBatchVector& inputs); + // For the given gradient outputs, compute the gradient of input and + // parameters and set it as their grad field. + void backward(const TensorBatchVector& grad_outputs); + + const TensorBatchVector& parameters(); + const TensorBatchVector& parameters_buffers(); + + static constexpr uint64_t kInvalidModuleId = 0; + + private: + // The i-th entry in this vector, is a vector of XLA computation data which + // belong the i-th replica. + using DataBatchVector = + std::vector>; + + void Initialize(const TensorBatchVector& inputs); + + void CheckInitialized() const; + + xla::PrecisionConfig::Precision GetPrecisionConfig() const; + + // Builds the fused forward and backward computation for RunFusedTrain. + void BuildFusedTrainComputation( + const std::vector& forward_shapes); + + // Runs the original, unfused forward computation on the given inputs. + TensorBatchVector RunUnfusedForward(const TensorBatchVector& inputs); + + // Runs a fused forward and backward computation. Takes the same input as the + // forward computation, returns the outputs for the forward computation and + // the gradients for model parameters. + TensorBatchVector RunFusedTrain(const TensorBatchVector& inputs); + + // Collect the inputs and model parameters and clear the captured inputs / + // outputs state. + TensorBatchVector PrepareForwardInput(const TensorBatchVector& inputs); + + // Executes the provided XLA computation. The execution will be replicated in + // as many replicas as the size of the inputs first dimension. + // The result_shape is the shape returned by the XLA computation, while + // output_shape (if not nullptr) is the shape+layout requested to the XLA + // compiler. The module_id is used to track changes in the tensors taking + // place of the fused computation, and will be assigned to the output tensors. + TensorBatchVector Execute(const xla::XlaComputation& computation, + const DataBatchVector& inputs, + const xla::Shape& result_shape, + const xla::Shape* output_shape, uint64_t module_id); + + // Creates the build options to be used to create a backward pass computation. + XlaTranslator::BuildOptions GetBackwardBuildOptions( + size_t param_to_return_count, size_t num_replicas); + + static void FlushTensorsOperations( + std::initializer_list batch_tensors); + + // Extracts the XLA computation data from the inputs, and returns a matching + // batch vector where data[i][j] holds the data beind the XLA tensor + // inputs[i][j]. + // Elements in the return vector are populated only if zero_input is nullptr, + // or if zero_input[j] is false. + static DataBatchVector GetDataBatchVector( + const TensorBatchVector& inputs, const std::vector* zero_input); + + // Returns the common device for every replica copy of the inputs. + // All common devices must be different in different replicas. + static std::vector CommonDevicesForReplicas( + const TensorBatchVector& inputs); + + // The module parameters which are marked for being subject to optimization. + TensorBatchVector optimizable_params_; + // All the module parameters (which include the optimizable_params_ ones). + TensorBatchVector all_params_; + c10::optional forward_computation_; + c10::optional backward_computation_; + + std::shared_ptr f_; + std::shared_ptr df_; + + // info for backwrd captures + size_t f_real_outputs_; + std::vector df_input_captured_inputs_; + std::vector df_input_captured_outputs_; + + // TODO: captured_outputs only needs shape, no need for holding onto full + // Tensor + TensorBatchVector inputs_; + std::vector inputs_require_grad_; + TensorBatchVector captured_outputs_; + TensorBatchVector captured_inputs_outputs_; + + // Specifies whether to use the highest precision available for convolutions. + // Currently it only makes a difference for TPUs. + const bool use_full_conv_precision_; + // Gradients set aside by the fused train computation, to be consumed by the + // backward call if we receive an unmodified tensor from the forward pass. + TensorBatchVector grad_inputs_; + // Keeps track whether an attempt to fuse the forward and backward + // computations failed. Starts on true (we attempt to fuse), permanently goes + // to false on failure. Mitigates doing redundant work (compute gradients we + // can't use) after the first training step, if the fusion fails. + bool enable_trace_fusion_; + // Unique identifier for the module, used to keep track of tensors originating + // from its forward method. + uint64_t module_id_; + + // Keep the script module alive for lazy initialization of this XlaModule. + // Once this XlaModule is initialized, script_module_ will be set to null. + std::shared_ptr script_module_; + + static std::atomic s_module_id_; +}; + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/nll_loss.cpp b/torch_xla/csrc/nll_loss.cpp new file mode 100644 index 00000000000..c150fbaa804 --- /dev/null +++ b/torch_xla/csrc/nll_loss.cpp @@ -0,0 +1,95 @@ +#include "nll_loss.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +namespace { + +// Converts "indices" into a one-hot representation. "depth" is the size of the +// new axis to add. "axis" is the position at which to add the new axis. +// "on_value" and "off_value" represent the values to use for the on and off +// positions, respectively. +xla::XlaOp LabelsToOneHot(xla::XlaBuilder* builder, xla::int64 depth, int axis, + const xla::XlaOp indices, const xla::XlaOp on_value, + const xla::XlaOp off_value) { + const auto indices_shape = builder->GetShape(indices).ValueOrDie(); + const int indices_dims = indices_shape.dimensions_size(); + const int output_dims = indices_dims + 1; + + // Expand the labels with a depth dimension for the classes. + std::vector output_dimensions(indices_shape.dimensions().begin(), + indices_shape.dimensions().end()); + output_dimensions.insert(output_dimensions.begin() + axis, depth); + + // Build a iota tensor populated with values 0 through depth - 1. + std::vector linspace_data(depth); + std::iota(linspace_data.begin(), linspace_data.end(), 0); + std::vector linspace_dims(output_dims, 1); + linspace_dims[axis] = depth; + const auto linspace_xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::PrimitiveType::S64, linspace_dims); + xla::BorrowingLiteral linspace_literal( + reinterpret_cast(linspace_data.data()), linspace_xla_shape); + + // Now compare the labels in index form to the iota tensor to get the one hot + // format. + std::vector broadcast_dims(indices_shape.dimensions_size()); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + xla::XlaOp linspace_xla; + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + + // Selects the user-provided off_value and on_value values. + return xla::Select(one_hot_bool, xla::Broadcast(on_value, output_dimensions), + xla::Broadcast(off_value, output_dimensions)); +} + +} // namespace + +// Builds the NLLLoss for log-probabilities "logits" and class indices "labels". +xla::XlaOp BuildNllLoss(const Node* node, const xla::XlaOp& logits, + const xla::XlaOp& labels) { + xla::XlaBuilder* builder = logits.builder(); + xla::Shape logits_shape = builder->GetShape(logits).ValueOrDie(); + xla::XlaOp zero = XlaHelpers::ScalarValue(0, builder); + xla::XlaOp one_hot_labels = LabelsToOneHot( + /*builder=*/builder, + /*depth=*/logits_shape.dimensions(1), + /*axis=*/1, + /*indices=*/labels, + /*on_value=*/XlaHelpers::ScalarValue(1, builder), + /*off_value=*/zero); + // Compute sum(-one_hot_labels * logits) / batch. + xla::XlaOp mul = xla::Mul(xla::Neg(one_hot_labels), logits); + xla::XlaComputation add_func = XlaHelpers::CreateAddComputation(); + xla::XlaOp batch = + XlaHelpers::ScalarValue(logits_shape.dimensions(0), builder); + return xla::ReduceAll(mul, zero, add_func) / batch; +} + +// Builds the NLLLoss gradient for log-probabilities "logits" and class indices +// "labels". +xla::XlaOp BuildNllLossBackward(const Node* node, const xla::XlaOp& logits, + const xla::XlaOp& labels) { + const int kBatchDim = 0; + auto builder = logits.builder(); + const auto zero = XlaHelpers::ScalarValue(0, builder); + const auto one = XlaHelpers::ScalarValue(1, builder); + const auto logits_shape = builder->GetShape(logits).ValueOrDie(); + xla::XlaOp one_hot_labels = LabelsToOneHot( + /*builder=*/builder, + /*depth=*/logits_shape.dimensions(1), + /*axis=*/1, + /*indices=*/labels, + /*on_value=*/XlaHelpers::ScalarValue(1, builder), + /*off_value=*/XlaHelpers::ScalarValue(0, builder)); + const auto batch = XlaHelpers::ScalarValue( + logits_shape.dimensions(kBatchDim), builder); + // Compute -one_hot_labels / batch. + return xla::Neg(one_hot_labels) / batch; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/nll_loss.h b/torch_xla/csrc/nll_loss.h new file mode 100644 index 00000000000..4c58ed20935 --- /dev/null +++ b/torch_xla/csrc/nll_loss.h @@ -0,0 +1,19 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Builds the NLLLoss for log-probabilities "logits" and class indices "labels". +xla::XlaOp BuildNllLoss(const Node* node, const xla::XlaOp& logits, + const xla::XlaOp& labels); + +// Builds the NLLLoss gradient for log-probabilities "logits" and class indices +// "labels". +xla::XlaOp BuildNllLossBackward(const Node* node, const xla::XlaOp& logits, + const xla::XlaOp& labels); + +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch_xla/csrc/passes/eval_static_size.cpp b/torch_xla/csrc/passes/eval_static_size.cpp new file mode 100644 index 00000000000..308196b8ac2 --- /dev/null +++ b/torch_xla/csrc/passes/eval_static_size.cpp @@ -0,0 +1,41 @@ +#include "eval_static_size.h" + +namespace torch { +namespace jit { + +namespace { + +// Evaluates aten::size on a statically known input. +int64_t RunSizeQuery(Node* node) { + const auto tensor_type = node->input(0)->type()->cast(); + JIT_ASSERT(tensor_type); + const auto tensor_sizes = tensor_type->sizes(); + const auto dim = node->get(attr::dim).value(); + JIT_ASSERT(dim >= 0); + JIT_ASSERT(static_cast(dim) < tensor_sizes.size()); + return tensor_sizes[dim]; +} + +// Returns true if the size can be evaluated during trace optimization. +bool IsStaticSizeQuery(Node* node) { + return node->kind() == aten::size && + node->input(0)->type()->cast() && + node->get(attr::dim) && + node->get(attr::dim).value() >= 0; +} + +} // namespace + +void EvalStaticSize(const std::shared_ptr& graph) { + auto nodes = graph->block()->nodes(); + for (auto node : nodes) { + if (IsStaticSizeQuery(node)) { + WithInsertPoint insert_point_guard(node); + auto new_output = graph->insertConstant(RunSizeQuery(node)); + node->outputs()[0]->replaceAllUsesWith(new_output); + } + } +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/eval_static_size.h b/torch_xla/csrc/passes/eval_static_size.h new file mode 100644 index 00000000000..8c0b1e57ef6 --- /dev/null +++ b/torch_xla/csrc/passes/eval_static_size.h @@ -0,0 +1,12 @@ +#pragma once + +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Evaluate aten::size operators for known shape inputs. +void EvalStaticSize(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/remove_unused_forward_outputs.cpp b/torch_xla/csrc/passes/remove_unused_forward_outputs.cpp new file mode 100644 index 00000000000..1c7b9a3c792 --- /dev/null +++ b/torch_xla/csrc/passes/remove_unused_forward_outputs.cpp @@ -0,0 +1,124 @@ +#include "remove_unused_forward_outputs.h" +#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" + +namespace torch { +namespace jit { + +namespace { + +// Remove an unused input from the backward graph from both the outputs and +// captured outputs sections of its input. +void RemoveInputFromBackwardGraph(Gradient& gradient, const size_t output_idx, + const size_t captured_output_idx) { + const auto backward_inputs = gradient.df->inputs(); + const Value* grad_output = backward_inputs[output_idx]; + const Value* captured_output = backward_inputs[captured_output_idx]; + // Remove grad_output and captured_output from the inputs of the backward + // graph. + for (auto it = gradient.df->nodes().begin(), end = gradient.df->nodes().end(); + it != end; ++it) { + const auto node_inputs = it->inputs(); + const auto grad_output_it = + std::find(node_inputs.begin(), node_inputs.end(), grad_output); + // Assert that grad_output doesn't have remaining uses. + JIT_ASSERT(grad_output_it == node_inputs.end()); + const auto captured_output_it = + std::find(node_inputs.begin(), node_inputs.end(), captured_output); + if (captured_output_it != node_inputs.end()) { + WithInsertPoint guard(*it); + Node* undef = gradient.df->insertNode(gradient.df->createUndefined()); + it->replaceInput(captured_output_it - node_inputs.begin(), + undef->output()); + } + } + // captured_output_idx points inside the captured outputs section, output_idx + // points inside the outputs section. We thus have captured_output_idx > + // output_idx because outputs come before captured outputs. Remove the + // captured_output_idx first to avoid invalidation of indices. + JIT_ASSERT(captured_output_idx > output_idx); + gradient.df->eraseInput(captured_output_idx); + gradient.df->eraseInput(output_idx); +} + +// Remove the unused output specified by node_output_idx from the given node, +// with subsequent removal from the backward graph input as well. +void RemoveNodeOutputFromGradient(Node* node, const size_t node_output_idx, + Gradient& gradient) { + const Value* output = node->outputs()[node_output_idx]; + // Find index of this output in forward graph outputs. + const auto forward_outputs = gradient.f->outputs(); + const auto output_it = + std::find(forward_outputs.begin(), forward_outputs.end(), output); + // This output isn't returned from the forward graph, nothing to do. + if (output_it == forward_outputs.end()) { + return; + } + size_t output_idx = output_it - forward_outputs.begin(); + // Remove the given output from the graph outputs. + gradient.f->eraseOutput(output_idx); + // Remove the given output from the node outputs. + node->eraseOutput(node_output_idx); + + // Find the captured_output_idx absolute index of the backward graph input to + // remove. First, position it at the beginning of the captured outputs, right + // after the outputs of the forward graph and the captureed inputs. + size_t captured_output_idx = + forward_outputs.size() + gradient.df_input_captured_inputs.size(); + // Next, find the index and value in df_input_captured_outputs of the node to + // remove. Use it to adjust captured_output_idx and update + // df_input_captured_outputs. + int df_input_captured_outputs_idx = -1; + for (size_t i = 0; i < gradient.df_input_captured_outputs.size(); i++) { + if (static_cast(output_idx) == + gradient.df_input_captured_outputs[i]) { + captured_output_idx += i; + df_input_captured_outputs_idx = i; + break; + } + } + JIT_ASSERT(df_input_captured_outputs_idx != -1); + const size_t df_input_captured_outputs_val = + gradient.df_input_captured_outputs[df_input_captured_outputs_idx]; + // Remove the node from df_input_captured_outputs and adjust references to + // nodes with higher indices in df_input_captured_outputs. + gradient.df_input_captured_outputs.erase( + gradient.df_input_captured_outputs.begin() + + df_input_captured_outputs_idx); + for (size_t i = 0; i < gradient.df_input_captured_outputs.size(); i++) { + if (gradient.df_input_captured_outputs[i] > df_input_captured_outputs_val) { + --gradient.df_input_captured_outputs[i]; + } + } + + // Finally, remove the node from all inputs of the backward graph. + RemoveInputFromBackwardGraph(/*gradient=*/gradient, /*output_idx=*/output_idx, + /*captured_output_idx=*/captured_output_idx); +} + +} // namespace + +void RemoveUnusedForwardOutputs(Gradient& gradient) { + for (auto it = gradient.f->nodes().begin(), end = gradient.f->nodes().end(); + it != end; ++it) { + JIT_ASSERT(it->blocks().size() == 0); + switch (it->kind()) { + case aten::thnn_conv2d_forward: { + JIT_ASSERT(it->outputs().size() == 3); + RemoveNodeOutputFromGradient(*it, 2, gradient); + RemoveNodeOutputFromGradient(*it, 1, gradient); + break; + } + case aten::max_pool2d_with_indices: { + JIT_ASSERT(it->outputs().size() == 2); + RemoveNodeOutputFromGradient(*it, 1, gradient); + break; + } + default: + break; + } + } +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/remove_unused_forward_outputs.h b/torch_xla/csrc/passes/remove_unused_forward_outputs.h new file mode 100644 index 00000000000..2a3681486b9 --- /dev/null +++ b/torch_xla/csrc/passes/remove_unused_forward_outputs.h @@ -0,0 +1,12 @@ +#pragma once + +#include "torch/csrc/jit/autodiff.h" + +namespace torch { +namespace jit { + +// Remove outputs from forward graph which are not useful for the XLA lowering. +void RemoveUnusedForwardOutputs(Gradient& gradient); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/replace_untraced_operators.cpp b/torch_xla/csrc/passes/replace_untraced_operators.cpp new file mode 100644 index 00000000000..c7c561401d4 --- /dev/null +++ b/torch_xla/csrc/passes/replace_untraced_operators.cpp @@ -0,0 +1,98 @@ +#include "replace_untraced_operators.h" + +namespace torch { +namespace jit { + +namespace { + +// Returns true if the node contains an attribute and has the expected value. +template +bool NodeHasExpectedAttribute(const Node* node, const Symbol attribute_name, + const T& expected) { + const auto maybe_attribute = node->get(attribute_name); + return maybe_attribute && *maybe_attribute == expected; +} + +// Only allow certain aten::_convolution operators to be replaced. +bool CanTraceConvolution(const Node* node) { + return NodeHasExpectedAttribute(node, attr::dilation, + std::vector{1, 1}) && + NodeHasExpectedAttribute(node, attr::output_padding, + std::vector{0, 0}) && + NodeHasExpectedAttribute(node, attr::transposed, false) && + NodeHasExpectedAttribute(node, attr::groups, int64_t(1)) && + NodeHasExpectedAttribute(node, attr::benchmark, false) && + NodeHasExpectedAttribute(node, attr::deterministic, false); +} + +// When possible, replace aten::{_convolution, batch_norm} operators with +// equivalent ones which are part of the operator schema and differentiable. +void ReplaceUntracedOperators(Block* block) { + for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; + ++it) { + for (auto sub : it->blocks()) { + ReplaceUntracedOperators(sub); + } + switch (it->kind()) { + case aten::_convolution: { + WithInsertPoint guard(*it); + auto graph = block->owningGraph(); + auto node = *it; + if (!CanTraceConvolution(node)) { + break; + } + const auto weight = node->namedInput(attr::weight); + const auto weight_type = weight->type()->expect(); + const auto& weight_size = weight_type->sizes(); + const auto kernel_size = graph->insertConstant( + std::vector{weight_size[2], weight_size[3]}); + const auto stride = graph->insertConstant( + node->get>(attr::stride).value()); + const auto padding = graph->insertConstant( + node->get>(attr::padding).value()); + + auto replacement_node = graph->create(aten::thnn_conv2d_forward, 3); + + graph->insertNode(replacement_node); + replacement_node->addInput(node->namedInput(attr::input)); + replacement_node->addInput(weight); + replacement_node->addInput(kernel_size); + replacement_node->addInput(node->namedInput(attr::bias)); + replacement_node->addInput(stride); + replacement_node->addInput(padding); + + replacement_node->outputs()[0]->setType(it->outputs()[0]->type()); + it->output()->replaceAllUsesWith(replacement_node->outputs()[0]); + it.destroyCurrent(); + break; + } + case aten::batch_norm: { + WithInsertPoint guard(*it); + auto graph = block->owningGraph(); + auto node = *it; + auto replacement_node = graph->create(aten::native_batch_norm, 3); + + graph->insertNode(replacement_node); + const auto node_inputs = node->inputs(); + JIT_ASSERT(node_inputs.size() == 9); + for (size_t i = 0; i < node_inputs.size() - 1; ++i) { + replacement_node->addInput(node_inputs[i]); + } + replacement_node->outputs()[0]->setType(it->outputs()[0]->type()); + it->output()->replaceAllUsesWith(replacement_node->outputs()[0]); + it.destroyCurrent(); + break; + } + default: { break; } + } + } +} + +} // namespace + +void ReplaceUntracedOperators(const std::shared_ptr& graph) { + ReplaceUntracedOperators(graph->block()); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/replace_untraced_operators.h b/torch_xla/csrc/passes/replace_untraced_operators.h new file mode 100644 index 00000000000..4f651f746af --- /dev/null +++ b/torch_xla/csrc/passes/replace_untraced_operators.h @@ -0,0 +1,12 @@ +#pragma once + +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Replace certain operators with their differentiable versions. +void ReplaceUntracedOperators(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/threshold_backward_peephole.cpp b/torch_xla/csrc/passes/threshold_backward_peephole.cpp new file mode 100644 index 00000000000..2076f26bc81 --- /dev/null +++ b/torch_xla/csrc/passes/threshold_backward_peephole.cpp @@ -0,0 +1,41 @@ +#include "threshold_backward_peephole.h" + +namespace torch { +namespace jit { + +namespace { + +void ThresholdBackwardPeephole(Block* block) { + for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; + ++it) { + for (auto sub_block : it->blocks()) { + ThresholdBackwardPeephole(sub_block); + } + if (it->kind() == aten::mul) { + const auto type_as_cand = it->input(1)->node(); + if (type_as_cand->kind() == aten::type_as) { + const auto gt_cand = type_as_cand->input(0)->node(); + if (gt_cand->kind() == aten::gt) { + WithInsertPoint guard(*it); + auto graph = block->owningGraph(); + auto replacement_node = graph->create(aten::threshold_backward); + graph->insertNode(replacement_node); + replacement_node->addInput(it->input(0)); + replacement_node->addInput(gt_cand->input(0)); + replacement_node->addInput(gt_cand->input(1)); + it->output()->replaceAllUsesWith(replacement_node->outputs()[0]); + it.destroyCurrent(); + } + } + } + } +} + +} // namespace + +void ThresholdBackwardPeephole(const std::shared_ptr& graph) { + ThresholdBackwardPeephole(graph->block()); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/passes/threshold_backward_peephole.h b/torch_xla/csrc/passes/threshold_backward_peephole.h new file mode 100644 index 00000000000..bc51a77a56d --- /dev/null +++ b/torch_xla/csrc/passes/threshold_backward_peephole.h @@ -0,0 +1,14 @@ +#pragma once + +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Recognizes a gt, type_as, mul sequence and replaces it with +// threshold_backward. Works around an issue in the TPU compiler with S8 tensor +// shapes. +void ThresholdBackwardPeephole(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp new file mode 100644 index 00000000000..1a7a340f0d1 --- /dev/null +++ b/torch_xla/csrc/pooling.cpp @@ -0,0 +1,154 @@ +#include "pooling.h" +#include "helpers.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "torch/csrc/jit/autodiff.h" + +namespace torch { +namespace jit { + +namespace { + +xla::TensorFormat MakeNCHWFormat() { + return {/*batch_dimension=*/0, + /*feature_dimension=*/1, + /*spatial_dimensions=*/std::vector{2, 3}}; +} + +// Holds the attributes common to all pooling operators. +struct PoolingOpAttributes { + std::vector kernel_size; + std::vector stride; + std::vector> padding; +}; + +xla::XlaComputation CreateGeComputation() { + xla::XlaBuilder reduction_builder("xla_ge_computation"); + const auto x = xla::Parameter( + &reduction_builder, 0, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "x"); + const auto y = xla::Parameter( + &reduction_builder, 1, + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}), "y"); + xla::Ge(x, y); + return reduction_builder.Build().ConsumeValueOrDie(); +} + +// Extract the pooling attributes for the given 2D pooling operator "node". +PoolingOpAttributes Pooling2DOpAttributes(const Node* pooling_2d) { + const auto kernel_size_attr = XlaHelpers::I64List( + pooling_2d->get>(attr::kernel_size).value()); + const auto stride_attr = + pooling_2d->get>(attr::stride).value(); + // Create a NCHW kernel size with 1 for batch size and feature. + std::vector kernel_size(2, 1); + kernel_size.insert(kernel_size.end(), kernel_size_attr.begin(), + kernel_size_attr.end()); + // Create a NCHW stride size with 1 for batch size and feature. Same as kernel + // size if not specified. + std::vector stride; + if (stride_attr.empty()) { + stride = kernel_size; + } else { + stride.resize(2, 1); + stride.insert(stride.end(), stride_attr.begin(), stride_attr.end()); + } + const auto padding_attr = + pooling_2d->get>(attr::padding).value(); + CHECK_EQ(padding_attr.size(), 2); + std::vector> padding; + for (const xla::int64 dim_pad : padding_attr) { + padding.push_back(std::make_pair(dim_pad, dim_pad)); + } + return {kernel_size, stride, padding}; +} + +void CheckAvgPool2DIsSupported(const Node* node) { + const auto node_inputs = node->inputs(); + CHECK_GE(node_inputs.size(), size_t(6)); + const auto ceil_mode = node->get(attr::ceil_mode).value(); + if (ceil_mode) { + AT_ERROR("ceil_mode not supported for avg_pool2d yet"); + } +} + +} // namespace + +xla::XlaOp BuildMaxPool2d(const Node* node, const xla::XlaOp& input) { + const auto pooling_op_attributes = Pooling2DOpAttributes(node); + auto builder = input.builder(); + const auto init_value = xla::LiteralUtil::MinValue(xla::PrimitiveType::F32); + const auto xla_init_value = xla::ConstantLiteral(builder, init_value); + const auto padding_config = XlaHelpers::MakeXlaPaddingConfig( + node->get>(attr::padding).value()); + const auto padded_input = xla::Pad(input, xla_init_value, padding_config); + return xla::MaxPool( + /*operand=*/padded_input, + /*kernel_size=*/pooling_op_attributes.kernel_size, + /*stride=*/pooling_op_attributes.stride, + /*padding=*/xla::Padding::kValid, + /*data_format=*/MakeNCHWFormat()); +} + +xla::XlaOp BuildMaxPool2dBackward(const Node* node, + const xla::XlaOp& out_backprop, + const xla::XlaOp& input) { + auto builder = out_backprop.builder(); + const auto init_value = XlaHelpers::ScalarValue(0, builder); + const auto select = CreateGeComputation(); + const auto scatter = XlaHelpers::CreateAddComputation(); + const auto pooling_op_attributes = Pooling2DOpAttributes(node); + std::vector> window_padding; + window_padding.resize(2); + window_padding.insert(window_padding.end(), + pooling_op_attributes.padding.begin(), + pooling_op_attributes.padding.end()); + return xla::SelectAndScatterWithGeneralPadding( + /*operand=*/input, + /*select=*/select, + /*window_dimensions=*/pooling_op_attributes.kernel_size, + /*window_strides=*/pooling_op_attributes.stride, + /*padding=*/window_padding, + /*source=*/out_backprop, + /*init_value=*/init_value, + /*scatter=*/scatter); +} + +xla::XlaOp BuildAvgPool2d(const Node* node, const xla::XlaOp& input) { + // Inspired from tf2xla. + CheckAvgPool2DIsSupported(node); + const auto pooling_op_attributes = Pooling2DOpAttributes(node); + const auto count_include_pad = + node->get(attr::count_include_pad).value(); + return xla::AvgPool( + /*operand=*/input, + /*kernel_size=*/pooling_op_attributes.kernel_size, + /*stride=*/pooling_op_attributes.stride, + /*padding=*/pooling_op_attributes.padding, + /*data_format=*/MakeNCHWFormat(), + /*counts_include_padding=*/count_include_pad); +} + +xla::XlaOp BuildAvgPool2dBackward(const Node* node, + const xla::XlaOp& out_backprop, + const xla::XlaOp& input) { + // Inspired from tf2xla. + CheckAvgPool2DIsSupported(node); + const auto pooling_op_attributes = Pooling2DOpAttributes(node); + const auto node_inputs = node->inputs(); + auto gradients_size = + XlaHelpers::I64List(XlaHelpers::TensorDimensionSizes(node_inputs[1])); + const auto count_include_pad = + node->get(attr::count_include_pad).value(); + + return xla::AvgPoolGrad( + /*out_backprop=*/out_backprop, + /*gradients_size=*/gradients_size, + /*kernel_size=*/pooling_op_attributes.kernel_size, + /*stride=*/pooling_op_attributes.stride, + /*spatial_padding=*/pooling_op_attributes.padding, + /*data_format=*/MakeNCHWFormat(), + /*counts_include_padding=*/count_include_pad); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/pooling.h b/torch_xla/csrc/pooling.h new file mode 100644 index 00000000000..1d4bba2c6dc --- /dev/null +++ b/torch_xla/csrc/pooling.h @@ -0,0 +1,28 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Computes max pooling for the given input with the attributes specified in the +// given node. +xla::XlaOp BuildMaxPool2d(const Node* node, const xla::XlaOp& input); + +// Computes the gradient for max pooling. +xla::XlaOp BuildMaxPool2dBackward(const Node* node, + const xla::XlaOp& out_backprop, + const xla::XlaOp& input); + +// Computes average pooling for the given input with the attributes specified in +// the given node. +xla::XlaOp BuildAvgPool2d(const Node* node, const xla::XlaOp& input); + +// Computes the gradient for average pooling. +xla::XlaOp BuildAvgPool2dBackward(const Node* node, + const xla::XlaOp& out_backprop, + const xla::XlaOp& input); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp new file mode 100644 index 00000000000..8ad5bdb14fa --- /dev/null +++ b/torch_xla/csrc/reduction.cpp @@ -0,0 +1,20 @@ +#include "reduction.h" +#include "helpers.h" + +namespace torch { +namespace jit { + +xla::XlaOp BuildSum(const Node* node, const xla::XlaOp& operand) { + if (node->get(attr::keepdim).value()) { + AT_ERROR("Sum with keepdim set not supported yet"); + } + auto builder = operand.builder(); + const auto init_value = XlaHelpers::ScalarValue(0, builder); + const auto dimensions_to_reduce = + node->get>(attr::dim).value(); + return xla::Reduce(operand, init_value, XlaHelpers::CreateAddComputation(), + XlaHelpers::I64List(dimensions_to_reduce)); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h new file mode 100644 index 00000000000..48748176a71 --- /dev/null +++ b/torch_xla/csrc/reduction.h @@ -0,0 +1,14 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +// Sum the given operand elements along the dimension specified by the "dim" +// attribute of the node. +xla::XlaOp BuildSum(const Node* node, const xla::XlaOp& operand); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp new file mode 100644 index 00000000000..f2a5ba0d1f1 --- /dev/null +++ b/torch_xla/csrc/tensor.cpp @@ -0,0 +1,559 @@ +#include "tensor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "c10/util/Exception.h" +#include "helpers.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "torch/csrc/autograd/variable.h" +#include "translator.h" + +namespace torch { +namespace jit { + +namespace { + +// Creates a minor-to-major layout from given dimensions. +xla::Shape MakeTorchTensorLayout(const std::vector& dimensions, + const xla::PrimitiveType type) { + return xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions); +} + +xla::Shape MakeArrayShapeFromDimensions( + const at::IntList& tensor_dimensions, const xla::PrimitiveType type, + const XLATensor::DeviceType device_type) { + const auto dimensions = XlaHelpers::I64List(tensor_dimensions); + if (dimensions.size() == 4 && device_type == XLATensor::DeviceType::TPU) { + // Use a TPU-compatible layout for 4D tensors -- batch and feature in minor + // dimensions. + std::vector hwcn_layout{0, 1, 3, 2}; + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, hwcn_layout); + } + return MakeTorchTensorLayout(dimensions, type); +} + +template +std::vector LinearizeTensor(const at::Tensor& t, + const size_t total_elements); + +template <> +std::vector LinearizeTensor(const at::Tensor& t, + const size_t total_elements) { + const at::Tensor& cont_t = t.contiguous(); + return std::vector(cont_t.data(), + cont_t.data() + total_elements); +} + +template <> +std::vector LinearizeTensor( + const at::Tensor& t, const size_t total_elements) { + const at::Tensor& cont_t = t.contiguous(); + return std::vector(cont_t.data(), + cont_t.data() + total_elements); +} + +template +std::shared_ptr TensorToXlaImpl( + const at::Tensor& param_tensor, const xla::Shape& param_shape, + const XLATensor::Device& device, xla::ComputationClient* client) { + size_t total_elements = 1; + std::vector dimension_sizes; + for (const auto dimension_size : param_tensor.sizes()) { + dimension_sizes.push_back(dimension_size); + total_elements *= dimension_size; + } + xla::Array parameter_xla_array(dimension_sizes); + parameter_xla_array.SetValues( + LinearizeTensor(param_tensor, total_elements)); + xla::Literal literal(param_shape); + literal.PopulateFromArray(parameter_xla_array); + return client->TransferParameterToServer(literal, + /*device=*/device.ToString()); +} + +std::shared_ptr TensorToXla( + const at::Tensor& param_tensor, const xla::Shape& param_shape, + const XLATensor::Device& device, xla::ComputationClient* client) { + switch (param_tensor.type().scalarType()) { + case at::ScalarType::Float: + return TensorToXlaImpl(param_tensor, param_shape, device, client); + case at::ScalarType::Long: + return TensorToXlaImpl(param_tensor, param_shape, device, + client); + default: + LOG(FATAL) << "Tensor type not supported"; + } +} + +at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal) { + const auto& result_shape = literal.shape(); + std::vector dimensions; + for (const auto result_dimension : result_shape.dimensions()) { + dimensions.push_back(result_dimension); + } + auto literal_type = result_shape.element_type(); + const auto torch_layout = + MakeTorchTensorLayout(XlaHelpers::I64List(dimensions), literal_type); + const auto literal_with_torch_layout = literal.Relayout(torch_layout); + switch (literal_type) { + case xla::PrimitiveType::F32: { + const auto result_slice = literal_with_torch_layout.data(); + at::Tensor result_tensor = at::empty(dimensions, at::TensorOptions(at::kFloat)); + std::copy(result_slice.begin(), result_slice.end(), + result_tensor.data()); + return result_tensor; + } + case xla::PrimitiveType::S64: { + const auto result_slice = literal_with_torch_layout.data(); + at::Tensor result_tensor = at::empty(dimensions, at::TensorOptions(at::kLong)); + std::copy(result_slice.begin(), result_slice.end(), + result_tensor.data()); + return result_tensor; + } + default: + AT_ERROR("Unsupported literal type"); + } +} + +std::string DeviceTypeToString(const XLATensor::DeviceType hw_type) { + switch (hw_type) { + case XLATensor::DeviceType::CPU: + return "CPU"; + case XLATensor::DeviceType::GPU: + return "GPU"; + case XLATensor::DeviceType::TPU: + return "TPU"; + } +} + +void SetMulti(const std::vector>& dest_tuple, + std::vector>& + new_dest_elements, + const std::vector& index_mapping) { + CHECK_EQ(index_mapping.size(), new_dest_elements.size()); + // Replace the underlying data for the destination tensors with the data in + // "new_dest_elements". + for (size_t i = 0; i < new_dest_elements.size(); ++i) { + xla::int64 dest_tuple_index = index_mapping[i]; + dest_tuple[dest_tuple_index]->SetXlaData(std::move(new_dest_elements[i])); + } +} + +} // namespace + +std::string XLATensor::Device::ToString() const { + return absl::StrCat(DeviceTypeToString(hw_type), ":", ordinal); +} + +XLATensor::XLATensor(const autograd::Variable& tensor, const Device& device) + : data_(std::make_shared( + TensorToXla( + tensor, + MakeArrayShapeFromDimensions( + tensor.sizes(), + XlaHelpers::MakeXlaPrimitiveType(tensor.type().scalarType()), + device.hw_type), + device, XlaGetClient()), + device, 0)), + requires_grad_(tensor.requires_grad()) {} + +XLATensor::XLATensor(std::shared_ptr xla_data, + uint64_t module_id) + : data_(std::make_shared( + xla_data, DeviceFromString(xla_data->device()), module_id)) {} + +XLATensor::XLATensor(std::shared_ptr xla_graph_node, + const Device& device, uint64_t module_id) + : data_(std::make_shared(std::move(xla_graph_node), device, + module_id)) { + TryLimitGraphSize(); +} + +void XLATensor::MulAddMulti( + const double scale_dest, + const std::vector>& dest_tuple, + const double alpha, + const std::vector>& source_tuple) { + CHECK_EQ(dest_tuple.size(), source_tuple.size()); + XlaGraphContext xla_graph_ctx; + for (size_t i = 0; i < dest_tuple.size(); ++i) { + auto dest_node = dest_tuple[i]->GetXlaGraphNode(); + auto source_node = source_tuple[i]->GetXlaGraphNode(); + auto dest_node_op = dest_node->Generate(&xla_graph_ctx).ValueOrDie(); + auto source_node_op = source_node->Generate(&xla_graph_ctx).ValueOrDie(); + if (alpha != 1) { + const auto alpha_source = XlaHelpers::ScalarBroadcast( + alpha, source_tuple[i]->shape(), xla_graph_ctx.builder()); + source_node_op = xla::Mul(source_node_op, alpha_source); + } + if (scale_dest != 1) { + const auto scale_dest_broadcast = XlaHelpers::ScalarBroadcast( + scale_dest, dest_tuple[i]->shape(), xla_graph_ctx.builder()); + dest_node_op = xla::Mul(dest_node_op, scale_dest_broadcast); + } + xla_graph_ctx.AddResult(xla::Add(dest_node_op, source_node_op)); + } + std::vector index_mapping(dest_tuple.size()); + std::iota(index_mapping.begin(), index_mapping.end(), 0); + ComputeAndDistribute(&xla_graph_ctx, index_mapping, dest_tuple); +} + +void XLATensor::ZeroMulti( + const std::vector>& dest_tuple) { + if (dest_tuple.empty()) { + return; + } + // Create a computation which returns zeroes shaped the same as tensors in + // "dest_tuple". + XlaGraphContext xla_graph_ctx; + for (auto& dest : dest_tuple) { + const auto dest_shape = dest->shape(); + const auto zero = + xla::ConstantLiteral(xla_graph_ctx.builder(), + xla::LiteralUtil::Zero(dest_shape.element_type())); + xla_graph_ctx.AddResult( + Broadcast(zero, XlaHelpers::ShapeSizes(dest_shape))); + } + std::vector index_mapping(dest_tuple.size()); + std::iota(index_mapping.begin(), index_mapping.end(), 0); + ComputeAndDistribute(&xla_graph_ctx, index_mapping, dest_tuple); +} + +std::shared_ptr XLATensor::grad() const { return data_->grad; } + +void XLATensor::setGrad(std::shared_ptr grad) { + data_->grad = std::move(grad); +} + +const xla::Shape& XLATensor::shape() const { + return data_->xla_data ? data_->xla_data->shape() + : data_->xla_graph_node->shape(); +} + +const XLATensor::Device& XLATensor::GetDevice() const { return data_->device; } + +const std::shared_ptr& XLATensor::GetXlaData() { + ApplyPendingGraph(); + return data_->xla_data; +} + +void XLATensor::SetXlaData( + std::shared_ptr xla_data) { + data_->xla_data = std::move(xla_data); + data_->xla_graph_node = nullptr; + // A modified tensor doesn't come directly from a module forward call. + data_->module_id = 0; +} + +void XLATensor::SetXlaGraphNode(std::shared_ptr xla_graph_node) { + data_->xla_graph_node = std::move(xla_graph_node); + // A modified tensor doesn't come directly from a module forward call. + data_->module_id = 0; + TryLimitGraphSize(); +} + +void XLATensor::TryLimitGraphSize() { + // If we are accumulating too many nodes in the pending graph, render the XLA + // by executing the pending graph. + static const xla::int64 kMaxPendingGraphSize = 1000; + if (data_->xla_graph_node != nullptr && + data_->xla_graph_node->graph_size() > kMaxPendingGraphSize) { + ApplyPendingGraph(); + } +} + +std::shared_ptr XLATensor::GetXlaGraphNode() const { + return data_->xla_graph_node ? data_->xla_graph_node + : CreateTensorNode(data_->xla_data); +} + +std::vector XLATensor::Size() const { + const xla::Shape& tensor_shape = shape(); + return std::vector(tensor_shape.dimensions().begin(), + tensor_shape.dimensions().end()); +} + +uint64_t XLATensor::ForwardModuleId() const { return data_->module_id; } + +at::Tensor XLATensor::toTensor() { + ApplyPendingGraph(); + // Because there's no transferToClient, we'll define an `identity` graph, and + // execute it. + xla::XlaBuilder b("identity"); + xla::GetTupleElement(xla::Tuple(&b, {xla::Parameter(&b, 0, shape(), "x")}), + 0); + xla::XlaComputation identity = b.Build().ValueOrDie(); + + auto client = XlaGetClient(); + auto result_literal = client->ExecuteComputationAndTransfer( + identity, {data_->xla_data.get()}, nullptr); + auto return_tensor = MakeTensorFromXlaLiteral(*result_literal); + return autograd::make_variable(return_tensor, requires_grad_); +} + +std::shared_ptr XLATensor::CreateTensorNode( + std::shared_ptr data) { + auto generator = [data](XlaGraphContext* ctx, + const XlaGraphNode&) -> xla::StatusOr { + return ctx->GetParameter(data); + }; + return XlaGraphNode::New(std::move(generator), data->shape(), {}); +} + +std::shared_ptr XLATensor::CreateMulNode(XLATensor& other) { + auto generator = [](XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + TF_ASSIGN_OR_RETURN(auto other_node_op, node.input(1)->Generate(ctx)); + return node_op * other_node_op; + }; + return XlaGraphNode::New(std::move(generator), shape(), + {GetXlaGraphNode(), other.GetXlaGraphNode()}); +} + +std::shared_ptr XLATensor::CreateMulNode( + const at::Scalar& other) { + auto generator = [other]( + XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + return node_op * XlaHelpers::ScalarBroadcast(other.toDouble(), + node.input(0)->shape(), + ctx->builder()); + }; + return XlaGraphNode::New(std::move(generator), shape(), {GetXlaGraphNode()}); +} + +std::shared_ptr XLATensor::CreateDivNode(XLATensor& other) { + auto generator = [](XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + TF_ASSIGN_OR_RETURN(auto other_node_op, node.input(1)->Generate(ctx)); + return node_op / other_node_op; + }; + return XlaGraphNode::New(std::move(generator), shape(), + {GetXlaGraphNode(), other.GetXlaGraphNode()}); +} + +std::shared_ptr XLATensor::CreateDivNode( + const at::Scalar& other) { + auto generator = [other]( + XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + return node_op / XlaHelpers::ScalarBroadcast(other.toDouble(), + node.input(0)->shape(), + ctx->builder()); + }; + return XlaGraphNode::New(std::move(generator), shape(), {GetXlaGraphNode()}); +} + +std::shared_ptr XLATensor::CreateAddNode( + XLATensor& other, const at::Scalar& alpha) { + auto generator = [alpha]( + XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + TF_ASSIGN_OR_RETURN(auto other_node_op, node.input(1)->Generate(ctx)); + return node_op + other_node_op * XlaHelpers::ScalarBroadcast( + alpha.toDouble(), + node.input(0)->shape(), + ctx->builder()); + }; + return XlaGraphNode::New(std::move(generator), shape(), + {GetXlaGraphNode(), other.GetXlaGraphNode()}); +} + +std::shared_ptr XLATensor::add(XLATensor& other, + const at::Scalar& alpha) { + return std::make_shared(CreateAddNode(other, alpha), data_->device, + 0); +} + +void XLATensor::add_(XLATensor& other, const at::Scalar& alpha) { + SetXlaGraphNode(CreateAddNode(other, alpha)); +} + +std::shared_ptr XLATensor::mul(XLATensor& other) { + return std::make_shared(CreateMulNode(other), data_->device, 0); +} + +std::shared_ptr XLATensor::mul(const at::Scalar& other) { + return std::make_shared(CreateMulNode(other), data_->device, 0); +} + +void XLATensor::mul_(XLATensor& other) { + SetXlaGraphNode(CreateMulNode(other)); +} + +void XLATensor::mul_(const at::Scalar& other) { + SetXlaGraphNode(CreateMulNode(other)); +} + +std::shared_ptr XLATensor::div(XLATensor& other) { + return std::make_shared(CreateDivNode(other), data_->device, 0); +} + +std::shared_ptr XLATensor::div(const at::Scalar& other) { + return std::make_shared(CreateDivNode(other), data_->device, 0); +} + +void XLATensor::div_(XLATensor& other) { + SetXlaGraphNode(CreateDivNode(other)); +} + +void XLATensor::div_(const at::Scalar& other) { + SetXlaGraphNode(CreateDivNode(other)); +} + +void XLATensor::zero_() { + xla::Shape tensor_shape = shape(); + auto generator = [tensor_shape]( + XlaGraphContext* ctx, + const XlaGraphNode&) -> xla::StatusOr { + auto zero_literal = xla::LiteralUtil::Zero(tensor_shape.element_type()); + auto const_zero = xla::ConstantLiteral(ctx->builder(), zero_literal); + return xla::Broadcast(const_zero, XlaHelpers::ShapeSizes(tensor_shape)); + }; + SetXlaGraphNode(XlaGraphNode::New(std::move(generator), tensor_shape, {})); +} + +std::shared_ptr XLATensor::cross_replica_sum( + const std::vector>& groups) { + auto generator = [groups]( + XlaGraphContext* ctx, + const XlaGraphNode& node) -> xla::StatusOr { + std::vector crs_groups; + for (auto& group : groups) { + xla::ReplicaGroup rgroup; + for (auto replica_id : group) { + rgroup.add_replica_ids(replica_id); + } + crs_groups.push_back(std::move(rgroup)); + } + TF_ASSIGN_OR_RETURN(auto node_op, node.input(0)->Generate(ctx)); + return xla::CrossReplicaSum(node_op, crs_groups); + }; + auto crs_node = + XlaGraphNode::New(std::move(generator), shape(), {GetXlaGraphNode()}); + return std::make_shared(std::move(crs_node), data_->device, 0); +} + +void XLATensor::ApplyPendingGraph() { + auto& xla_graph_node = current_xla_graph_node(); + if (xla_graph_node != nullptr) { + XlaGraphContext xla_graph_ctx; + auto root = xla_graph_node->Generate(&xla_graph_ctx).ValueOrDie(); + auto computation = xla_graph_ctx.Build(root).ValueOrDie(); + SetXlaData(XlaGetClient()->ExecuteComputation( + computation, xla_graph_ctx.GetParametersData(), nullptr)); + } +} + +void XLATensor::ComputeAndDistribute( + XlaGraphContext* xla_graph_ctx, + const std::vector& index_mapping, + const std::vector>& tensors) { + auto computation = xla_graph_ctx->Build().ValueOrDie(); + auto program_shape = computation.GetProgramShape().ValueOrDie(); + const auto device = CommonDeviceForTensors(tensors); + const auto multi_shape = + MakeShapeWithDeviceLayout(program_shape.result(), device.hw_type); + auto client = XlaGetClient(); + auto result_tuple = client->ExecuteComputation( + computation, xla_graph_ctx->GetParametersData(), &multi_shape); + auto new_dest_elements = client->DeconstructTuple(*result_tuple).ValueOrDie(); + // Replace destination's underlying data with the result of the computation. + SetMulti(tensors, new_dest_elements, index_mapping); +} + +void XLATensor::ApplyPendingGraph( + const std::vector>& tensors) { + XlaGraphContext xla_graph_ctx; + std::vector index_mapping; + for (size_t i = 0; i < tensors.size(); ++i) { + auto& xla_graph_node = tensors[i]->current_xla_graph_node(); + if (xla_graph_node != nullptr) { + auto root = xla_graph_node->Generate(&xla_graph_ctx).ValueOrDie(); + xla_graph_ctx.AddResult(root); + index_mapping.push_back(i); + } + } + if (!index_mapping.empty()) { + // The SetXlaData() call done within SetMulti(), called by the following + // function, will provide to reset the cached XLA graph node. + ComputeAndDistribute(&xla_graph_ctx, index_mapping, tensors); + } +} + +XLATensor::Device XLATensor::DeviceFromString(const std::string& device_spec) { + if (device_spec.empty()) { + const std::string default_device_spec = XlaGetClient()->GetDefaultDevice(); + CHECK(!default_device_spec.empty()); + return DeviceFromString(default_device_spec); + } + std::vector device_spec_parts = absl::StrSplit(device_spec, ':'); + std::string invalid_device_error = + "Invalid device specification: " + device_spec; + if (device_spec_parts.size() != 2) { + AT_ERROR(invalid_device_error); + } + int device_ordinal = std::stoi(device_spec_parts[1]); + std::string device_hw_type = device_spec_parts[0]; + if (device_hw_type == "CPU") { + return {XLATensor::DeviceType::CPU, device_ordinal}; + } + if (device_hw_type == "GPU") { + return {XLATensor::DeviceType::GPU, device_ordinal}; + } + if (device_hw_type == "TPU") { + return {XLATensor::DeviceType::TPU, device_ordinal}; + } + AT_ERROR(invalid_device_error); +} + +XLATensor::Device XLATensor::CommonDeviceForTensors( + const std::vector>& tensors) { + CHECK(!tensors.empty()); + const XLATensor::Device& device = tensors.front()->GetDevice(); + for (const auto& tensor : tensors) { + const XLATensor::Device& tensor_device = tensor->GetDevice(); + if (tensor_device != device) { + AT_ERROR("All input tensors should have the same device"); + } + } + return device; +} + +std::vector GetComponentShapes(const xla::Shape& shape) { + std::vector component_shapes; + if (xla::ShapeUtil::IsTuple(shape)) { + for (const xla::Shape& component_shape : shape.tuple_shapes()) { + CHECK(!xla::ShapeUtil::IsTuple(component_shape)); + component_shapes.push_back(component_shape); + } + } else { + component_shapes.push_back(shape); + } + return component_shapes; +} + +xla::Shape MakeShapeWithDeviceLayout(const xla::Shape& shape, + const XLATensor::DeviceType device_type) { + std::vector shape_components = GetComponentShapes(shape); + std::vector shape_components_with_layout; + CHECK(!shape_components.empty()); + for (const auto& shape_component : shape_components) { + std::vector shape_component_dimensions( + shape_component.dimensions().begin(), + shape_component.dimensions().end()); + shape_components_with_layout.push_back(MakeArrayShapeFromDimensions( + shape_component_dimensions, shape_component.element_type(), + device_type)); + } + return shape_components_with_layout.size() > 1 + ? xla::ShapeUtil::MakeTupleShape(shape_components_with_layout) + : shape_components_with_layout.front(); +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h new file mode 100644 index 00000000000..be1e168fdc7 --- /dev/null +++ b/torch_xla/csrc/tensor.h @@ -0,0 +1,190 @@ +#pragma once + +#include "graph_context.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +class XLATensor { + struct Data; + + public: + TH_DISALLOW_COPY_AND_ASSIGN(XLATensor); + + enum class DeviceType { CPU, GPU, TPU }; + + struct Device { + Device() = default; + Device(DeviceType hw_type, int ordinal) + : hw_type(hw_type), ordinal(ordinal) {} + + bool operator==(const Device& other) const { + return hw_type == other.hw_type && ordinal == other.ordinal; + } + + bool operator!=(const Device& other) const { return !(*this == other); } + + bool operator<(const Device& rhs) const { + if (hw_type != rhs.hw_type) { + return hw_type < rhs.hw_type; + } + return ordinal < rhs.ordinal; + } + + std::string ToString() const; + + DeviceType hw_type = DeviceType::CPU; + int ordinal = 0; + }; + + XLATensor(const autograd::Variable& tensor, const Device& device); + XLATensor(std::shared_ptr xla_data, + uint64_t module_id); + XLATensor(std::shared_ptr xla_graph_node, const Device& device, + uint64_t module_id); + XLATensor(std::shared_ptr data) : data_(std::move(data)) {} + + // Creates a new XLA tensor sharing the core tensor data structure, with + // require-gradients disabled. + std::shared_ptr Clone() const { + return std::make_shared(data_); + } + + bool RequiresGrad() const { return requires_grad_; } + + void detach_() { requires_grad_ = false; } + + at::Tensor toTensor(); + + std::shared_ptr grad() const; + void setGrad(std::shared_ptr grad); + + const xla::Shape& shape() const; + const Device& GetDevice() const; + const std::shared_ptr& GetXlaData(); + void SetXlaData(std::shared_ptr xla_data); + std::shared_ptr GetXlaGraphNode() const; + std::vector Size() const; + uint64_t ForwardModuleId() const; + + // Basic tensor operations used by the optimizers. + std::shared_ptr add(XLATensor& other, const at::Scalar& alpha); + void add_(XLATensor& other, const at::Scalar& alpha); + + std::shared_ptr mul(XLATensor& other); + std::shared_ptr mul(const at::Scalar& other); + void mul_(XLATensor& other); + void mul_(const at::Scalar& other); + + std::shared_ptr div(XLATensor& other); + std::shared_ptr div(const at::Scalar& other); + void div_(XLATensor& other); + void div_(const at::Scalar& other); + + void zero_(); + + std::shared_ptr cross_replica_sum( + const std::vector>& groups); + + // Applies the queue of operations in preparation for using the data. + void ApplyPendingGraph(); + + // Converts the given "device_spec" string to a device. The format is + // :, where hw_type is one of TPU, CPU or GPU and ordinal is + // an integer. + static Device DeviceFromString(const std::string& device_spec); + + // Returns the common device for "tensors". Throws if not all tensors have the + // same device. + static Device CommonDeviceForTensors( + const std::vector>& tensors); + + // In place scale and add for multiple tensors. The operation applies to all + // tensors "dest" in "dest_tuple" and is: + // dest = scale_dest * dest + alpha * source + // where "source" is the corresponding tensor in "source_tuple". + // This is a (temporary) building block for manually batched SGD optimizer. We + // have ways to automatically batch the optimizer application to all weights + // in the model; for expediency, we'll instead do this to minimize the number + // of moving parts needed to achieve better usability. + static void MulAddMulti( + const double scale_dest, + const std::vector>& dest_tuple, + const double alpha, + const std::vector>& source_tuple); + + // Zero all the tensors in "dest_tuple", it exists for the same reason as + // "MulAddMulti". + static void ZeroMulti( + const std::vector>& dest_tuple); + + // Applies the queue of operations for a list of tensors. + static void ApplyPendingGraph( + const std::vector>& tensors); + + private: + struct Data { + Data(std::shared_ptr xla_data, + const Device& device, uint64_t module_id) + : xla_data(std::move(xla_data)), device(device), module_id(module_id) {} + Data(std::shared_ptr xla_graph_node, const Device& device, + uint64_t module_id) + : xla_graph_node(std::move(xla_graph_node)), + device(device), + module_id(module_id) {} + + std::shared_ptr xla_data; + std::shared_ptr grad; + std::shared_ptr xla_graph_node; + Device device; + uint64_t module_id = 0; + }; + + void SetXlaGraphNode(std::shared_ptr xla_graph_node); + + const std::shared_ptr& current_xla_graph_node() const { + return data_->xla_graph_node; + } + + // We build an XLA graph accumulating XLA operations, but at a given point we + // need to force a rendering, otherwise the graph can grow without control. + // Think: + // for i in range(0, 100000): + // a = a + b + void TryLimitGraphSize(); + + std::shared_ptr CreateAddNode(XLATensor& other, + const at::Scalar& alpha); + std::shared_ptr CreateMulNode(XLATensor& other); + std::shared_ptr CreateMulNode(const at::Scalar& other); + std::shared_ptr CreateDivNode(XLATensor& other); + std::shared_ptr CreateDivNode(const at::Scalar& other); + + static void ComputeAndDistribute( + XlaGraphContext* xla_graph_ctx, + const std::vector& index_mapping, + const std::vector>& tensors); + + static std::shared_ptr CreateTensorNode( + std::shared_ptr data); + + std::shared_ptr data_; + bool requires_grad_ = false; +}; + +// If "shape" is a tuple, return the element shapes, otherwise return a +// singleton list containing the original shape. +std::vector GetComponentShapes(const xla::Shape& shape); + +// Create a shape with "device_type" compatible layout from the given "shape". +xla::Shape MakeShapeWithDeviceLayout(const xla::Shape& shape, + const XLATensor::DeviceType device_type); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp new file mode 100644 index 00000000000..15abdc51bbc --- /dev/null +++ b/torch_xla/csrc/torch_util.cpp @@ -0,0 +1,40 @@ +#include "torch_util.h" + +namespace torch { +namespace jit { + +XlaModule::TensorBatchVector XlaCreateTensorList(const py::tuple& tuple) { + XlaModule::TensorBatchVector result; + result.reserve(tuple.size()); + for (auto& replica_tuple : tuple) { + XlaModule::TensorBatchVector::value_type replica_result; + for (auto& e : replica_tuple) { + auto variable = py::cast>(e); + replica_result.push_back(variable); + } + result.push_back(std::move(replica_result)); + } + return result; +} + +py::object XlaPackTensorList(const XlaModule::TensorBatchVector& outputs) { + py::tuple tuple(outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& replica_outputs = outputs[i]; + if (replica_outputs.empty()) { + tuple[i] = py::none(); + } else if (replica_outputs.size() == 1) { + tuple[i] = py::cast(replica_outputs[0]); + } else { + py::tuple replica_tuple(replica_outputs.size()); + for (size_t j = 0; j < replica_outputs.size(); j++) { + replica_tuple[j] = py::cast(replica_outputs[j]); + } + tuple[i] = replica_tuple; + } + } + return tuple; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h new file mode 100644 index 00000000000..553586b4649 --- /dev/null +++ b/torch_xla/csrc/torch_util.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +#include "module.h" +#include "tensor.h" +#include "torch/csrc/jit/pybind_utils.h" + +namespace torch { +namespace jit { + +// Extracts a vector of XLA tensors out of a PyThon tuple. +XlaModule::TensorBatchVector XlaCreateTensorList(const py::tuple& tuple); + +// Packs a vector of XLA tensors into a Python tuple, if they are more than one. +py::object XlaPackTensorList(const XlaModule::TensorBatchVector& outputs); + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp new file mode 100644 index 00000000000..4891c8fecb9 --- /dev/null +++ b/torch_xla/csrc/translator.cpp @@ -0,0 +1,483 @@ +#include "translator.h" +#include +#include +#include +#include +#include "batch_norm.h" +#include "convolution.h" +#include "data_ops.h" +#include "elementwise.h" +#include "helpers.h" +#include "log_softmax.h" +#include "nll_loss.h" +#include "pooling.h" +#include "reduction.h" +#include "tensor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" + +namespace torch { +namespace jit { + +namespace { + +std::once_flag create_client_flag; +std::unique_ptr computation_client; + +void CreateClient(std::unique_ptr* client) { + *client = std::move(xla::ComputationClient::Create().ValueOrDie()); +} + +// Create an identity operation of `ret` to make it the root of the computation. +void ActivateReturnNode(const xla::XlaOp& ret, xla::XlaBuilder* b) { + xla::GetTupleElement(xla::Tuple(b, {ret}), 0); +} + +// Context class to hold together all the necessary state for the XLA +// computation building process out of a PyTorch graph. +class ComputationContext { + public: + static size_t OutputId(const Node* node) { + const auto node_outputs = node->outputs(); + CHECK_EQ(node_outputs.size(), 1); + return node_outputs[0]->unique(); + } + + void AddNodeOpById(size_t id, xla::XlaOp op) { + const auto it_ok = node_xla_ops_.emplace(id, std::move(op)); + CHECK(it_ok.second) << "Duplicated IR node ID: " << id; + } + + void AddNodeOp(const Node* node, xla::XlaOp op) { + AddNodeOpById(OutputId(node), op); + } + + void AddValueOp(const Value* value, xla::XlaOp op) { + AddNodeOpById(value->unique(), std::move(op)); + } + + void AddInputOp(xla::XlaOp op) { input_ops_.push_back(std::move(op)); } + + void AddUndefinedInput(size_t index) { undefined_inputs_.insert(index); } + + const xla::XlaOp& GetOpForValue(const Value* value) const { + auto it = node_xla_ops_.find(value->unique()); + CHECK(it != node_xla_ops_.end()) << value->uniqueName(); + return it->second; + } + + c10::optional OpForInput(const Node* node, + size_t input_index) const { + const auto node_inputs = node->inputs(); + const auto input = node_inputs.at(input_index); + // Check if is prim::Undefined. + if (undefined_inputs_.count(input->unique()) > 0) { + return at::nullopt; + } + // Check in constructed xla ops. + auto it = node_xla_ops_.find(input->unique()); + if (it == node_xla_ops_.end()) { + return at::nullopt; + } + return it->second; + } + + std::vector ReleaseInputs() { return std::move(input_ops_); } + + size_t GetInputsSize() const { return input_ops_.size(); } + + const std::unordered_map& GetNodeOps() const { + return node_xla_ops_; + } + + const std::unordered_set& GetUndefinedInputs() const { + return undefined_inputs_; + } + + private: + std::vector input_ops_; + std::unordered_map node_xla_ops_; + std::unordered_set undefined_inputs_; +}; + +} // namespace + +xla::ComputationClient* XlaGetClient() { + std::call_once(create_client_flag, CreateClient, &computation_client); + return computation_client.get(); +} + +XlaTranslator::XlaTranslator( + const std::shared_ptr& graph, + const xla::PrecisionConfig::Precision conv_precision) + : graph_(graph), conv_precision_(conv_precision) {} + +xla::XlaComputation XlaTranslator::BuildComputation( + const std::vector& parameter_shapes, + const BuildOptions& options) const { + xla::XlaBuilder b("XlaComputation"); + const auto returned_tuple = BuildComputationProgram(parameter_shapes, &b); + std::vector returned_tuple_outputs; + CHECK_GE(returned_tuple.inputs.size(), options.param_to_return_count); + // The forward computation in a fused forward and backward computation needs + // to make its inputs available to the backward computation. + for (size_t i = 0; i < options.param_to_return_count; ++i) { + returned_tuple_outputs.push_back(returned_tuple.inputs[i]); + } + returned_tuple_outputs.insert(returned_tuple_outputs.end(), + returned_tuple.outputs.begin(), + returned_tuple.outputs.end()); + if (options.output_transform) { + for (size_t i = 0; i < returned_tuple_outputs.size(); ++i) { + returned_tuple_outputs[i] = + options.output_transform(returned_tuple_outputs[i], i); + } + } + if (returned_tuple_outputs.size() > 1) { + xla::Tuple(&b, returned_tuple_outputs); + } else { + // Ensure that the returned value is the root of the computation. + ActivateReturnNode(returned_tuple_outputs[0], &b); + } + return b.Build().ValueOrDie(); +} + +XlaComputationInOut XlaTranslator::BuildComputationProgram( + const std::vector& parameter_shapes, + xla::XlaBuilder* b) const { + ComputationContext cctx; + const auto graph_inputs = graph_->inputs(); + for (size_t parameter_number = 0; parameter_number < graph_inputs.size(); + ++parameter_number) { + Value* graph_input = graph_inputs[parameter_number]; + if (!parameter_shapes[parameter_number].zero_input) { + auto param_no = cctx.GetInputsSize(); + const auto parameter_op = + xla::Parameter(b, param_no, parameter_shapes[parameter_number].shape, + "parameter_" + std::to_string(param_no)); + cctx.AddValueOp(graph_input, parameter_op); + cctx.AddInputOp(parameter_op); + } else { + // The backward method of the model creates all-zeros grad outputs we + // represent as XLATensor with no data and empty shape. + cctx.AddValueOp(graph_input, + XlaHelpers::ScalarBroadcast( + 0, parameter_shapes[parameter_number].shape, b)); + } + } + auto nodes = graph_->block()->nodes(); + for (auto node : nodes) { + switch (node->kind()) { + case aten::add: + case aten::mul: { + const auto node_inputs = node->inputs(); + if (node_inputs.size() < 2) { + AT_ERROR("Unsupported arity for binary operator ", + node->kind().toQualString()); + } + xla::XlaOp xla_output; + auto input_op_1 = cctx.OpForInput(node, 1); + if (!input_op_1) { + const auto other = XlaHelpers::ScalarValue( + node->get(attr::other).value().to(), b); + xla_output = + BuildArithmeticOp(node, *cctx.OpForInput(node, 0), other); + } else { + xla_output = + BuildArithmeticOp(node, *cctx.OpForInput(node, 0), *input_op_1); + } + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::gt: { + if (node->inputs().size() != 2) { + AT_ERROR("Unsupported arity for aten::gt"); + } + xla::XlaOp xla_output = + BuildComparisonOp(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::type_as: { + CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = BuildTypeAs(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::convolution: + case aten::thnn_conv2d_forward: { + if (node->inputs().size() < 3) { + AT_ERROR("Unsupported number of inputs for convolution: ", + node->inputs().size()); + } + + xla::XlaOp xla_output; + auto opt_op = cctx.OpForInput(node, 3); + if (opt_op) { // bias exists + xla_output = BuildConvolutionBias(node, *cctx.OpForInput(node, 0), + *cctx.OpForInput(node, 1), *opt_op, + conv_precision_); + } else { + xla_output = + BuildConvolution(node, *cctx.OpForInput(node, 0), + *cctx.OpForInput(node, 1), conv_precision_); + } + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::thnn_conv2d_backward: { + CHECK_EQ(node->inputs().size(), 9); + const auto conv2d_grads = BuildConv2dBackward( + node, *cctx.OpForInput(node, 0), *cctx.OpForInput(node, 1), + *cctx.OpForInput(node, 2), conv_precision_); + const auto node_outputs = node->outputs(); + cctx.AddValueOp(node_outputs[0], conv2d_grads.grad_input); + cctx.AddValueOp(node_outputs[1], conv2d_grads.grad_weight); + cctx.AddValueOp(node_outputs[2], conv2d_grads.grad_bias); + break; + } + case aten::t: { + CHECK_EQ(node->inputs().size(), 1); + xla::XlaOp xla_output = + xla::Transpose(*cctx.OpForInput(node, 0), {1, 0}); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::addmm: { + if (node->inputs().size() < 3) { + AT_ERROR("Unsupported number of inputs for linear layer: ", + node->inputs().size()); + } + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(conv_precision_); + xla::XlaOp xla_output = + xla::Dot(*cctx.OpForInput(node, 1), *cctx.OpForInput(node, 2), + &precision_config) + + *cctx.OpForInput(node, 0); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::mm: { + CHECK_EQ(node->inputs().size(), 2); + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(conv_precision_); + xla::XlaOp xla_output = + xla::Dot(*cctx.OpForInput(node, 0), *cctx.OpForInput(node, 1), + &precision_config); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::max_pool2d_with_indices: { + CHECK_GE(node->inputs().size(), 1); + CHECK_GE(node->outputs().size(), 1); + xla::XlaOp xla_output = BuildMaxPool2d(node, *cctx.OpForInput(node, 0)); + const auto node_outputs = node->outputs(); + CHECK_GE(node_outputs.size(), 1); + cctx.AddValueOp(node_outputs[0], xla_output); + break; + } + case aten::max_pool2d_with_indices_backward: { + CHECK_EQ(node->inputs().size(), 8); + xla::XlaOp xla_output = BuildMaxPool2dBackward( + node, *cctx.OpForInput(node, 0), *cctx.OpForInput(node, 1)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::avg_pool2d: { + CHECK_GE(node->inputs().size(), 1); + xla::XlaOp xla_output = BuildAvgPool2d(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::avg_pool2d_backward: { + CHECK_GE(node->inputs().size(), 2); + xla::XlaOp xla_output = BuildAvgPool2dBackward( + node, *cctx.OpForInput(node, 0), *cctx.OpForInput(node, 1)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::neg: { + CHECK_EQ(node->inputs().size(), 1); + const auto xla_input = *cctx.OpForInput(node, 0); + xla::XlaOp xla_output = Neg(xla_input); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::tanh: { + CHECK_EQ(node->inputs().size(), 1); + const auto xla_input = *cctx.OpForInput(node, 0); + xla::XlaOp xla_output = Tanh(xla_input); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::sigmoid: { + CHECK_EQ(node->inputs().size(), 1); + const auto xla_input = *cctx.OpForInput(node, 0); + const auto half = XlaHelpers::ScalarValue(0.5, b); + xla::XlaOp xla_output = half + half * Tanh(half * xla_input); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::relu: { + CHECK_EQ(node->inputs().size(), 1); + xla::XlaOp xla_output = xla::Max(*cctx.OpForInput(node, 0), + XlaHelpers::ScalarValue(0, b)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::threshold: { + CHECK_EQ(node->inputs().size(), 3); + xla::XlaOp xla_output = BuildThreshold( + node, *cctx.OpForInput(node, 0), *cctx.OpForInput(node, 0), + node->get(attr::threshold).value().to(), + node->get(attr::value).value().to(), b); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::threshold_backward: { + CHECK_EQ(node->inputs().size(), 3); + xla::XlaOp xla_output = BuildThreshold( + node, *cctx.OpForInput(node, 1), *cctx.OpForInput(node, 0), + node->get(attr::threshold).value().to(), 0, b); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::log_softmax: { + CHECK_EQ(node->inputs().size(), size_t(2)); + xla::XlaOp xla_output = + BuildLogSoftmax(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::_log_softmax_backward_data: { + CHECK_EQ(node->inputs().size(), 4); + xla::XlaOp xla_output = BuildLogSoftmaxGrad( + node, *cctx.OpForInput(node, 0), *cctx.OpForInput(node, 1)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::reshape: + case aten::view: { + CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = BuildView(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::expand: { + CHECK_GE(node->inputs().size(), 1); + xla::XlaOp xla_output = BuildExpand(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::stack: { + CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = + BuildStack(node, + [&cctx](const Value* node) -> xla::XlaOp { + return cctx.GetOpForValue(node); + }, + b); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::cat: { + CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = + BuildCat(node, + [&cctx](const Value* node) -> xla::XlaOp { + return cctx.GetOpForValue(node); + }, + b); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::chunk: { + std::vector xla_outputs = + BuildChunk(node, *cctx.OpForInput(node, 0)); + const auto node_outputs = node->outputs(); + for (size_t i = 0; i < node_outputs.size(); ++i) { + cctx.AddValueOp(node_outputs[i], xla_outputs[i]); + } + break; + } + case aten::native_batch_norm: + case aten::batch_norm: { + CHECK_EQ(node->inputs().size(), 8); + const auto outputs = BuildBatchNorm(node, *cctx.OpForInput(node, 0), + *cctx.OpForInput(node, 1), + *cctx.OpForInput(node, 2)); + const auto node_outputs = node->outputs(); + cctx.AddValueOp(node_outputs[0], outputs.output); + if (node->kind() == aten::batch_norm) { + CHECK_EQ(node->outputs().size(), 1); + } + // aten::batch_norm only has 1 output + // native_batch_norm_forward has output, save_mean, save_std + if (node->kind() == aten::native_batch_norm) { + cctx.AddValueOp(node_outputs[1], outputs.save_mean); + cctx.AddValueOp(node_outputs[2], outputs.save_invstd_eps); + } + break; + } + case aten::native_batch_norm_backward: { + CHECK_EQ(node->inputs().size(), 10); + auto grads = BuildBatchNormBackward( + node, *cctx.OpForInput(node, 0), // grad_output + *cctx.OpForInput(node, 1), // input + *cctx.OpForInput(node, 2), // weight + *cctx.OpForInput(node, 7), // save_mean + *cctx.OpForInput(node, 8)); // save_std + const auto node_outputs = node->outputs(); + cctx.AddValueOp(node_outputs[0], grads.grad_input); + cctx.AddValueOp(node_outputs[1], grads.grad_weight); + cctx.AddValueOp(node_outputs[2], grads.grad_bias); + break; + } + case aten::sum: { + CHECK_GE(node->inputs().size(), 1); + xla::XlaOp xla_output = BuildSum(node, *cctx.OpForInput(node, 0)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::nll_loss: { + CHECK_EQ(node->inputs().size(), 5); + xla::XlaOp xla_output = BuildNllLoss(node, *cctx.OpForInput(node, 0), + *cctx.OpForInput(node, 1)); + cctx.AddNodeOp(node, xla_output); + break; + } + case aten::nll_loss_backward: { + CHECK_EQ(node->inputs().size(), 7); + xla::XlaOp xla_output = BuildNllLossBackward( + node, *cctx.OpForInput(node, 1), *cctx.OpForInput(node, 2)); + cctx.AddNodeOp(node, xla_output); + break; + } + case prim::Constant: + case prim::ListConstruct: { + break; + } + case prim::Undefined: { + cctx.AddUndefinedInput(ComputationContext::OutputId(node)); + break; + } + default: + AT_ERROR("Unsupported operator: ", node->kind().toQualString()); + } + } + const auto return_node = graph_->return_node(); + const auto node_inputs = return_node->inputs(); + // TODO: tighten the id check for returned tuples. + if (return_node->kind() != prim::Return || node_inputs.empty()) { + AT_ERROR("Unexpected end of graph"); + } + std::vector returned_tuple; + for (const auto return_input : node_inputs) { + returned_tuple.push_back(cctx.GetOpForValue(return_input)); + } + return XlaComputationInOut{cctx.ReleaseInputs(), returned_tuple}; +} + +} // namespace jit +} // namespace torch diff --git a/torch_xla/csrc/translator.h b/torch_xla/csrc/translator.h new file mode 100644 index 00000000000..07f78262afa --- /dev/null +++ b/torch_xla/csrc/translator.h @@ -0,0 +1,60 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "torch/csrc/jit/ir.h" + +namespace torch { +namespace jit { + +struct XlaComputationInOut { + std::vector inputs; + std::vector outputs; +}; + +class XlaTranslator { + public: + struct ParameterShape { + // A shape created with zero_input == true, when passed to the + // BuildComputation*() APIs, will generate an artificial zero input (of + // proper shape) value for the XLA computation. + ParameterShape(xla::Shape shape, bool zero_input) + : shape(std::move(shape)), zero_input(zero_input) {} + + xla::Shape shape; + bool zero_input; + }; + + struct BuildOptions { + BuildOptions() {} + + // The number of parameters to return, before the real computation outputs. + size_t param_to_return_count = 0; + // Optional transfor function which is called to apply transformation to the + // computation outputs before they get merged into the output tuple. + std::function output_transform; + }; + + XlaTranslator(const std::shared_ptr& graph, + const xla::PrecisionConfig::Precision conv_precision); + + // Builds and compiles the XLA computation for graph_. + xla::XlaComputation BuildComputation( + const std::vector& parameter_shapes, + const BuildOptions& options = BuildOptions()) const; + + // Builds the XLA computation for graph_ without compiling it and returns the + // XLA operations for inputs and outputs. + XlaComputationInOut BuildComputationProgram( + const std::vector& parameter_shapes, + xla::XlaBuilder* b) const; + + private: + std::shared_ptr graph_; + xla::PrecisionConfig::Precision conv_precision_; +}; + +xla::ComputationClient* XlaGetClient(); + +} // namespace jit +} // namespace torch