From 1bbe4da79753bf0a4bd94ff106cf1267f841f302 Mon Sep 17 00:00:00 2001 From: stgpetrovic Date: Thu, 16 Feb 2023 17:28:05 +0100 Subject: [PATCH] Bazel (#4636) * Replace tensorflow with a bazel external repository * Basic migration to bazel for xla_client. * Revert to blob * Add vscode config. * Update newlines * Merge with pjrt client test build changes. * Migrate tests to new build * Format test and plugin * Order imports * Conditionally apply tf patches; apply pt patches always. * Format python * configure formatters * Mirror TF pin update an fixes in bazel. * Support local and sandboxed build based on flags * Add cloud cache URLs for llvm. * Merge with upstream * Update TF pin * Fix patching regression * Remove the citcleci setup downloading llvm * Rework the experimental dockerfile for bazel support --- .bazelrc | 176 ++++++++++++++++ .bazelversion | 1 + .circleci/config.yml | 1 - .gitignore | 5 +- .gitmodules | 5 - .vscode/settings.json | 18 ++ WORKSPACE | 61 ++++++ build_torch_xla_libs.sh | 50 ++--- docker/experimental/Dockerfile | 4 - scripts/apply_patches.sh | 10 +- setup.py | 45 +--- test/cpp/CMakeLists.txt | 15 +- test/cpp/cpp_test_util.cpp | 4 +- test/cpp/cpp_test_util.h | 2 +- test/cpp/metrics_snapshot.cpp | 4 +- test/cpp/metrics_snapshot.h | 2 +- test/cpp/test_async_task.cpp | 2 +- test/cpp/test_aten_xla_tensor.cpp | 2 +- test/cpp/test_mayberef.cpp | 2 +- test/cpp/test_replication.cpp | 8 +- test/cpp/test_xla_sharding.cpp | 4 +- test/cpp/test_xla_util_cache.cpp | 4 +- test/cpp/torch_xla_test.cpp | 4 +- tf_patches/BUILD | 0 tf_patches/bazel.diff | 13 ++ tf_patches/cache_urls.diff | 30 +++ third_party/tensorflow | 1 - third_party/xla_client/BUILD | 198 +++++++++--------- third_party/xla_client/async_task.h | 4 +- third_party/xla_client/computation_client.cc | 19 +- third_party/xla_client/computation_client.h | 8 +- third_party/xla_client/debug_macros.h | 2 +- third_party/xla_client/env_vars.cc | 2 +- third_party/xla_client/mesh_service.cc | 16 +- third_party/xla_client/mesh_service.h | 2 +- third_party/xla_client/metrics.cc | 6 +- third_party/xla_client/metrics.h | 2 +- third_party/xla_client/metrics_analysis.cc | 10 +- third_party/xla_client/metrics_reader.cc | 10 +- third_party/xla_client/multi_wait.cc | 2 +- third_party/xla_client/nccl_distributed.cc | 4 +- .../xla_client/pjrt_computation_client.cc | 10 +- .../xla_client/pjrt_computation_client.h | 6 +- .../pjrt_computation_client_test.cc | 8 +- third_party/xla_client/profiler.cc | 2 +- third_party/xla_client/record_reader.cc | 4 +- third_party/xla_client/sys_util.cc | 2 +- third_party/xla_client/sys_util.h | 4 + third_party/xla_client/sys_util_test.cc | 33 +++ third_party/xla_client/tf_logging.cc | 2 +- third_party/xla_client/thread_pool.cc | 6 +- third_party/xla_client/triggered_task.cc | 2 +- third_party/xla_client/unique.h | 2 +- third_party/xla_client/util.cc | 2 +- third_party/xla_client/util.h | 2 +- third_party/xla_client/util_test.cc | 73 +++++++ third_party/xla_client/xla_util.cc | 8 +- third_party/xla_client/xla_util.h | 2 +- .../xla_client/xrt_computation_client.cc | 26 +-- .../xla_client/xrt_computation_client.h | 20 +- third_party/xla_client/xrt_local_service.cc | 4 +- third_party/xla_client/xrt_local_service.h | 2 +- third_party/xla_client/xrt_session.cc | 2 +- third_party/xla_client/xrt_session.h | 2 +- third_party/xla_client/xrt_session_cache.cc | 6 +- third_party/xla_client/xrt_session_cache.h | 2 +- torch_xla/csrc/aten_cpu_fallback.cpp | 6 +- torch_xla/csrc/aten_xla_bridge.cpp | 4 +- torch_xla/csrc/aten_xla_type.cpp | 8 +- torch_xla/csrc/computation.cpp | 2 +- torch_xla/csrc/computation.h | 6 +- torch_xla/csrc/convert_ops.cpp | 2 +- torch_xla/csrc/convolution.cpp | 1 - torch_xla/csrc/cross_replica_reduces.cpp | 4 +- torch_xla/csrc/data_ops.cpp | 6 +- torch_xla/csrc/debug_util.cpp | 6 +- torch_xla/csrc/device.cpp | 4 +- torch_xla/csrc/device.h | 2 +- torch_xla/csrc/elementwise.cpp | 2 +- torch_xla/csrc/function_call_tracker.cpp | 2 +- torch_xla/csrc/generated_file_include.h | 4 +- torch_xla/csrc/helpers.cpp | 8 +- torch_xla/csrc/helpers.h | 4 +- torch_xla/csrc/init_python_bindings.cpp | 24 +-- torch_xla/csrc/ir.cpp | 6 +- torch_xla/csrc/ir.h | 2 +- torch_xla/csrc/ir_builder.h | 2 +- torch_xla/csrc/ir_dump_util.cpp | 4 +- torch_xla/csrc/ir_util.cpp | 2 +- torch_xla/csrc/layout_manager.cpp | 8 +- torch_xla/csrc/lowering_context.cpp | 4 +- torch_xla/csrc/lowering_context.h | 2 +- torch_xla/csrc/nms_op.cpp | 4 +- torch_xla/csrc/op_by_op_executor.cpp | 8 +- torch_xla/csrc/op_by_op_executor.h | 8 +- torch_xla/csrc/ops/adaptive_max_pool2d.cpp | 2 +- torch_xla/csrc/ops/all_gather.cpp | 2 +- torch_xla/csrc/ops/all_reduce.cpp | 2 +- torch_xla/csrc/ops/as_strided.cpp | 2 +- torch_xla/csrc/ops/as_strided_view_update.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/constant_pad_nd.cpp | 2 +- .../ops/convolution_backward_overrideable.cpp | 2 +- .../csrc/ops/convolution_overrideable.cpp | 2 +- torch_xla/csrc/ops/device_data.h | 2 +- torch_xla/csrc/ops/diagonal.cpp | 2 +- torch_xla/csrc/ops/discrete_uniform.cpp | 2 +- torch_xla/csrc/ops/dynamic_ir.cpp | 2 +- torch_xla/csrc/ops/einsum_utilities.h | 2 +- torch_xla/csrc/ops/expand_symint.cpp | 2 +- torch_xla/csrc/ops/index_get.cpp | 2 +- torch_xla/csrc/ops/index_ops.cpp | 4 +- torch_xla/csrc/ops/log_softmax_backward.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/max_unpool_nd.cpp | 2 +- torch_xla/csrc/ops/mse_loss.cpp | 4 +- torch_xla/csrc/ops/mse_loss_backward.cpp | 2 +- .../csrc/ops/native_batch_norm_backward.cpp | 2 +- .../csrc/ops/native_batch_norm_forward.cpp | 2 +- torch_xla/csrc/ops/nll_loss.cpp | 2 +- torch_xla/csrc/ops/nll_loss2d.cpp | 2 +- torch_xla/csrc/ops/nll_loss2d_backward.cpp | 4 +- torch_xla/csrc/ops/nll_loss_backward.cpp | 4 +- torch_xla/csrc/ops/nms.cpp | 2 +- torch_xla/csrc/ops/not_supported.cpp | 2 +- torch_xla/csrc/ops/ops.cpp | 4 +- torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 2 +- torch_xla/csrc/ops/permute.cpp | 2 +- torch_xla/csrc/ops/recv.cpp | 2 +- torch_xla/csrc/ops/repeat.cpp | 2 +- torch_xla/csrc/ops/replication_pad.cpp | 2 +- torch_xla/csrc/ops/resize.cpp | 2 +- torch_xla/csrc/ops/scalar.cpp | 2 +- torch_xla/csrc/ops/scalar.h | 2 +- torch_xla/csrc/ops/select.cpp | 2 +- torch_xla/csrc/ops/send.cpp | 2 +- torch_xla/csrc/ops/softmax_backward.cpp | 2 +- torch_xla/csrc/ops/split.cpp | 2 +- torch_xla/csrc/ops/squeeze.cpp | 2 +- torch_xla/csrc/ops/svd.cpp | 2 +- torch_xla/csrc/ops/uniform.cpp | 2 +- .../csrc/ops/upsample_bilinear2d_backward.cpp | 2 +- .../csrc/ops/upsample_nearest2d_backward.cpp | 2 +- torch_xla/csrc/pooling.cpp | 4 +- torch_xla/csrc/random.cpp | 4 +- torch_xla/csrc/reduction.cpp | 2 +- torch_xla/csrc/resize_ops.cpp | 4 +- torch_xla/csrc/softmax_builder.cpp | 2 +- torch_xla/csrc/tensor.cpp | 14 +- torch_xla/csrc/tensor.h | 10 +- torch_xla/csrc/tensor_impl.cpp | 4 +- torch_xla/csrc/tensor_methods.cpp | 8 +- torch_xla/csrc/tensor_ops.cpp | 4 +- torch_xla/csrc/tensor_util.cpp | 14 +- torch_xla/csrc/tensor_util.h | 2 +- torch_xla/csrc/token_handler.cpp | 2 +- torch_xla/csrc/torch_util.cpp | 4 +- torch_xla/csrc/torch_util.h | 2 +- torch_xla/csrc/view.cpp | 4 +- torch_xla/csrc/xla_backend_impl.h | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 14 +- torch_xla/csrc/xla_graph_executor.h | 10 +- torch_xla/csrc/xla_lower_util.cpp | 4 +- torch_xla/csrc/xla_op_builder.cpp | 2 +- 166 files changed, 859 insertions(+), 502 deletions(-) create mode 100644 .bazelrc create mode 100644 .bazelversion create mode 100644 .vscode/settings.json create mode 100644 WORKSPACE create mode 100644 tf_patches/BUILD create mode 100644 tf_patches/bazel.diff create mode 100644 tf_patches/cache_urls.diff delete mode 160000 third_party/tensorflow create mode 100644 third_party/xla_client/sys_util_test.cc create mode 100644 third_party/xla_client/util_test.cc diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 00000000000..8b49a1e2525 --- /dev/null +++ b/.bazelrc @@ -0,0 +1,176 @@ +############################################################################ +# All default build options below. + +# Enable exceptions in C++. +common --copt=-fexceptions + +# Make Bazel print out all options from rc files. +build --announce_rc + +# TODO(goranpetrovic): figure out visibility of tensorflow libraries. +build --nocheck_visibility + +#build --define open_source_build=true + +# We can set this to `standalone` after https://github.com/bazelbuild/bazel/issues/15359 is resolved. +build --spawn_strategy=sandboxed + +build --enable_platform_specific_config + +build --experimental_cc_shared_library + +# Disable enabled-by-default TensorFlow features that we don't care about. +build --define=no_aws_support=true +build --define=no_hdfs_support=true + +build --define=grpc_no_ares=true + +build -c opt + +build --config=short_logs + +########################################################################### + +build:posix --copt=-Wno-sign-compare +build:posix --cxxopt=-std=c++17 +build:posix --host_cxxopt=-std=c++17 + +build:avx_posix --copt=-mavx +build:avx_posix --host_copt=-mavx + +build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx + +build:native_arch_posix --copt=-march=native +build:native_arch_posix --host_copt=-march=native + +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 + +build:cuda --repo_env TF_NEED_CUDA=1 +# "sm" means we emit only cubin, which is forward compatible within a GPU generation. +# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. +build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda --@local_config_cuda//:enable_cuda +build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true +build:cuda --define=xla_python_enable_gpu=true +build:cuda --cxxopt=-DXLA_CUDA=1 + +build:acl --define==build_with_acl=true + +build:nonccl --define=no_nccl_support=true + +build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option + +# Suppress all warning messages. +build:short_logs --output_filter=DONT_MATCH_ANYTHING + +#build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true +build:tpu --define=with_tpu_support=true + +######################################################################### +# RBE config options below. +# Flag to enable remote config +common --experimental_repo_remote_exec +######################################################################### + +# Load rc file with user-specific options. +try-import %workspace%/.bazelrc.user + +# Compile database generation config. +build:compdb --features=-layering_check + +# Test requires Java. +test --java_runtime_version=remotejdk_11 + +# Coverage requires Java and GCC. +coverage --config=coverage +coverage --build_tests_only +build:coverage --java_runtime_version=remotejdk_11 +build:coverage --copt=-DNDEBUG +build:coverage --combined_report=lcov +build:coverage --strategy=TestRunner=sandboxed,local +build:coverage --strategy=CoverageReport=sandboxed,local +build:coverage --experimental_use_llvm_covmap +build:coverage --collect_code_coverage +build:coverage --test_tag_filters=-nocoverage +build:coverage --action_env=CC=gcc +build:coverage --action_env=CXX=g++ + +############################################################################ +############## TensorFlow .bazelrc greatest hits ########################### +############################################################################ + +# Modular TF build options +build:dynamic_kernels --define=dynamic_loaded_kernels=true +build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS +build --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 + +# Default paths for TF_SYSTEM_LIBS +build:linux --define=PREFIX=/usr +build:linux --define=LIBDIR=$(PREFIX)/lib +build:linux --define=INCLUDEDIR=$(PREFIX)/include +build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include + +# On linux, we dynamically link small amount of kernels +build:linux --config=dynamic_kernels + +# For projects which use TensorFlow as part of a Bazel build process, putting +# nothing in a bazelrc will default to a monolithic build. The following line +# opts in to modular op registration support by default. +build --define framework_shared_object=true +build --define tsl_protobuf_header_only=true + +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +# Enable XLA support by default. +build --define=with_xla_support=true + +# See https://github.com/bazelbuild/bazel/issues/7362 for information on what +# --incompatible_remove_legacy_whole_archive flag does. +# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate +# Tensorflow to the default, however test coverage wasn't enough to catch the +# errors. +# There is ongoing work on Bazel team's side to provide support for transitive +# shared libraries. As part of migrating to transitive shared libraries, we +# hope to provide a better mechanism for control over symbol exporting, and +# then tackle this issue again. +# +# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library +# archives in -whole_archive -no_whole_archive. +build --noincompatible_remove_legacy_whole_archive + +# cc_shared_library ensures no library is linked statically more than once. +build --experimental_link_static_libraries_once=false + +# On linux, don't cross compile by default +build:linux --distinct_host_configuration=false + +# Do not risk cache corruption. See: +# https://github.com/bazelbuild/bazel/issues/3360 +build:linux --experimental_guard_against_concurrent_changes + +# Prevent regressions on those two incompatible changes +# TODO: remove those flags when they are flipped in the default Bazel version TF uses. +build --incompatible_enforce_config_setting_visibility + +# Disable TFRT integration for now unless --config=tfrt is specified. +build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils + +# Suppress most C++ complier warnings to reduce log size but allow +# for specific warnings to still be present. +build:linux --copt="-Wno-all" +build:linux --copt="-Wno-extra" +build:linux --copt="-Wno-deprecated" +build:linux --copt="-Wno-deprecated-declarations" +build:linux --copt="-Wno-ignored-attributes" +build:linux --copt="-Wno-array-bounds" +# Add unused-result as an error on Linux. +build:linux --copt="-Wunused-result" +build:linux --copt="-Werror=unused-result" +# Add switch as an error on Linux. +build:linux --copt="-Wswitch" +build:linux --copt="-Werror=switch" +# Required for building with clang +build:linux --copt="-Wno-error=unused-but-set-variable" diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 00000000000..03f488b076a --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +5.3.0 diff --git a/.circleci/config.yml b/.circleci/config.yml index 47eec28c5e4..6108a221069 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -10,7 +10,6 @@ setup_base_docker: &setup_base_docker name: Setup Base Docker Image command: | .circleci/setup_ci_environment.sh - .circleci/download_llvm_raw.sh launch_docker_and_build: &launch_docker_and_build name: Launch Docker Container and Build diff --git a/.gitignore b/.gitignore index 8fc14019340..6f00d98bd8e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,6 @@ torch_xla/csrc/generated/ # Below files are not deleted by "setup.py clean". # Visual Studio Code files -.vscode .vs # Files autogenerated by docs/docs_build.sh @@ -27,3 +26,7 @@ torch_xla/csrc/generated/ # Local terraform state .terraform* + + +# Build system temporary files +/bazel-* diff --git a/.gitmodules b/.gitmodules index 4830c02abbe..e69de29bb2d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +0,0 @@ -[submodule "third_party/tensorflow"] - path = third_party/tensorflow - url = https://github.com/tensorflow/tensorflow.git - ignore = dirty - diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000000..e926c7af017 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,18 @@ +{ + "bsv.bazel.buildFlags": [ + "--config=compdb", + "--sandbox_base=/dev/shm", + ], + "bsv.cc.compdb.targets": [ + "//third_party/xla_client/...", + ], + "coverage-gutters.coverageBaseDir": ".", + "coverage-gutters.showLineCoverage": false, + "coverage-gutters.showGutterCoverage": true, + "coverage-gutters.coverageReportFileName": "./genhtml/index.html", + "coverage-gutters.coverageFileNames": [ "./bazel-out/_coverage/_coverage_report.dat" ], + "lcov.path": [ "./.bazel-out/_coverage/_coverage_report.dat"], + + "python.formatting.provider": "yapf", + "editor.formatOnSave": true +} diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 00000000000..ae695ab99d7 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,61 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# To update TensorFlow to a new revision, +# a) update URL and strip_prefix to the new git commit hash +# b) get the sha256 hash of the commit by running: +# curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum +# and update the sha256 with the result. +http_archive( + name = "org_tensorflow", + patch_args = [ + "-l", + "-p1", + ], + patch_tool = "patch", + patches = [ + "//tf_patches:bazel.diff", + "//tf_patches:cache_urls.diff", + "//tf_patches:cudnn_int8x32.diff", + "//tf_patches:f16_abi_clang.diff", + "//tf_patches:gpu_race_condition.diff", + "//tf_patches:stream_executor.diff", + "//tf_patches:grpc_version.diff", + "//tf_patches:thread_local_random.diff", + "//tf_patches:xplane.diff", + ], + sha256 = "0fdf5067cd9827be2ae14c2ac59cd482e678134b125943be278ad23ea5342181", + strip_prefix = "tensorflow-f7759359f8420d3ca7b9fd19493f2a01bd47b4ef", + urls = [ + "https://github.com/tensorflow/tensorflow/archive/f7759359f8420d3ca7b9fd19493f2a01bd47b4ef.tar.gz", + ], +) + +# For development, one often wants to make changes to the TF repository as well +# as the PyTorch/XLA repository. You can override the pinned repository above with a +# local checkout by either: +# a) overriding the TF repository on the build.py command line by passing a flag +# like: +# bazel --override_repository=org_tensorflow=/path/to/tensorflow +# or +# b) by commenting out the http_archive above and uncommenting the following: +# local_repository( +# name = "org_tensorflow", +# path = "/path/to/tensorflow", +# ) + +# Initialize TensorFlow's external dependencies. +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") + +tf_workspace3() + +load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") + +tf_workspace2() + +load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") + +tf_workspace1() + +load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") + +tf_workspace0() diff --git a/build_torch_xla_libs.sh b/build_torch_xla_libs.sh index eb72de6b478..b55b5482a81 100755 --- a/build_torch_xla_libs.sh +++ b/build_torch_xla_libs.sh @@ -34,7 +34,6 @@ if [[ "$XLA_BAZEL_VERBOSE" == "1" ]]; then VERBOSE="-s" fi -BUILD_STRATEGY="standalone" SANDBOX_BASE="${XLA_SANDBOX_BASE}" if [ -z "$XLA_SANDBOX_BASE" ]; then SANDBOX_BASE="/tmp" @@ -43,14 +42,11 @@ if [[ "$XLA_SANDBOX_BUILD" == "1" ]]; then BUILD_STRATEGY="sandboxed --sandbox_base=${SANDBOX_BASE}" else # We can remove this after https://github.com/bazelbuild/bazel/issues/15359 is resolved - unset CC - unset CXX BUILD_STRATEGY="local" fi -TPUVM_FLAG= if [[ "$TPUVM_MODE" == "1" ]]; then - TPUVM_FLAG="--define=with_tpu_support=true" + OPTS+=(--config=tpu) fi MAX_JOBS= @@ -58,46 +54,26 @@ if [[ ! -z "$BAZEL_JOBS" ]]; then MAX_JOBS="--jobs=$BAZEL_JOBS" fi -OPTS+=(--cxxopt="-std=c++17") -if [[ $(basename -- $CC) =~ ^clang ]]; then - OPTS+=(--cxxopt="-Wno-c++11-narrowing") - OPTS+=(--cxxopt="-Wno-c++14-narrowing") -fi - if [[ "$XLA_CUDA" == "1" ]]; then - OPTS+=(--cxxopt="-DXLA_CUDA=1") OPTS+=(--config=cuda) fi if [[ "$XLA_CPU_USE_ACL" == "1" ]]; then - OPTS+=("--define=build_with_acl=true") + OPTS+=(--config=acl) fi if [ "$CMD" == "clean" ]; then - pushd $THIRD_PARTY_DIR/tensorflow bazel clean - popd -else - # Overlay llvm-raw secondary cache. The remote cache should be updated - # nightly with the pinned llvm archive. Note, this commands will be NO-OP if there is no match. - sed -i '/.*github.com\/llvm.*,/a "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),' \ - $THIRD_PARTY_DIR/tensorflow/third_party/llvm/workspace.bzl - sed -i 's/LLVM_COMMIT)]/LLVM_COMMIT),"https:\/\/storage.googleapis.com\/tpu-pytorch\/llvm-raw\/{commit}.tar.gz".format(commit = LLVM_COMMIT)]/g' \ - $THIRD_PARTY_DIR/tensorflow/tensorflow/compiler/xla/mlir_hlo/WORKSPACE - - cp -r -u -p $THIRD_PARTY_DIR/xla_client $THIRD_PARTY_DIR/tensorflow/tensorflow/compiler/xla/ + exit 0 +fi - pushd $THIRD_PARTY_DIR/tensorflow - # TensorFlow and its dependencies may introduce warning flags from newer compilers - # that PyTorch and PyTorch/XLA's default compilers don't recognize. They become error - # while '-Werror' is used. Therefore, surpress the warnings. - TF_EXTRA_FLAGS="--copt=-Wno-unknown-warning-option" - bazel build $MAX_JOBS $VERBOSE $TPUVM_FLAG $TF_EXTRA_FLAGS --spawn_strategy=$BUILD_STRATEGY --show_progress_rate_limit=20 \ - --define framework_shared_object=false -c "$MODE" "${OPTS[@]}" \ - $XLA_CUDA_CFG //tensorflow/compiler/xla/xla_client:libxla_computation_client.so +# TensorFlow and its dependencies may introduce warning flags from newer compilers +# that PyTorch and PyTorch/XLA's default compilers don't recognize. They become error +# while '-Werror' is used. Therefore, surpress the warnings in .bazelrc or here. +bazel build $MAX_JOBS $VERBOSE --spawn_strategy=$BUILD_STRATEGY --show_progress_rate_limit=20 \ + --define framework_shared_object=false -c "$MODE" "${OPTS[@]}" \ + $XLA_CUDA_CFG //third_party/xla_client:libxla_computation_client.so - popd - mkdir -p torch_xla/lib - chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so - cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so torch_xla/lib -fi +mkdir -p torch_xla/lib +chmod 0644 bazel-bin/third_party/xla_client/libxla_computation_client.so +cp bazel-bin/third_party/xla_client/libxla_computation_client.so torch_xla/lib diff --git a/docker/experimental/Dockerfile b/docker/experimental/Dockerfile index 28be48d91bd..57ddd6ca160 100644 --- a/docker/experimental/Dockerfile +++ b/docker/experimental/Dockerfile @@ -69,9 +69,7 @@ WORKDIR /pytorch/xla/ # Contains actual build artifacts FROM builder AS artifacts -COPY tf_patches/ tf_patches/ COPY third_party/ third_party/ -RUN cd third_party/tensorflow && find ../../tf_patches -name '*.diff' | xargs -t -r -n 1 patch -N -p1 -l -i COPY build_torch_xla_libs.sh . @@ -96,9 +94,7 @@ ARG package_version RUN TORCH_XLA_VERSION=${package_version} BUILD_CPP_TESTS=${build_cpp_tests} TPUVM_MODE=${tpuvm} BUNDLE_LIBTPU=${tpuvm} XLA_CUDA=${cuda} TF_CUDA_COMPUTE_CAPABILITIES=${tf_cuda_compute_capabilities} python setup.py bdist_wheel # Expunge cache to keep image size under control -WORKDIR /pytorch/xla/third_party/tensorflow RUN bazel clean --expunge -WORKDIR /pytorch/xla/ RUN pip install dist/*.whl diff --git a/scripts/apply_patches.sh b/scripts/apply_patches.sh index 1c7b053caaa..05b37555e39 100755 --- a/scripts/apply_patches.sh +++ b/scripts/apply_patches.sh @@ -38,7 +38,9 @@ python $CDIR/cond_patch.py \ $XDIR/torch_patches \ $PTDIR -python $CDIR/cond_patch.py \ - $XDIR/tf_patches \ - $TFDIR - +# Apply TF patches only if requested, since bazel handles that normally. +if [[ -n "${APPLY_TF_PATCHES}" ]]; then + python $CDIR/cond_patch.py \ + $XDIR/tf_patches \ + $TFDIR +fi diff --git a/setup.py b/setup.py index b20a85758e8..97fae73fda5 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,6 @@ import zipfile base_dir = os.path.dirname(os.path.abspath(__file__)) -third_party_path = os.path.join(base_dir, 'third_party') _libtpu_version = '0.1.dev20230213' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' @@ -174,27 +173,6 @@ def maybe_bundle_libtpu(base_dir): libtpu_so.write(z.read('libtpu/libtpu.so')) -def generate_protos(base_dir, third_party_path): - # Application proto files should be in torch_xla/pb/src/ and the generated - # files will go in torch_xla/pb/cpp/. - proto_files = glob.glob(os.path.join(base_dir, 'torch_xla/pb/src/*.proto')) - if proto_files: - protoc = os.path.join( - third_party_path, - 'tensorflow/bazel-out/host/bin/external/com_google_protobuf/protoc') - protoc_cmd = [ - protoc, '-I', - os.path.join(third_party_path, 'tensorflow'), '-I', - os.path.join(base_dir, 'torch_xla/pb/src'), '--cpp_out', - os.path.join(base_dir, 'torch_xla/pb/cpp') - ] + proto_files - if subprocess.call(protoc_cmd) != 0: - print( - 'Failed to generate protobuf files: {}'.format(protoc_cmd), - file=sys.stderr) - sys.exit(1) - - def _compile_parallel(self, sources, output_dir=None, @@ -289,9 +267,6 @@ def run(self): # Copy libtpu.so into torch_xla/lib maybe_bundle_libtpu(base_dir) - # Generate the proto C++/python files only after third_party has built. - generate_protos(base_dir, third_party_path) - # Fetch the sources to be built. torch_xla_sources = ( glob.glob('torch_xla/csrc/*.cpp') + glob.glob('torch_xla/csrc/ops/*.cpp') + @@ -308,16 +283,18 @@ def run(self): base_dir, ] for ipath in [ - 'tensorflow/bazel-bin', - 'tensorflow/bazel-tensorflow', - 'tensorflow/bazel-tensorflow/external/protobuf_archive/src', - 'tensorflow/bazel-tensorflow/external/com_google_protobuf/src', - 'tensorflow/bazel-tensorflow/external/eigen_archive', - 'tensorflow/bazel-tensorflow/external/com_google_absl', - 'tensorflow/bazel-tensorflow/external/com_googlesource_code_re2', - 'tensorflow/bazel-tensorflow/external/com_github_grpc_grpc/include', + 'bazel-bin', + 'bazel-xla', + 'bazel-bin/external/org_tensorflow/', + 'bazel-xla/external/org_tensorflow/', + 'bazel-xla/external/com_github_grpc_grpc/include', + 'bazel-xla/external/com_google_protobuf/src', + 'bazel-xla/external/eigen_archive', + 'bazel-xla/external/com_google_absl', + 'bazel-xla/external/com_googlesource_code_re2', + 'bazel-xla/com_github_grpc_grpc/include', ]: - include_dirs.append(os.path.join(third_party_path, ipath)) + include_dirs.append(os.path.join(base_dir, ipath)) include_dirs += [ pytorch_source_path, os.path.join(pytorch_source_path, 'torch/csrc'), diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 3eb988f2d3d..ece2dd3caae 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -8,7 +8,6 @@ set(GTEST_DIR "${CMAKE_BINARY_DIR}/gtest") get_filename_component(PTXLA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../.." ABSOLUTE) get_filename_component(PT_DIR "${PTXLA_DIR}/.." ABSOLUTE) -set(TFDIR "${PTXLA_DIR}/third_party/tensorflow") file(GLOB PTXLA_LIBDIRS "${PTXLA_DIR}/build/lib.*") list(GET PTXLA_LIBDIRS 0 PTXLA_LIBDIR) @@ -96,12 +95,14 @@ target_include_directories( test_ptxla SYSTEM PUBLIC "${SOURCE_DIR}/googletest/include" - "${TFDIR}/bazel-tensorflow" - "${TFDIR}/bazel-bin" - "${TFDIR}/bazel-tensorflow/external/protobuf_archive/src" - "${TFDIR}/bazel-tensorflow/external/com_google_protobuf/src" - "${TFDIR}/bazel-tensorflow/external/eigen_archive" - "${TFDIR}/bazel-tensorflow/external/com_google_absl" + "${PTXLA_DIR}/bazel-xla" + "${PTXLA_DIR}/bazel-bin" + "${PTXLA_DIR}/bazel-bin/external/org_tensorflow", + "${PTXLA_DIR}/bazel-xla/external/org_tensorflow", + "${PTXLA_DIR}/bazel-xla/external/protobuf_archive/src" + "${PTXLA_DIR}/bazel-xla/external/com_google_protobuf/src" + "${PTXLA_DIR}/bazel-xla/external/eigen_archive" + "${PTXLA_DIR}/bazel-xla/external/com_google_absl" "${PYTHON_INCLUDE_DIR}" ) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 97d85a60657..47e7428c911 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -3,8 +3,8 @@ #include #include -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/test/cpp/cpp_test_util.h b/test/cpp/cpp_test_util.h index 7f724843c85..dc7ebf9395c 100644 --- a/test/cpp/cpp_test_util.h +++ b/test/cpp/cpp_test_util.h @@ -9,7 +9,7 @@ #include #include "absl/types/span.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "third_party/xla_client/computation_client.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" diff --git a/test/cpp/metrics_snapshot.cpp b/test/cpp/metrics_snapshot.cpp index 3da3b6e2879..df4209a32db 100644 --- a/test/cpp/metrics_snapshot.cpp +++ b/test/cpp/metrics_snapshot.cpp @@ -2,8 +2,8 @@ #include -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/metrics.h" namespace torch_xla { diff --git a/test/cpp/metrics_snapshot.h b/test/cpp/metrics_snapshot.h index f92d32dbc71..6b25445b281 100644 --- a/test/cpp/metrics_snapshot.h +++ b/test/cpp/metrics_snapshot.h @@ -7,7 +7,7 @@ #include #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" +#include "third_party/xla_client/metrics.h" namespace torch_xla { namespace cpp_test { diff --git a/test/cpp/test_async_task.cpp b/test/cpp/test_async_task.cpp index 717298b05ae..9fffd9807dd 100644 --- a/test/cpp/test_async_task.cpp +++ b/test/cpp/test_async_task.cpp @@ -3,7 +3,7 @@ #include #include "cpp_test_util.h" -#include "tensorflow/compiler/xla/xla_client/async_task.h" +#include "third_party/xla_client/async_task.h" namespace torch_xla { namespace cpp_test { diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index e58954c0f91..58f1ebe1564 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6,7 +6,7 @@ #include "cpp_test_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" +#include "third_party/xla_client/metrics.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/dynamic_ir.h" diff --git a/test/cpp/test_mayberef.cpp b/test/cpp/test_mayberef.cpp index dfc241ed6e9..79d0b84b16b 100644 --- a/test/cpp/test_mayberef.cpp +++ b/test/cpp/test_mayberef.cpp @@ -3,7 +3,7 @@ #include #include "cpp_test_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" namespace torch_xla { namespace cpp_test { diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 24e6e2b7b98..f66fd9d18a1 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -6,10 +6,10 @@ #include "cpp_test_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/thread_pool.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index b4195997ad5..aa6c12f44a4 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -5,9 +5,9 @@ #include #include "cpp_test_util.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/sys_util.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor.h" diff --git a/test/cpp/test_xla_util_cache.cpp b/test/cpp/test_xla_util_cache.cpp index b32f6850af8..66f57bd9b5b 100644 --- a/test/cpp/test_xla_util_cache.cpp +++ b/test/cpp/test_xla_util_cache.cpp @@ -3,8 +3,8 @@ #include #include "cpp_test_util.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/util.h" namespace torch_xla { namespace cpp_test { diff --git a/test/cpp/torch_xla_test.cpp b/test/cpp/torch_xla_test.cpp index e81beb23abb..d848993faf8 100644 --- a/test/cpp/torch_xla_test.cpp +++ b/test/cpp/torch_xla_test.cpp @@ -3,8 +3,8 @@ #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/tf_logging.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor.h" diff --git a/tf_patches/BUILD b/tf_patches/BUILD new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tf_patches/bazel.diff b/tf_patches/bazel.diff new file mode 100644 index 00000000000..4b670953bcb --- /dev/null +++ b/tf_patches/bazel.diff @@ -0,0 +1,13 @@ +diff --git i/tensorflow/tensorflow.bzl w/tensorflow/tensorflow.bzl +index 649c8e22dcc..a85f4bc3af3 100644 +--- i/tensorflow/tensorflow.bzl ++++ w/tensorflow/tensorflow.bzl +@@ -315,7 +315,7 @@ def if_libtpu(if_true, if_false = []): + def if_with_tpu_support(if_true, if_false = []): + """Shorthand for select()ing whether to build API support for TPUs when building TensorFlow""" + return select({ +- "//tensorflow:with_tpu_support": if_true, ++ clean_dep("//tensorflow:with_tpu_support"): if_true, + "//conditions:default": if_false, + }) + \ No newline at end of file diff --git a/tf_patches/cache_urls.diff b/tf_patches/cache_urls.diff new file mode 100644 index 00000000000..e99ccdab424 --- /dev/null +++ b/tf_patches/cache_urls.diff @@ -0,0 +1,30 @@ +diff --git i/tensorflow/compiler/xla/mlir_hlo/WORKSPACE w/tensorflow/compiler/xla/mlir_hlo/WORKSPACE +index cc9eeb64f02..b290eb4556c 100644 +--- i/tensorflow/compiler/xla/mlir_hlo/WORKSPACE ++++ w/tensorflow/compiler/xla/mlir_hlo/WORKSPACE +@@ -35,7 +35,10 @@ http_archive( + build_file_content = "# empty", + sha256 = LLVM_SHA256, + strip_prefix = "llvm-project-" + LLVM_COMMIT, +- urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], ++ urls = [ ++ "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), ++ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), ++ ], + ) + + load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") +diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl +index 02a0c926c99..caa0f5cbed9 100644 +--- i/third_party/llvm/workspace.bzl ++++ w/third_party/llvm/workspace.bzl +@@ -13,7 +13,9 @@ def repo(name): + strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT), + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), ++ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), + "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), ++ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), + ], + build_file = "//third_party/llvm:llvm.BUILD", + patch_file = [ diff --git a/third_party/tensorflow b/third_party/tensorflow deleted file mode 160000 index f7759359f84..00000000000 --- a/third_party/tensorflow +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f7759359f8420d3ca7b9fd19493f2a01bd47b4ef diff --git a/third_party/xla_client/BUILD b/third_party/xla_client/BUILD index c0e8a61d126..fb1023bce8f 100644 --- a/third_party/xla_client/BUILD +++ b/third_party/xla_client/BUILD @@ -1,21 +1,20 @@ -load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load( - "//tensorflow:tensorflow.bzl", + "@org_tensorflow//tensorflow:tensorflow.bzl", "if_with_tpu_support", "tf_cc_shared_object", ) load( - "//tensorflow/tsl/platform/default:build_config.bzl", + "@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl", "tf_proto_library_cc", ) load( - "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", + "@org_tensorflow//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) licenses(["notice"]) # Apache 2.0 -package(default_visibility = ["//tensorflow:internal"]) +package(default_visibility = ["@org_tensorflow//tensorflow:internal"]) exports_files( [ @@ -31,36 +30,36 @@ tf_proto_library_cc( cc_api_version = 2, cc_grpc_version = 1, protodeps = [ - "//tensorflow/core/protobuf/tpu:topology_proto", + "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto", ], ) tf_cc_shared_object( name = "libxla_computation_client.so", linkopts = select({ - "//tensorflow:windows": [], + "@org_tensorflow//tensorflow:windows": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/compiler/xla/xla_client:tf_version_script.lds)", + "$(location :tf_version_script.lds)", ], }), visibility = ["//visibility:public"], deps = [ ":computation_client_impl", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:svd", - "//tensorflow/compiler/xla/rpc:grpc_stub", - "//tensorflow/compiler/xla/xla_client:tf_exported_symbols.lds", - "//tensorflow/compiler/xla/xla_client:tf_version_script.lds", - "//tensorflow/core:lib", - "//tensorflow/core/platform/cloud:gcs_file_system", - "//tensorflow/python/profiler/internal:profiler_pywrap_impl", + ":tf_exported_symbols.lds", + ":tf_version_script.lds", "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/compiler/xla:literal_util", + "@org_tensorflow//tensorflow/compiler/xla/client", + "@org_tensorflow//tensorflow/compiler/xla/client:global_data", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:svd", + "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system", + "@org_tensorflow//tensorflow/python/profiler/internal:profiler_pywrap_impl", ], ) @@ -118,91 +117,92 @@ cc_library( ], deps = [ ":mesh_service_proto_cc", - "//tensorflow:grpc", - "//tensorflow:grpc++", - "//tensorflow/cc:client_session", - "//tensorflow/cc:scope", - "//tensorflow/compiler/jit:xla_cpu_device", - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_proto_cc", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:comparators", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:logdet", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:pooling", - "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:svd", - "//tensorflow/compiler/xla/client/lib:tridiagonal", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", - "//tensorflow/compiler/xla/pjrt/distributed:distributed", - "//tensorflow/compiler/xla/pjrt:tpu_client", - "//tensorflow/compiler/xla/pjrt:pjrt_client", - "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", - "//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", - "//tensorflow/compiler/xla/rpc:grpc_stub", - "//tensorflow/compiler/xla/service:cpu_plugin", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service/spmd:spmd_partitioner", - "//tensorflow/compiler/xrt:xrt_proto_cc", - "//tensorflow/compiler/xrt:xrt_server", - "//tensorflow/compiler/xrt:xrt_utils", - "//tensorflow/compiler/xrt/cc:xrt_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_runtime", - "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/profiler/rpc:profiler_server_impl", - "//tensorflow/core/profiler/rpc/client:profiler_client", - "//tensorflow/core/protobuf/tpu:topology_proto_cc", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/cc:client_session", + "@org_tensorflow//tensorflow/cc:scope", + "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_device", + "@org_tensorflow//tensorflow/compiler/xla/client", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:arithmetic", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:comparators", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:constants", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:logdet", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:math", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:matrix", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:pooling", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:qr", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:slicing", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:sorting", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:svd", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:tridiagonal", + "@org_tensorflow//tensorflow/compiler/xla/client:global_data", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/compiler/xla/hlo/ir:hlo", + "@org_tensorflow//tensorflow/compiler/xla/pjrt/distributed:distributed", + "@org_tensorflow//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:tpu_client", + "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", + "@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner", + "@org_tensorflow//tensorflow/compiler/xla/service:cpu_plugin", + "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", + "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_impl", + "@org_tensorflow//tensorflow/compiler/xla:debug_options_flags", + "@org_tensorflow//tensorflow/compiler/xla:literal", + "@org_tensorflow//tensorflow/compiler/xla:literal_util", + "@org_tensorflow//tensorflow/compiler/xla:shape_util", + "@org_tensorflow//tensorflow/compiler/xla:statusor", + "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", + "@org_tensorflow//tensorflow/compiler/xla:xla_proto_cc", + "@org_tensorflow//tensorflow/compiler/xrt/cc:xrt_ops", + "@org_tensorflow//tensorflow/compiler/xrt:xrt_proto_cc", + "@org_tensorflow//tensorflow/compiler/xrt:xrt_server", + "@org_tensorflow//tensorflow/compiler/xrt:xrt_utils", + "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_runtime", + "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", + "@org_tensorflow//tensorflow/core/kernels:data_flow", + "@org_tensorflow//tensorflow/core/profiler/rpc/client:profiler_client", + "@org_tensorflow//tensorflow/core/profiler/rpc:profiler_server_impl", + "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@org_tensorflow//tensorflow/core:core_cpu", + "@org_tensorflow//tensorflow/core:framework_internal", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@org_tensorflow//tensorflow:grpc", + "@org_tensorflow//tensorflow:grpc++", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", - "//tensorflow/compiler/jit:xla_gpu_device", - "//tensorflow/compiler/xla/stream_executor:cuda_platform", + "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_device", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor:cuda_platform", ]) + if_with_tpu_support([ - "//tensorflow/compiler/jit:xla_tpu_device", - "//tensorflow/compiler/jit:xla_tpu_jit", + "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_device", + "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_jit", ]), alwayslink = 1, ) -xla_cc_test( - name = "pjrt_computation_client_test", - srcs = ["pjrt_computation_client_test.cc"], - deps = [ - ":computation_client_impl", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tools:hlo_module_loader", - "//tensorflow/core/platform:logging", - "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:env", - "//tensorflow/tsl/platform:test", - "//tensorflow/tsl/platform:test_main", - ], -) +# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. +# xla_cc_test( +# name = "pjrt_computation_client_test", +# srcs = ["pjrt_computation_client_test.cc"], +# deps = [ +# ":computation_client_impl", +# "@org_tensorflow//tensorflow/compiler/xla:literal", +# "@org_tensorflow//tensorflow/compiler/xla:literal_util", +# "@org_tensorflow//tensorflow/compiler/xla:shape_util", +# "@org_tensorflow//tensorflow/compiler/xla:status", +# "@org_tensorflow//tensorflow/compiler/xla:statusor", +# "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", +# "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", +# "@org_tensorflow//tensorflow/compiler/xla/tests:literal_test_util", +# "@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", +# "@org_tensorflow//tensorflow/core/platform:logging", +# "@org_tensorflow//tensorflow/tsl/lib/core:status_test_util", +# "@org_tensorflow//tensorflow/tsl/platform:env", +# "@org_tensorflow//tensorflow/tsl/platform:test", +# "@org_tensorflow//tensorflow/tsl/platform:test_main", +# ], +# ) diff --git a/third_party/xla_client/async_task.h b/third_party/xla_client/async_task.h index af869cac338..9aad577b48b 100644 --- a/third_party/xla_client/async_task.h +++ b/third_party/xla_client/async_task.h @@ -8,8 +8,8 @@ #include #include "absl/types/optional.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/thread_pool.h" namespace xla { namespace util { diff --git a/third_party/xla_client/computation_client.cc b/third_party/xla_client/computation_client.cc index c10e9130c5c..516eed96242 100644 --- a/third_party/xla_client/computation_client.cc +++ b/third_party/xla_client/computation_client.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "third_party/xla_client/computation_client.h" #include #include @@ -11,15 +11,15 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/mesh_service.h" -#include "tensorflow/compiler/xla/xla_client/pjrt_computation_client.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/stacktrace_handler.h" #include "tensorflow/core/util/device_name_utils.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/mesh_service.h" +#include "third_party/xla_client/pjrt_computation_client.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/xrt_computation_client.h" namespace xla { namespace { @@ -144,7 +144,7 @@ void AddXrtHostDevices(const std::string& worker_name, int task_no, struct Devices { const char* name; const char* tf_name; - int count; + int64_t count; } const devices[] = { {"TPU", "TPU", sys_util::GetEnvInt(env::kEnvNumTpu, device_counts.num_tpus)}, @@ -269,8 +269,7 @@ std::unique_ptr ComputationClient::Create() { std::unique_ptr client; if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - client = - std::unique_ptr(new PjRtComputationClient()); + client = std::unique_ptr(new PjRtComputationClient()); } else { XrtComputationClient::Options options; std::unique_ptr topology_proto; diff --git a/third_party/xla_client/computation_client.h b/third_party/xla_client/computation_client.h index 9a540b90a90..5dc3f52261e 100644 --- a/third_party/xla_client/computation_client.h +++ b/third_party/xla_client/computation_client.h @@ -14,10 +14,10 @@ #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/types.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/types.h" +#include "third_party/xla_client/util.h" namespace xla { diff --git a/third_party/xla_client/debug_macros.h b/third_party/xla_client/debug_macros.h index bf289a9b4fb..bf041427c12 100644 --- a/third_party/xla_client/debug_macros.h +++ b/third_party/xla_client/debug_macros.h @@ -2,8 +2,8 @@ #define XLA_CLIENT_DEBUG_MACROS_H_ #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" #include "tensorflow/core/platform/stacktrace.h" +#include "third_party/xla_client/tf_logging.h" #define XLA_ERROR() TF_ERROR_STREAM() #define XLA_CHECK(c) TF_CHECK(c) << "\n" << tensorflow::CurrentStackTrace() diff --git a/third_party/xla_client/env_vars.cc b/third_party/xla_client/env_vars.cc index 48e2f70150d..c1046257500 100644 --- a/third_party/xla_client/env_vars.cc +++ b/third_party/xla_client/env_vars.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/env_vars.h" +#include "third_party/xla_client/env_vars.h" namespace xla { namespace env { diff --git a/third_party/xla_client/mesh_service.cc b/third_party/xla_client/mesh_service.cc index c413d74f742..82f8ccb01e1 100644 --- a/third_party/xla_client/mesh_service.cc +++ b/third_party/xla_client/mesh_service.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/mesh_service.h" +#include "third_party/xla_client/mesh_service.h" #include #include @@ -22,13 +22,13 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/mesh_service.grpc.pb.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/nccl_distributed.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/mesh_service.grpc.pb.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/nccl_distributed.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/util.h" namespace xla { namespace service { diff --git a/third_party/xla_client/mesh_service.h b/third_party/xla_client/mesh_service.h index 5aa0eedd689..1e16a1f6ac7 100644 --- a/third_party/xla_client/mesh_service.h +++ b/third_party/xla_client/mesh_service.h @@ -7,7 +7,7 @@ #include "absl/types/span.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/mesh_service.pb.h" +#include "third_party/xla_client/mesh_service.pb.h" namespace xla { namespace service { diff --git a/third_party/xla_client/metrics.cc b/third_party/xla_client/metrics.cc index 3b236de5aeb..5560a13a8a4 100644 --- a/third_party/xla_client/metrics.cc +++ b/third_party/xla_client/metrics.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/metrics.h" +#include "third_party/xla_client/metrics.h" #include #include @@ -6,9 +6,9 @@ #include "absl/memory/memory.h" #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/core/platform/macros.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" namespace xla { namespace metrics { diff --git a/third_party/xla_client/metrics.h b/third_party/xla_client/metrics.h index 8040feda377..ad2aaca4fcc 100644 --- a/third_party/xla_client/metrics.h +++ b/third_party/xla_client/metrics.h @@ -10,7 +10,7 @@ #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/sys_util.h" namespace xla { namespace metrics { diff --git a/third_party/xla_client/metrics_analysis.cc b/third_party/xla_client/metrics_analysis.cc index 78d1a3c374e..93e0181ec27 100644 --- a/third_party/xla_client/metrics_analysis.cc +++ b/third_party/xla_client/metrics_analysis.cc @@ -1,10 +1,10 @@ -#include "tensorflow/compiler/xla/xla_client/metrics_analysis.h" +#include "third_party/xla_client/metrics_analysis.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/types.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/types.h" namespace xla { namespace metrics { diff --git a/third_party/xla_client/metrics_reader.cc b/third_party/xla_client/metrics_reader.cc index db7e06387fe..13410ff19ab 100644 --- a/third_party/xla_client/metrics_reader.cc +++ b/third_party/xla_client/metrics_reader.cc @@ -1,11 +1,11 @@ -#include "tensorflow/compiler/xla/xla_client/metrics_reader.h" +#include "third_party/xla_client/metrics_reader.h" #include -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/util.h" namespace xla { namespace metrics_reader { diff --git a/third_party/xla_client/multi_wait.cc b/third_party/xla_client/multi_wait.cc index 77d7a633344..d0db93bd0a1 100644 --- a/third_party/xla_client/multi_wait.cc +++ b/third_party/xla_client/multi_wait.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" +#include "third_party/xla_client/multi_wait.h" #include #include diff --git a/third_party/xla_client/nccl_distributed.cc b/third_party/xla_client/nccl_distributed.cc index c9ca43ad4e1..eb6def54e97 100644 --- a/third_party/xla_client/nccl_distributed.cc +++ b/third_party/xla_client/nccl_distributed.cc @@ -1,10 +1,10 @@ -#include "tensorflow/compiler/xla/xla_client/nccl_distributed.h" +#include "third_party/xla_client/nccl_distributed.h" #include #include #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #if XLA_CUDA #include "third_party/nccl/nccl.h" #endif diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc index 468051b7ed0..39123acaad4 100644 --- a/third_party/xla_client/pjrt_computation_client.cc +++ b/third_party/xla_client/pjrt_computation_client.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/pjrt_computation_client.h" +#include "third_party/xla_client/pjrt_computation_client.h" #include @@ -18,11 +18,11 @@ #include "tensorflow/compiler/xla/pjrt/tpu_client.h" #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/tf_logging.h" namespace xla { diff --git a/third_party/xla_client/pjrt_computation_client.h b/third_party/xla_client/pjrt_computation_client.h index 60ad22a3c3d..06de1253027 100644 --- a/third_party/xla_client/pjrt_computation_client.h +++ b/third_party/xla_client/pjrt_computation_client.h @@ -9,9 +9,9 @@ #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h" #include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" namespace xla { diff --git a/third_party/xla_client/pjrt_computation_client_test.cc b/third_party/xla_client/pjrt_computation_client_test.cc index 27c33344db3..5143d85d417 100644 --- a/third_party/xla_client/pjrt_computation_client_test.cc +++ b/third_party/xla_client/pjrt_computation_client_test.cc @@ -1,4 +1,6 @@ -#include "tensorflow/compiler/xla/xla_client/pjrt_computation_client.h" +#include "third_party/xla_client/pjrt_computation_client.h" + +#include #include #include @@ -11,11 +13,11 @@ #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/tsl/lib/core/status_test_util.h" #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/test.h" +#include "third_party/xla_client/computation_client.h" namespace xla { @@ -84,4 +86,4 @@ TEST(PjRtComputationClientTest, Init) { result_literals[0])); } -} // namespace xla \ No newline at end of file +} // namespace xla diff --git a/third_party/xla_client/profiler.cc b/third_party/xla_client/profiler.cc index 8aef088806b..578e119b91b 100644 --- a/third_party/xla_client/profiler.cc +++ b/third_party/xla_client/profiler.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/profiler.h" +#include "third_party/xla_client/profiler.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" diff --git a/third_party/xla_client/record_reader.cc b/third_party/xla_client/record_reader.cc index ddb4148f034..7bcd4f4dd37 100644 --- a/third_party/xla_client/record_reader.cc +++ b/third_party/xla_client/record_reader.cc @@ -1,9 +1,9 @@ -#include "tensorflow/compiler/xla/xla_client/record_reader.h" +#include "third_party/xla_client/record_reader.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "third_party/xla_client/debug_macros.h" namespace xla { namespace util { diff --git a/third_party/xla_client/sys_util.cc b/third_party/xla_client/sys_util.cc index 9d1b6057ada..7d0a9ac1aef 100644 --- a/third_party/xla_client/sys_util.cc +++ b/third_party/xla_client/sys_util.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/sys_util.h" #include #include diff --git a/third_party/xla_client/sys_util.h b/third_party/xla_client/sys_util.h index 2a785f5555a..617bb0b08f6 100644 --- a/third_party/xla_client/sys_util.h +++ b/third_party/xla_client/sys_util.h @@ -8,6 +8,7 @@ namespace xla { namespace sys_util { +// Gets the string environmental variable by `name`, or `defval` if unset. std::string GetEnvString(const char* name, const std::string& defval); std::string GetEnvOrdinalPath(const char* name, const std::string& defval, @@ -17,10 +18,13 @@ std::string GetEnvOrdinalPath( const char* name, const std::string& defval, const char* ordinal_env = "XRT_SHARD_LOCAL_ORDINAL"); +// Gets the integer environmental variable by `name`, or `defval` if unset. int64_t GetEnvInt(const char* name, int64_t defval); +// Gets the double environmental variable by `name`, or `defval` if unset. double GetEnvDouble(const char* name, double defval); +// Gets the boolean environmental variable by `name`, or `defval` if unset. bool GetEnvBool(const char* name, bool defval); // Retrieves the current EPOCH time in nanoseconds. diff --git a/third_party/xla_client/sys_util_test.cc b/third_party/xla_client/sys_util_test.cc new file mode 100644 index 00000000000..eead5190f13 --- /dev/null +++ b/third_party/xla_client/sys_util_test.cc @@ -0,0 +1,33 @@ +#include "third_party/xla_client/sys_util.h" + +#include + +namespace xla { +namespace sys_util { + +TEST(SysUtilTest, Env) { + EXPECT_EQ(GetEnvInt("does-not-exist-hopefully", 42), 42); + EXPECT_EQ(GetEnvString("does-not-exist-hopefully", "42"), "42"); + EXPECT_EQ(GetEnvDouble("does-not-exist-hopefully", 42.0f), 42.0f); + + setenv("ordinal", "42", true); + EXPECT_EQ(GetEnvOrdinalPath("does-not-exist-hopefully", "/path/to/test/data", + "ordinal"), + "/path/to/test/data.42"); + + EXPECT_EQ(GetEnvBool("does-not-exist-hopefully", true), true); + setenv("existing-bool", "true", true); + EXPECT_EQ(GetEnvBool("existing-bool", false), true); + setenv("existing-bool", "false", true); + EXPECT_EQ(GetEnvBool("existing-bool", true), false); + + setenv("existing-bool", "0", true); + EXPECT_EQ(GetEnvBool("existing-bool", true), false); + setenv("existing-bool", "7", true); + EXPECT_EQ(GetEnvBool("existing-bool", false), true); + + EXPECT_GT(NowNs(), 0); +} + +} // namespace sys_util +} // namespace xla \ No newline at end of file diff --git a/third_party/xla_client/tf_logging.cc b/third_party/xla_client/tf_logging.cc index afb83982d8e..8c017a06dbc 100644 --- a/third_party/xla_client/tf_logging.cc +++ b/third_party/xla_client/tf_logging.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" +#include "third_party/xla_client/tf_logging.h" #include diff --git a/third_party/xla_client/thread_pool.cc b/third_party/xla_client/thread_pool.cc index baf5923d8b8..c5a1a486480 100644 --- a/third_party/xla_client/thread_pool.cc +++ b/third_party/xla_client/thread_pool.cc @@ -1,12 +1,12 @@ -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" +#include "third_party/xla_client/thread_pool.h" #include #include #include #include -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/tf_logging.h" namespace xla { namespace env { diff --git a/third_party/xla_client/triggered_task.cc b/third_party/xla_client/triggered_task.cc index 62d74834703..7145397e4cd 100644 --- a/third_party/xla_client/triggered_task.cc +++ b/third_party/xla_client/triggered_task.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/triggered_task.h" +#include "third_party/xla_client/triggered_task.h" namespace xla { namespace util { diff --git a/third_party/xla_client/unique.h b/third_party/xla_client/unique.h index cf5cdeb6599..081319e5cbf 100644 --- a/third_party/xla_client/unique.h +++ b/third_party/xla_client/unique.h @@ -5,7 +5,7 @@ #include #include "absl/types/optional.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" namespace xla { namespace util { diff --git a/third_party/xla_client/util.cc b/third_party/xla_client/util.cc index a880aa4114c..c22a87f457f 100644 --- a/third_party/xla_client/util.cc +++ b/third_party/xla_client/util.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include diff --git a/third_party/xla_client/util.h b/third_party/xla_client/util.h index 00857006b35..1f73d7a133c 100644 --- a/third_party/xla_client/util.h +++ b/third_party/xla_client/util.h @@ -16,9 +16,9 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/xla_client/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" +#include "third_party/xla_client/types.h" namespace xla { namespace util { diff --git a/third_party/xla_client/util_test.cc b/third_party/xla_client/util_test.cc new file mode 100644 index 00000000000..efe8f81155e --- /dev/null +++ b/third_party/xla_client/util_test.cc @@ -0,0 +1,73 @@ +#include "third_party/xla_client/util.h" + +#include +#include + +#include + +namespace xla { +namespace util { + +using ::testing::ElementsAre; + +TEST(UtilTest, Cleanup) { + bool notify = false; + + // Set to true. + { + Cleanup c([¬ify](bool b) { notify = b; }); + c.SetStatus(true); + } + EXPECT_TRUE(notify); + + // Set to false. + { + Cleanup c([¬ify](bool b) { notify = b; }); + c.SetStatus(false); + } + EXPECT_FALSE(notify); + + // Releasing the cleanup will not change the `notify` to true. + { + Cleanup c([¬ify](bool b) { notify = b; }); + c.SetStatus(true); + c.Release(); + } + EXPECT_FALSE(notify); +} + +TEST(UtilTest, Iota) { + EXPECT_THAT(Iota(5, 0, 2), ElementsAre(0, 2, 4, 6, 8)); +} + +TEST(UtilTest, Range) { + EXPECT_THAT(Range(0, 10, 2), ElementsAre(0, 2, 4, 6, 8)); + EXPECT_THAT(Range(10, 0, -2), ElementsAre(10, 8, 6, 4, 2)); +} + +TEST(UtilTest, ToVector) { + EXPECT_THAT(ToVector(std::string("char")), + ElementsAre('c', 'h', 'a', 'r')); +} + +TEST(UtilTest, Equal) { + EXPECT_TRUE(Equal(std::string("this"), std::string("this"))); + EXPECT_FALSE(Equal(std::string("this"), std::string("that"))); +} + +TEST(UtilTest, FindOr) { + std::unordered_map v = {{1, 1}, {2, 2}, {3, 3}}; + EXPECT_EQ(FindOr(v, 1, 7), 1); + EXPECT_EQ(FindOr(v, 2, 7), 2); + EXPECT_EQ(FindOr(v, 3, 7), 3); + EXPECT_EQ(FindOr(v, 10, 7), 7); +} + +TEST(UtilTest, MapInsert) { + std::unordered_map v; + EXPECT_EQ(MapInsert(&v, 1, [] { return 1; }), 1); + EXPECT_EQ(MapInsert(&v, 1, [] { return 7; }), 1); +} + +} // namespace util +} // namespace xla diff --git a/third_party/xla_client/xla_util.cc b/third_party/xla_client/xla_util.cc index 693bb00e57b..7c809894076 100644 --- a/third_party/xla_client/xla_util.cc +++ b/third_party/xla_client/xla_util.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/xla_util.h" #include #include @@ -8,11 +8,11 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/stacktrace.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/util.h" namespace xla { namespace util { diff --git a/third_party/xla_client/xla_util.h b/third_party/xla_client/xla_util.h index 3b1e7ad4807..a2677072b0e 100644 --- a/third_party/xla_client/xla_util.h +++ b/third_party/xla_client/xla_util.h @@ -8,7 +8,7 @@ #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_client/types.h" +#include "third_party/xla_client/types.h" namespace xla { namespace util { diff --git a/third_party/xla_client/xrt_computation_client.cc b/third_party/xla_client/xrt_computation_client.cc index 7cd90beb533..9ea39792ca5 100644 --- a/third_party/xla_client/xrt_computation_client.cc +++ b/third_party/xla_client/xrt_computation_client.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h" +#include "third_party/xla_client/xrt_computation_client.h" #include #include @@ -15,18 +15,18 @@ #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/unique.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/tsl/lib/math/math_util.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/unique.h" +#include "third_party/xla_client/util.h" +#include "third_party/xla_client/xla_util.h" namespace xla { namespace { @@ -75,8 +75,7 @@ class TensorAllocator : public tensorflow::Allocator { // to store a pointer to its AllocBlocks. alignment = std::max(alignment, sizeof(void*)); // To call aligned_alloc(), num_bytes must be multiple of alignment. - num_bytes = - tsl::MathUtil::CeilOfRatio(num_bytes, alignment) * alignment; + num_bytes = tsl::MathUtil::CeilOfRatio(num_bytes, alignment) * alignment; AllocKey alloc_key = {alignment, num_bytes}; void* block = nullptr; @@ -1121,8 +1120,9 @@ std::unique_ptr XrtComputationClient::CreateXrtComputation( tensorflow::Tensor XrtComputationClient::GetArgumentsInputs( absl::Span arguments, const std::string& device) { - tensorflow::Tensor inputs_tensor(tensorflow::DT_INT64, - tensorflow::TensorShape({arguments.size()})); + tensorflow::Tensor inputs_tensor( + tensorflow::DT_INT64, + tensorflow::TensorShape({static_cast(arguments.size())})); for (size_t i = 0; i < arguments.size(); ++i) { const XrtData& xrt_data = dynamic_cast(*arguments[i]); XLA_CHECK_EQ(device, xrt_data.device()); @@ -1226,7 +1226,7 @@ void XrtComputationClient::ReleaseHandles( session_and_handles.second; tensorflow::Tensor handles_tensor( tensorflow::DT_INT64, - tensorflow::TensorShape({session_handles.size()})); + tensorflow::TensorShape({static_cast(session_handles.size())})); auto flat_handles_tensor = handles_tensor.flat(); for (size_t i = 0; i < session_handles.size(); ++i) { flat_handles_tensor(i) = session_handles[i].handle; diff --git a/third_party/xla_client/xrt_computation_client.h b/third_party/xla_client/xrt_computation_client.h index 7ffa447c38b..620d5a7bd26 100644 --- a/third_party/xla_client/xrt_computation_client.h +++ b/third_party/xla_client/xrt_computation_client.h @@ -14,16 +14,6 @@ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/mesh_service.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/triggered_task.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "tensorflow/compiler/xla/xla_client/xrt_local_service.h" -#include "tensorflow/compiler/xla/xla_client/xrt_session.h" -#include "tensorflow/compiler/xla/xla_client/xrt_session_cache.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" @@ -31,6 +21,16 @@ #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/mesh_service.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/triggered_task.h" +#include "third_party/xla_client/util.h" +#include "third_party/xla_client/xrt_local_service.h" +#include "third_party/xla_client/xrt_session.h" +#include "third_party/xla_client/xrt_session_cache.h" namespace xla { diff --git a/third_party/xla_client/xrt_local_service.cc b/third_party/xla_client/xrt_local_service.cc index 5095575fd03..0e46052532c 100644 --- a/third_party/xla_client/xrt_local_service.cc +++ b/third_party/xla_client/xrt_local_service.cc @@ -1,15 +1,15 @@ -#include "tensorflow/compiler/xla/xla_client/xrt_local_service.h" +#include "third_party/xla_client/xrt_local_service.h" #include #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" namespace xla { namespace { diff --git a/third_party/xla_client/xrt_local_service.h b/third_party/xla_client/xrt_local_service.h index 44b695bd44f..2168894eaf2 100644 --- a/third_party/xla_client/xrt_local_service.h +++ b/third_party/xla_client/xrt_local_service.h @@ -5,8 +5,8 @@ #include #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "third_party/xla_client/debug_macros.h" namespace xla { diff --git a/third_party/xla_client/xrt_session.cc b/third_party/xla_client/xrt_session.cc index 47d88cae4ea..f134f9d2bb8 100644 --- a/third_party/xla_client/xrt_session.cc +++ b/third_party/xla_client/xrt_session.cc @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/xrt_session.h" +#include "third_party/xla_client/xrt_session.h" #include "absl/strings/str_cat.h" diff --git a/third_party/xla_client/xrt_session.h b/third_party/xla_client/xrt_session.h index 71ade20eac1..186bab549c0 100644 --- a/third_party/xla_client/xrt_session.h +++ b/third_party/xla_client/xrt_session.h @@ -13,7 +13,7 @@ #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" namespace xla { diff --git a/third_party/xla_client/xrt_session_cache.cc b/third_party/xla_client/xrt_session_cache.cc index 6abda1f28ab..70ec609bc61 100644 --- a/third_party/xla_client/xrt_session_cache.cc +++ b/third_party/xla_client/xrt_session_cache.cc @@ -1,7 +1,7 @@ -#include "tensorflow/compiler/xla/xla_client/xrt_session_cache.h" +#include "third_party/xla_client/xrt_session_cache.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/sys_util.h" namespace xla { diff --git a/third_party/xla_client/xrt_session_cache.h b/third_party/xla_client/xrt_session_cache.h index 18567859668..4be9e711eba 100644 --- a/third_party/xla_client/xrt_session_cache.h +++ b/third_party/xla_client/xrt_session_cache.h @@ -10,7 +10,7 @@ #include #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/xrt_session.h" +#include "third_party/xla_client/xrt_session.h" namespace xla { diff --git a/torch_xla/csrc/aten_cpu_fallback.cpp b/torch_xla/csrc/aten_cpu_fallback.cpp index f9b08020dd3..6bd8bbaf219 100644 --- a/torch_xla/csrc/aten_cpu_fallback.cpp +++ b/torch_xla/csrc/aten_cpu_fallback.cpp @@ -1,8 +1,8 @@ #include "torch_xla/csrc/aten_cpu_fallback.h" -#include -#include -#include +#include +#include +#include #include #include diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 5cd8e890b52..54d4f810027 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -5,8 +5,8 @@ #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/tensor_impl.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5f40025d247..df76e6dff6c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -9,10 +9,10 @@ #include -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/shape_inference.h" #include "torch/csrc/lazy/core/tensor_util.h" #include "torch/csrc/lazy/core/util.h" diff --git a/torch_xla/csrc/computation.cpp b/torch_xla/csrc/computation.cpp index efb6d1a5c8a..64f3f90c44c 100644 --- a/torch_xla/csrc/computation.cpp +++ b/torch_xla/csrc/computation.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/computation.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" namespace torch_xla { diff --git a/torch_xla/csrc/computation.h b/torch_xla/csrc/computation.h index bf923c6ecf3..d5ff2f2370b 100644 --- a/torch_xla/csrc/computation.h +++ b/torch_xla/csrc/computation.h @@ -6,9 +6,9 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/types.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/types.h" #include "torch/csrc/lazy/backend/lowering_context.h" #include "torch/csrc/lazy/core/hash.h" diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index 9e2261fd9f7..58096a0d3b5 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -6,7 +6,7 @@ #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index c6c9ac251a0..a125ff686e3 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -2,7 +2,6 @@ #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "third_party/xla_client/debug_macros.h" diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 109a8ead910..3706c25a311 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -3,8 +3,8 @@ #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index ac0f79723cd..2d14e0b967c 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -9,9 +9,9 @@ #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/tensor_util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 732f9c5a661..11266ea51d5 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -7,9 +7,9 @@ #include "absl/memory/memory.h" #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/unique.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/unique.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/python/python_util.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index e4484802e4a..02867f764b1 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -3,8 +3,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" namespace torch_xla { namespace { diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index d63bc86b0e4..5a31b19d283 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -3,7 +3,7 @@ #include #include -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/backend/backend_device.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/util.h" diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index c1c1e155319..67045fc7881 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/function_call_tracker.cpp b/torch_xla/csrc/function_call_tracker.cpp index 3b918b4580c..887e38c6cde 100644 --- a/torch_xla/csrc/function_call_tracker.cpp +++ b/torch_xla/csrc/function_call_tracker.cpp @@ -8,8 +8,8 @@ #include #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/core/platform/stacktrace.h" +#include "third_party/xla_client/sys_util.h" #include "torch/csrc/lazy/python/python_util.h" namespace torch_xla { diff --git a/torch_xla/csrc/generated_file_include.h b/torch_xla/csrc/generated_file_include.h index bec32a984bd..878934b4ac3 100644 --- a/torch_xla/csrc/generated_file_include.h +++ b/torch_xla/csrc/generated_file_include.h @@ -1,5 +1,5 @@ -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" #include "torch/csrc/lazy/core/shape.h" #include "torch_xla/csrc/aten_cpu_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 665a0c5aadf..674922f16e9 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -6,10 +6,10 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 7b2b142251a..cb96261c834 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -13,9 +13,9 @@ #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/shape.h" #include "torch/csrc/lazy/core/util.h" diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5c0f37b217d..1e0d74db66c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -22,23 +22,23 @@ #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "tensorflow/compiler/xla/python/profiler/internal/traceme_wrapper.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/mesh_service.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/metrics_analysis.h" -#include "tensorflow/compiler/xla/xla_client/metrics_reader.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/profiler.h" -#include "tensorflow/compiler/xla/xla_client/record_reader.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/python/profiler/internal/profiler_pywrap_impl.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/mesh_service.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/metrics_analysis.h" +#include "third_party/xla_client/metrics_reader.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/profiler.h" +#include "third_party/xla_client/record_reader.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/util.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/python/pybind.h" diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 527feee80c5..666f8ff3ea7 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -4,9 +4,9 @@ #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" #include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/ir_metadata.h" diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 5d9521cd475..c4eedb2881f 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -15,8 +15,8 @@ #include "absl/hash/hash.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_client/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "third_party/xla_client/types.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/ir.h" #include "torch/csrc/lazy/core/ir_builder.h" diff --git a/torch_xla/csrc/ir_builder.h b/torch_xla/csrc/ir_builder.h index 5f2254a9f0d..35096e6f603 100644 --- a/torch_xla/csrc/ir_builder.h +++ b/torch_xla/csrc/ir_builder.h @@ -1,4 +1,4 @@ -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch/csrc/lazy/core/ir.h" #include "torch/csrc/lazy/core/ir_builder.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 42adda26219..7aa4657b912 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -6,8 +6,8 @@ #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ir_util.cpp b/torch_xla/csrc/ir_util.cpp index 6defffb1a16..c03c0a3e322 100644 --- a/torch_xla/csrc/ir_util.cpp +++ b/torch_xla/csrc/ir_util.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ir_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" namespace torch_xla { diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 63eb77eb103..cece37f6a99 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -10,10 +10,10 @@ #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" namespace torch_xla { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index aac2f2d6123..420891d18cf 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -7,8 +7,8 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" #include "torch/csrc/lazy/core/ir_metadata.h" #include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 0259f792ed1..7d539dbd950 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -9,8 +9,8 @@ #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/core/platform/macros.h" +#include "third_party/xla_client/computation_client.h" #include "torch/csrc/lazy/backend/backend_data.h" #include "torch/csrc/lazy/backend/lowering_context.h" #include "torch/csrc/lazy/core/ir_util.h" diff --git a/torch_xla/csrc/nms_op.cpp b/torch_xla/csrc/nms_op.cpp index 9019d1a5bfc..cca328ff4a6 100644 --- a/torch_xla/csrc/nms_op.cpp +++ b/torch_xla/csrc/nms_op.cpp @@ -8,8 +8,8 @@ #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index 07f7b02b468..5e154b05de7 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -5,10 +5,10 @@ #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/op_by_op_executor.h b/torch_xla/csrc/op_by_op_executor.h index 584cd57714b..bab3102de85 100644 --- a/torch_xla/csrc/op_by_op_executor.h +++ b/torch_xla/csrc/op_by_op_executor.h @@ -5,10 +5,10 @@ #include "absl/types/span.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/async_task.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/async_task.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/ir.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp index d47896bf908..33f56adf6bf 100644 --- a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/adaptive_max_pool2d.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp index b3f3ea26e6e..ff08c4de830 100644 --- a/torch_xla/csrc/ops/all_gather.cpp +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index 2da8bf45cf3..3129ef87e6d 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/as_strided.cpp b/torch_xla/csrc/ops/as_strided.cpp index 4151e8f1ac1..03e19496785 100644 --- a/torch_xla/csrc/ops/as_strided.cpp +++ b/torch_xla/csrc/ops/as_strided.cpp @@ -4,7 +4,7 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/ops/as_strided_view_update.cpp b/torch_xla/csrc/ops/as_strided_view_update.cpp index cfaef1f3b58..92e9055cdce 100644 --- a/torch_xla/csrc/ops/as_strided_view_update.cpp +++ b/torch_xla/csrc/ops/as_strided_view_update.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/as_strided_view_update.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index 13e65b5116b..1c9938a5c3c 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/avg_pool_nd.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index 746deff8247..32b4a533eab 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/avg_pool_nd_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/constant_pad_nd.cpp b/torch_xla/csrc/ops/constant_pad_nd.cpp index 2617a54bb39..2a6c6338f34 100644 --- a/torch_xla/csrc/ops/constant_pad_nd.cpp +++ b/torch_xla/csrc/ops/constant_pad_nd.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index 3b033327e4d..f37a4cc6a17 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/convolution_backward_overrideable.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index 1376cc7252e..91f34beab74 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/convolution_overrideable.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/device_data.h b/torch_xla/csrc/ops/device_data.h index 987511c85bf..e9ed59a3477 100644 --- a/torch_xla/csrc/ops/device_data.h +++ b/torch_xla/csrc/ops/device_data.h @@ -1,6 +1,6 @@ #pragma once -#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "third_party/xla_client/computation_client.h" #include "torch/csrc/lazy/backend/backend_data.h" #include "torch_xla/csrc/ir.h" diff --git a/torch_xla/csrc/ops/diagonal.cpp b/torch_xla/csrc/ops/diagonal.cpp index 615ad47a3c9..a8ec6f7a1fb 100644 --- a/torch_xla/csrc/ops/diagonal.cpp +++ b/torch_xla/csrc/ops/diagonal.cpp @@ -3,7 +3,7 @@ #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/discrete_uniform.cpp b/torch_xla/csrc/ops/discrete_uniform.cpp index 529a4c45bd1..83fec55ffe5 100644 --- a/torch_xla/csrc/ops/discrete_uniform.cpp +++ b/torch_xla/csrc/ops/discrete_uniform.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/discrete_uniform.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" diff --git a/torch_xla/csrc/ops/dynamic_ir.cpp b/torch_xla/csrc/ops/dynamic_ir.cpp index 32dd7ce6c62..653cf8ad8c1 100644 --- a/torch_xla/csrc/ops/dynamic_ir.cpp +++ b/torch_xla/csrc/ops/dynamic_ir.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/dynamic_ir.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/tensor.h" diff --git a/torch_xla/csrc/ops/einsum_utilities.h b/torch_xla/csrc/ops/einsum_utilities.h index 90a8e2100f2..c22a159cf35 100644 --- a/torch_xla/csrc/ops/einsum_utilities.h +++ b/torch_xla/csrc/ops/einsum_utilities.h @@ -3,7 +3,7 @@ #include #include -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/expand_symint.cpp b/torch_xla/csrc/ops/expand_symint.cpp index 34ec1dae920..8427b189c3e 100644 --- a/torch_xla/csrc/ops/expand_symint.cpp +++ b/torch_xla/csrc/ops/expand_symint.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/data_ops.h" diff --git a/torch_xla/csrc/ops/index_get.cpp b/torch_xla/csrc/ops/index_get.cpp index ae0dc7415d4..60f90e5821d 100644 --- a/torch_xla/csrc/ops/index_get.cpp +++ b/torch_xla/csrc/ops/index_get.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/index_get.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 5aabf622b26..3243024d3ce 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -4,8 +4,8 @@ #include #include "tensorflow/compiler/xla/permutation_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index df54248d60f..b126affbc25 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/log_softmax_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 5973f2392a5..c8707fbdee8 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/max_pool_nd.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index 8e1004c1705..89a37ee740b 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/max_pool_nd_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp index af4c1dad87f..101f3e68c4f 100644 --- a/torch_xla/csrc/ops/max_unpool_nd.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/max_unpool_nd.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp index a45d97228cf..676e9fbb05a 100644 --- a/torch_xla/csrc/ops/mse_loss.cpp +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -2,8 +2,8 @@ #include -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp index 51edf2312f4..04db23fc11a 100644 --- a/torch_xla/csrc/ops/mse_loss_backward.cpp +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/mse_loss_backward.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_backward.cpp b/torch_xla/csrc/ops/native_batch_norm_backward.cpp index 06dfc6dcfd5..2c9ae91b812 100644 --- a/torch_xla/csrc/ops/native_batch_norm_backward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/native_batch_norm_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index 3b808ccbc17..5d07576aa3a 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/native_batch_norm_forward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/nll_loss.cpp b/torch_xla/csrc/ops/nll_loss.cpp index 6d95276cf15..919415bd1d7 100644 --- a/torch_xla/csrc/ops/nll_loss.cpp +++ b/torch_xla/csrc/ops/nll_loss.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nll_loss.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nll_loss.h" diff --git a/torch_xla/csrc/ops/nll_loss2d.cpp b/torch_xla/csrc/ops/nll_loss2d.cpp index ba0ebb95943..81bb826f733 100644 --- a/torch_xla/csrc/ops/nll_loss2d.cpp +++ b/torch_xla/csrc/ops/nll_loss2d.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nll_loss2d.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nll_loss.h" diff --git a/torch_xla/csrc/ops/nll_loss2d_backward.cpp b/torch_xla/csrc/ops/nll_loss2d_backward.cpp index 05b91b4c149..283c8341519 100644 --- a/torch_xla/csrc/ops/nll_loss2d_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss2d_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nll_loss2d_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nll_loss.h" diff --git a/torch_xla/csrc/ops/nll_loss_backward.cpp b/torch_xla/csrc/ops/nll_loss_backward.cpp index 02b7cd34f3d..f2b9c1917ca 100644 --- a/torch_xla/csrc/ops/nll_loss_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nll_loss_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nll_loss.h" diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index 40e4e30f8f9..26193801df6 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/nms.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nms_op.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/not_supported.cpp b/torch_xla/csrc/ops/not_supported.cpp index 4757a2e94e2..107dce9dd77 100644 --- a/torch_xla/csrc/ops/not_supported.cpp +++ b/torch_xla/csrc/ops/not_supported.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/not_supported.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 1c44e6d0d37..3b2678d3f00 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -6,8 +6,8 @@ #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 15c9c43a3e4..fa8ae4c08d0 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/client/lib/logdet.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/permute.cpp b/torch_xla/csrc/ops/permute.cpp index 73bcff5274c..4022ca559e1 100644 --- a/torch_xla/csrc/ops/permute.cpp +++ b/torch_xla/csrc/ops/permute.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/permute.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/recv.cpp b/torch_xla/csrc/ops/recv.cpp index 6cfe2bedcf7..7847a5f208b 100644 --- a/torch_xla/csrc/ops/recv.cpp +++ b/torch_xla/csrc/ops/recv.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/recv.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/repeat.cpp b/torch_xla/csrc/ops/repeat.cpp index 529d2ebcf8e..ebe0e7f2317 100644 --- a/torch_xla/csrc/ops/repeat.cpp +++ b/torch_xla/csrc/ops/repeat.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/repeat.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/replication_pad.cpp b/torch_xla/csrc/ops/replication_pad.cpp index cd60bfc150a..9e48fbe5ae4 100644 --- a/torch_xla/csrc/ops/replication_pad.cpp +++ b/torch_xla/csrc/ops/replication_pad.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/replication_pad.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/resize.cpp b/torch_xla/csrc/ops/resize.cpp index 1b3b1cf4457..7ce264c2c3a 100644 --- a/torch_xla/csrc/ops/resize.cpp +++ b/torch_xla/csrc/ops/resize.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/resize.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/scalar.cpp b/torch_xla/csrc/ops/scalar.cpp index ba26f13fb72..01f2f7666fb 100644 --- a/torch_xla/csrc/ops/scalar.cpp +++ b/torch_xla/csrc/ops/scalar.cpp @@ -4,7 +4,7 @@ #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/scalar.h b/torch_xla/csrc/ops/scalar.h index 1c67cf6de8d..7d485100314 100644 --- a/torch_xla/csrc/ops/scalar.h +++ b/torch_xla/csrc/ops/scalar.h @@ -5,7 +5,7 @@ #include -#include "tensorflow/compiler/xla/xla_client/types.h" +#include "third_party/xla_client/types.h" #include "torch_xla/csrc/ir.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index 9b9a21f3e8b..6b8b3039627 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/select.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/send.cpp b/torch_xla/csrc/ops/send.cpp index 5aa7fa3ca1b..4013a427491 100644 --- a/torch_xla/csrc/ops/send.cpp +++ b/torch_xla/csrc/ops/send.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/send.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index 565fb9c2893..7900386eb52 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/softmax_backward.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/split.cpp b/torch_xla/csrc/ops/split.cpp index 20c58591467..6bf922b79b3 100644 --- a/torch_xla/csrc/ops/split.cpp +++ b/torch_xla/csrc/ops/split.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/split.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index 7685783cd72..0a1adba0793 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/squeeze.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index 1b2910f3334..0a931531081 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -3,7 +3,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/svd.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/ops/uniform.cpp b/torch_xla/csrc/ops/uniform.cpp index 0b0d1e84208..a7482d3a4e6 100644 --- a/torch_xla/csrc/ops/uniform.cpp +++ b/torch_xla/csrc/ops/uniform.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/uniform.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" diff --git a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp index 181fa08650a..e670a8623aa 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/upsample_bilinear2d_backward.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp index 85b61b3b4a1..61ec3560fb6 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/upsample_nearest2d_backward.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index 7b955787bbf..0ec4e04ff5a 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -5,8 +5,8 @@ #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/tensor_util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/data_ops.h" diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 7083426e74b..a1203a6f690 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -5,8 +5,8 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/prng.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index b774fa58c75..9558d6b3197 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -9,7 +9,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 0ebfe06be89..95c19f92b1f 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -4,8 +4,8 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/sys_util.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/shape_builder.h" diff --git a/torch_xla/csrc/softmax_builder.cpp b/torch_xla/csrc/softmax_builder.cpp index 2d6f7ba5e21..f839effe16a 100644 --- a/torch_xla/csrc/softmax_builder.cpp +++ b/torch_xla/csrc/softmax_builder.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/softmax_builder.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/helpers.h" namespace torch_xla { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 964a3f7493b..a3abf55551f 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -16,15 +16,15 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/unique.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/unique.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/helpers.h" diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 58f0d350e87..05b7b6f83c0 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -9,11 +9,11 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/async_task.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/async_task.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/computation.h" diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 96963e547f2..cf74443f781 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -4,8 +4,8 @@ #include #include -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/debug_macros.h" #include "torch/csrc/lazy/backend/backend_interface.h" #include "torch/csrc/lazy/core/tensor.h" #include "torch/csrc/lazy/core/tensor_util.h" diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index fe9fe41e384..3f450cfe7c6 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -8,10 +8,10 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/metrics.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/util.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 61c99fb12a6..a7df4202d19 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/tensor_ops.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 901e37ce551..55a6aa4aac9 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -12,14 +12,14 @@ #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/tf_logging.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/tf_logging.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/util.h" +#include "third_party/xla_client/xrt_computation_client.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index ad12be9f88d..4b71ff1a7c4 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -7,7 +7,7 @@ #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "third_party/xla_client/computation_client.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/token_handler.cpp b/torch_xla/csrc/token_handler.cpp index d0cb89138fa..a147c0fbbd6 100644 --- a/torch_xla/csrc/token_handler.cpp +++ b/torch_xla/csrc/token_handler.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "third_party/xla_client/sys_util.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index ba85518f93b..1f1db4a934f 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/torch_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/constant.h" #include "torch_xla/csrc/tensor.h" diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index c3a03509799..0ab1b49d993 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -5,7 +5,7 @@ #include #include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch/csrc/lazy/core/dynamic_ir.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/tensor.h" diff --git a/torch_xla/csrc/view.cpp b/torch_xla/csrc/view.cpp index 5f2aab16370..86099d988d8 100644 --- a/torch_xla/csrc/view.cpp +++ b/torch_xla/csrc/view.cpp @@ -6,8 +6,8 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" diff --git a/torch_xla/csrc/xla_backend_impl.h b/torch_xla/csrc/xla_backend_impl.h index 2131cb81758..6d7d5b1d726 100644 --- a/torch_xla/csrc/xla_backend_impl.h +++ b/torch_xla/csrc/xla_backend_impl.h @@ -3,7 +3,7 @@ #include #include -#include "tensorflow/compiler/xla/xla_client/computation_client.h" +#include "third_party/xla_client/computation_client.h" #include "torch/csrc/lazy/backend/backend_interface.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index f10c5d75365..2e9a10ecc36 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -18,15 +18,15 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/env_vars.h" -#include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/thread_pool.h" -#include "tensorflow/compiler/xla/xla_client/unique.h" -#include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/env_vars.h" +#include "third_party/xla_client/sys_util.h" +#include "third_party/xla_client/thread_pool.h" +#include "third_party/xla_client/unique.h" +#include "third_party/xla_client/xla_util.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/hash.h" #include "torch/csrc/lazy/core/helpers.h" diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index b198a49b84d..b834a175523 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -9,11 +9,11 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_client/async_task.h" -#include "tensorflow/compiler/xla/xla_client/cache.h" -#include "tensorflow/compiler/xla/xla_client/computation_client.h" -#include "tensorflow/compiler/xla/xla_client/multi_wait.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/async_task.h" +#include "third_party/xla_client/cache.h" +#include "third_party/xla_client/computation_client.h" +#include "third_party/xla_client/multi_wait.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/computation.h" diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 463e5f0e271..15f55336ecd 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -11,8 +11,8 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/dnn.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" +#include "third_party/xla_client/debug_macros.h" +#include "third_party/xla_client/util.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/convert_ops.h" diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index a1e43cdcf7a..10fe8726bbb 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -9,7 +9,7 @@ #include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "third_party/xla_client/debug_macros.h" #include "torch_xla/csrc/computation.h" #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h"