From 023d763bb274eec69a86726a3f1350d029df86ae Mon Sep 17 00:00:00 2001 From: stgpetrovic Date: Tue, 21 Feb 2023 23:51:45 +0100 Subject: [PATCH] Use GCC and remove sandboxing mechanisms. (#4658) * Use GCC and remove sandboxing mechanisms. * Split `libxla_computation_client.so` into libraries. (#4659) * Split `libxla_computation_client.so` into libraries. To do this split, reworked the metrics analysis. Currently, metrics analysis would get a global singleton of the computation client to get the metrics. This change switches to injection, so the python bindings init can use the singleton to pass the metrics down to the analysis, removing the dependency from the analysis to the whole client. Add some tests. * Remove running tests; they are not cached and are slow. They are not run anyway as is. --- .bazelrc | 25 +- .vscode/settings.json | 12 +- WORKSPACE | 2 +- build_torch_xla_libs.sh | 16 +- docker/Dockerfile | 4 - docker/experimental/Dockerfile | 4 - third_party/xla_client/BUILD | 503 +++++++++++++++++---- third_party/xla_client/metrics_analysis.cc | 15 +- third_party/xla_client/metrics_analysis.h | 12 +- third_party/xla_client/metrics_reader.cc | 12 +- third_party/xla_client/metrics_reader.h | 7 +- third_party/xla_client/util_test.cc | 35 ++ third_party/xla_client/xla_util_test.cc | 119 +++++ torch_xla/csrc/init_python_bindings.cpp | 6 +- 14 files changed, 617 insertions(+), 155 deletions(-) create mode 100644 third_party/xla_client/xla_util_test.cc diff --git a/.bazelrc b/.bazelrc index 0d3ed8c228e..0468b5ada94 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,9 +10,6 @@ build --announce_rc # TODO(goranpetrovic): figure out visibility of tensorflow libraries. build --nocheck_visibility -# We can set this to `standalone` after https://github.com/bazelbuild/bazel/issues/15359 is resolved. -build --spawn_strategy=sandboxed - build --enable_platform_specific_config build --experimental_cc_shared_library @@ -20,6 +17,9 @@ build --experimental_cc_shared_library # Disable enabled-by-default TensorFlow features that we don't care about. build --define=no_aws_support=true build --define=no_hdfs_support=true +build --define=no_hdfs_support=true +build --define=no_kafka_support=true +build --define=no_ignite_support=true build --define=grpc_no_ares=true @@ -27,6 +27,11 @@ build -c opt build --config=short_logs +# Force GCC because clang/bazel has issues. +common --action_env=CC=gcc +common --action_env=CXX=g++ +common --spawn_strategy=standalone + ########################################################################### build:posix --copt=-Wno-sign-compare @@ -78,13 +83,12 @@ try-import %workspace%/.bazelrc.user # Compile database generation config. build:compdb --features=-layering_check -# Test requires Java. -test --java_runtime_version=remotejdk_11 +# Compiling tests requires Java. +common --java_runtime_version=remotejdk_11 # Coverage requires Java and GCC. coverage --config=coverage coverage --build_tests_only -build:coverage --java_runtime_version=remotejdk_11 build:coverage --copt=-DNDEBUG build:coverage --combined_report=lcov build:coverage --strategy=TestRunner=sandboxed,local @@ -92,8 +96,6 @@ build:coverage --strategy=CoverageReport=sandboxed,local build:coverage --experimental_use_llvm_covmap build:coverage --collect_code_coverage build:coverage --test_tag_filters=-nocoverage -build:coverage --action_env=CC=gcc -build:coverage --action_env=CXX=g++ ############################################################################ ############## TensorFlow .bazelrc greatest hits ########################### @@ -114,9 +116,10 @@ build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include build:linux --config=dynamic_kernels # For projects which use TensorFlow as part of a Bazel build process, putting -# nothing in a bazelrc will default to a monolithic build. The following line -# opts in to modular op registration support by default. -build --define framework_shared_object=true +# nothing in a bazelrc will default to a monolithic build. Here we force +# the monolitih build because otherwise there are missing dependencies and +# linking fails. +build --define framework_shared_object=false build --define tsl_protobuf_header_only=true build --define=use_fast_cpp_protos=true diff --git a/.vscode/settings.json b/.vscode/settings.json index e926c7af017..127a3c9ad79 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,6 @@ { "bsv.bazel.buildFlags": [ "--config=compdb", - "--sandbox_base=/dev/shm", ], "bsv.cc.compdb.targets": [ "//third_party/xla_client/...", @@ -10,9 +9,12 @@ "coverage-gutters.showLineCoverage": false, "coverage-gutters.showGutterCoverage": true, "coverage-gutters.coverageReportFileName": "./genhtml/index.html", - "coverage-gutters.coverageFileNames": [ "./bazel-out/_coverage/_coverage_report.dat" ], - "lcov.path": [ "./.bazel-out/_coverage/_coverage_report.dat"], - + "coverage-gutters.coverageFileNames": [ + "./bazel-out/_coverage/_coverage_report.dat" + ], + "lcov.path": [ + "./.bazel-out/_coverage/_coverage_report.dat" + ], "python.formatting.provider": "yapf", "editor.formatOnSave": true -} +} \ No newline at end of file diff --git a/WORKSPACE b/WORKSPACE index ae695ab99d7..c8bd6bc7900 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,8 +18,8 @@ http_archive( "//tf_patches:cudnn_int8x32.diff", "//tf_patches:f16_abi_clang.diff", "//tf_patches:gpu_race_condition.diff", - "//tf_patches:stream_executor.diff", "//tf_patches:grpc_version.diff", + "//tf_patches:stream_executor.diff", "//tf_patches:thread_local_random.diff", "//tf_patches:xplane.diff", ], diff --git a/build_torch_xla_libs.sh b/build_torch_xla_libs.sh index f88fbe4dc0f..e1051e4930c 100755 --- a/build_torch_xla_libs.sh +++ b/build_torch_xla_libs.sh @@ -34,20 +34,6 @@ if [[ "$XLA_BAZEL_VERBOSE" == "1" ]]; then VERBOSE="-s" fi -SANDBOX_BASE="${XLA_SANDBOX_BASE}" -if [ -z "$XLA_SANDBOX_BASE" ]; then - SANDBOX_BASE="/tmp" -fi -if [[ "$XLA_SANDBOX_BUILD" == "1" ]]; then - BUILD_STRATEGY="sandboxed --sandbox_base=${SANDBOX_BASE}" -else - # We can remove this after https://github.com/bazelbuild/bazel/issues/15359 is resolved - # Use GCC locally since clang does not work except with sanboxing, and sandboxing causes pjrt crashes. - unset CXX - unset CC - BUILD_STRATEGY="local" -fi - if [[ "$TPUVM_MODE" == "1" ]]; then OPTS+=(--config=tpu) fi @@ -73,7 +59,7 @@ fi # TensorFlow and its dependencies may introduce warning flags from newer compilers # that PyTorch and PyTorch/XLA's default compilers don't recognize. They become error # while '-Werror' is used. Therefore, surpress the warnings in .bazelrc or here. -bazel build $MAX_JOBS $VERBOSE --spawn_strategy=$BUILD_STRATEGY --show_progress_rate_limit=20 \ +bazel build $MAX_JOBS $VERBOSE --show_progress_rate_limit=20 \ --define framework_shared_object=false -c "$MODE" "${OPTS[@]}" \ $XLA_CUDA_CFG //third_party/xla_client:libxla_computation_client.so diff --git a/docker/Dockerfile b/docker/Dockerfile index b015fc62b22..0911701916b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -39,10 +39,6 @@ ENV BUNDLE_LIBTPU "${tpuvm}" # Maximum number of jobs to use for bazel build ENV BAZEL_JOBS "${bazel_jobs}" -# This makes the bazel build behave more consistently, but runs slower. -ENV XLA_SANDBOX_BUILD "0" -ENV XLA_SANDBOX_BASE "/dev/shm" - # To get around issue of Cloud Build with recursive submodule update # clone recursively from pytorch/xla if building docker image with # cloud build. Otherwise, just use local. diff --git a/docker/experimental/Dockerfile b/docker/experimental/Dockerfile index a9f9b0e86df..c236cc8c491 100644 --- a/docker/experimental/Dockerfile +++ b/docker/experimental/Dockerfile @@ -75,10 +75,6 @@ COPY .bazelversion . COPY WORKSPACE . COPY build_torch_xla_libs.sh . -# TODO: Remove this when it's not required anymore -ENV XLA_SANDBOX_BUILD=0 -ENV XLA_SANDBOX_BASE "/dev/shm" - COPY torch_xla/ torch_xla/ COPY setup.py . COPY xla_native_functions.yaml . diff --git a/third_party/xla_client/BUILD b/third_party/xla_client/BUILD index fb1023bce8f..b3b4950bd4f 100644 --- a/third_party/xla_client/BUILD +++ b/third_party/xla_client/BUILD @@ -14,14 +14,12 @@ load( licenses(["notice"]) # Apache 2.0 -package(default_visibility = ["@org_tensorflow//tensorflow:internal"]) +package(default_visibility = ["//visibility:public"]) -exports_files( - [ - "tf_version_script.lds", - "tf_exported_symbols.lds", - ], -) +exports_files([ + "tf_version_script.lds", + "tf_exported_symbols.lds", +]) tf_proto_library_cc( name = "mesh_service_proto", @@ -34,94 +32,60 @@ tf_proto_library_cc( ], ) -tf_cc_shared_object( - name = "libxla_computation_client.so", - linkopts = select({ - "@org_tensorflow//tensorflow:windows": [], - "//conditions:default": [ - "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location :tf_version_script.lds)", - ], - }), - visibility = ["//visibility:public"], +cc_library( + name = "async_task", + hdrs = ["async_task.h"], deps = [ - ":computation_client_impl", - ":tf_exported_symbols.lds", - ":tf_version_script.lds", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/compiler/xla:literal_util", - "@org_tensorflow//tensorflow/compiler/xla/client", - "@org_tensorflow//tensorflow/compiler/xla/client:global_data", - "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", - "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", - "@org_tensorflow//tensorflow/compiler/xla/client/lib:svd", - "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system", - "@org_tensorflow//tensorflow/python/profiler/internal:profiler_pywrap_impl", + ":debug_macros", + ":thread_pool", + "@com_google_absl//absl/types:optional", ], + alwayslink = True, ) cc_library( - name = "computation_client_impl", + name = "computation_client", srcs = [ "computation_client.cc", - "env_vars.cc", - "mesh_service.cc", - "metrics.cc", - "metrics_analysis.cc", - "metrics_reader.cc", - "multi_wait.cc", - "nccl_distributed.cc", "pjrt_computation_client.cc", - "profiler.cc", - "record_reader.cc", - "sys_util.cc", - "tf_logging.cc", - "thread_pool.cc", - "triggered_task.cc", - "util.cc", - "xla_util.cc", "xrt_computation_client.cc", - "xrt_local_service.cc", - "xrt_session.cc", - "xrt_session_cache.cc", ], hdrs = [ - "cache.h", "computation_client.h", - "debug_macros.h", - "env_vars.h", - "mesh_service.h", - "metrics.h", - "metrics_analysis.h", - "metrics_reader.h", - "multi_wait.h", - "nccl_distributed.h", "pjrt_computation_client.h", - "profiler.h", - "record_reader.h", - "sys_util.h", - "tf_logging.h", - "thread_pool.h", - "triggered_task.h", - "types.h", - "unique.h", - "util.h", - "xla_util.h", "xrt_computation_client.h", - "xrt_local_service.h", - "xrt_session.h", - "xrt_session_cache.h", ], deps = [ - ":mesh_service_proto_cc", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/strings", + ":cache", + ":debug_macros", + ":env_vars", + ":mesh_service", + ":metrics_analysis", + ":metrics_reader", + ":metrics", + ":multi_wait", + ":profiler", + ":record_reader", + ":sys_util", + ":tf_logging", + ":thread_pool", + ":triggered_task", + ":types", + ":unique", + ":util", + ":xla_util", + ":xrt_local_service", + ":xrt_session", + ":xrt_session_cache", + "@org_tensorflow//tensorflow:grpc++", "@org_tensorflow//tensorflow/cc:client_session", "@org_tensorflow//tensorflow/cc:scope", "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_device", + "@org_tensorflow//tensorflow/compiler/xla:debug_options_flags", + "@org_tensorflow//tensorflow/compiler/xla:literal", + "@org_tensorflow//tensorflow/compiler/xla:literal_util", + "@org_tensorflow//tensorflow/compiler/xla:shape_util", + "@org_tensorflow//tensorflow/compiler/xla:xla_proto_cc", "@org_tensorflow//tensorflow/compiler/xla/client", "@org_tensorflow//tensorflow/compiler/xla/client/lib:arithmetic", "@org_tensorflow//tensorflow/compiler/xla/client/lib:comparators", @@ -137,51 +101,390 @@ cc_library( "@org_tensorflow//tensorflow/compiler/xla/client/lib:tridiagonal", "@org_tensorflow//tensorflow/compiler/xla/client:global_data", "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", - "@org_tensorflow//tensorflow/compiler/xla/hlo/ir:hlo", - "@org_tensorflow//tensorflow/compiler/xla/pjrt/distributed:distributed", + "@org_tensorflow//tensorflow/compiler/xla/pjrt/distributed", "@org_tensorflow//tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client", - "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:tpu_client", "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", "@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", - "@org_tensorflow//tensorflow/compiler/xla/pjrt:tpu_client", + "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", - "@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner", "@org_tensorflow//tensorflow/compiler/xla/service:cpu_plugin", - "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", - "@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - "@org_tensorflow//tensorflow/compiler/xla:debug_options_flags", - "@org_tensorflow//tensorflow/compiler/xla:literal", - "@org_tensorflow//tensorflow/compiler/xla:literal_util", - "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:statusor", "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", - "@org_tensorflow//tensorflow/compiler/xla:xla_proto_cc", - "@org_tensorflow//tensorflow/compiler/xrt/cc:xrt_ops", + "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", + "@org_tensorflow//tensorflow/compiler/xla/hlo/ir:hlo", + "@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner", "@org_tensorflow//tensorflow/compiler/xrt:xrt_proto_cc", "@org_tensorflow//tensorflow/compiler/xrt:xrt_server", "@org_tensorflow//tensorflow/compiler/xrt:xrt_utils", - "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_runtime", - "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", - "@org_tensorflow//tensorflow/core/kernels:data_flow", - "@org_tensorflow//tensorflow/core/profiler/rpc/client:profiler_client", - "@org_tensorflow//tensorflow/core/profiler/rpc:profiler_server_impl", - "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@org_tensorflow//tensorflow/compiler/xrt/cc:xrt_ops", "@org_tensorflow//tensorflow/core:core_cpu", "@org_tensorflow//tensorflow/core:framework_internal", "@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core:protos_all_cc", - "@org_tensorflow//tensorflow:grpc", - "@org_tensorflow//tensorflow:grpc++", + "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", + "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_runtime", + "@org_tensorflow//tensorflow/core/kernels:data_flow", + "@org_tensorflow//tensorflow/core/profiler/rpc:profiler_server_impl", + "@org_tensorflow//tensorflow/core/profiler/rpc/client:profiler_client", + "@org_tensorflow//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_impl", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/strings", ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_device", "@org_tensorflow//tensorflow/compiler/xla/stream_executor:cuda_platform", ]) + if_with_tpu_support([ "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_device", "@org_tensorflow//tensorflow/compiler/jit:xla_tpu_jit", ]), - alwayslink = 1, + alwayslink = True, +) + +cc_library( + name = "cache", + hdrs = ["cache.h"], + alwayslink = True, +) + +cc_library( + name = "debug_macros", + hdrs = ["debug_macros.h"], + deps = [ + ":tf_logging", + "@org_tensorflow//tensorflow/compiler/xla:statusor", + "@org_tensorflow//tensorflow/core/platform:stacktrace", + ], + alwayslink = True, +) + +cc_library( + name = "env_vars", + srcs = ["env_vars.cc"], + hdrs = ["env_vars.h"], + alwayslink = True, +) + +cc_library( + name = "mesh_service", + srcs = ["mesh_service.cc"], + hdrs = ["mesh_service.h"], + deps = [ + "nccl_distributed", + ":debug_macros", + ":mesh_service_proto_cc", + ":multi_wait", + ":sys_util", + ":thread_pool", + ":util", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/compiler/xla:statusor", + ], + alwayslink = True, +) + +cc_library( + name = "metrics_analysis", + srcs = ["metrics_analysis.cc"], + hdrs = ["metrics_analysis.h"], + deps = [ + ":metrics", + ":tf_logging", + ":types", + "@com_google_absl//absl/types:variant", + ], + alwayslink = True, +) + +cc_library( + name = "metrics_reader", + srcs = ["metrics_reader.cc"], + hdrs = ["metrics_reader.h"], + deps = [ + ":debug_macros", + ":metrics", + ":util", + ], + alwayslink = True, +) + +cc_library( + name = "metrics", + srcs = ["metrics.cc"], + hdrs = ["metrics.h"], + deps = [ + ":debug_macros", + ":sys_util", + ":util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", + ], + alwayslink = True, +) + +cc_library( + name = "multi_wait", + srcs = ["multi_wait.cc"], + hdrs = ["multi_wait.h"], + deps = [ + "@org_tensorflow//tensorflow/compiler/xla:types", + ], + alwayslink = True, +) + +cc_library( + name = "nccl_distributed", + srcs = ["nccl_distributed.cc"], + hdrs = ["nccl_distributed.h"], + deps = [ + ":debug_macros", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/compiler/xla:types", + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]), + alwayslink = True, +) + +cc_library( + name = "profiler", + srcs = ["profiler.cc"], + hdrs = ["profiler.h"], + deps = [ + "@org_tensorflow//tensorflow/core/profiler/rpc:profiler_server_impl", + ], + alwayslink = True, +) + +cc_library( + name = "record_reader", + srcs = ["record_reader.cc"], + hdrs = ["record_reader.h"], + deps = [ + ":debug_macros", + "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/core/lib/core:errors", + "@org_tensorflow//tensorflow/core/lib/io:record_reader", + "@org_tensorflow//tensorflow/core/lib/strings:strcat", + ], + alwayslink = True, +) + +cc_library( + name = "sys_util", + srcs = ["sys_util.cc"], + hdrs = ["sys_util.h"], + deps = [ + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/compiler/xla:types", + ], + alwayslink = True, +) + +cc_test( + name = "sys_util_test", + srcs = ["sys_util_test.cc"], + deps = [ + ":sys_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tf_logging", + srcs = ["tf_logging.cc"], + hdrs = ["tf_logging.h"], + deps = [ + "@org_tensorflow//tensorflow/compiler/xla:statusor", + "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", + ], + alwayslink = True, +) + +cc_library( + name = "thread_pool", + srcs = ["thread_pool.cc"], + hdrs = ["thread_pool.h"], + deps = [ + ":metrics", + ":tf_logging", + ], + alwayslink = True, +) + +cc_library( + name = "triggered_task", + srcs = ["triggered_task.cc"], + hdrs = ["triggered_task.h"], + alwayslink = True, +) + +cc_library( + name = "types", + hdrs = ["types.h"], + deps = [ + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/types:optional", + "@org_tensorflow//tensorflow/compiler/xla:types", + ], + alwayslink = True, +) + +cc_library( + name = "unique", + hdrs = ["unique.h"], + deps = [ + ":debug_macros", + "@com_google_absl//absl/types:optional", + ], + alwayslink = True, +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":types", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/compiler/xla:statusor", + "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/core:lib", + ], + alwayslink = True, +) + +cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":util", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "xla_util", + srcs = ["xla_util.cc"], + hdrs = ["xla_util.h"], + deps = [ + ":metrics", + ":sys_util", + ":tf_logging", + ":types", + ":util", + ":xrt_session", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/compiler/xla:shape_util", + "@org_tensorflow//tensorflow/compiler/xla:status_macros", + "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", + "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", + "@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner", + "@org_tensorflow//tensorflow/core/lib/core:errors", + ], + alwayslink = True, +) + +cc_test( + name = "xla_util_test", + srcs = ["xla_util_test.cc"], + deps = [ + ":xla_util", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/compiler/xla:shape_util", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/tsl/lib/core:status_test_util", + "@org_tensorflow//tensorflow/tsl/platform:errors", + "@org_tensorflow//tensorflow/tsl/platform:status_matchers", + ], +) + +cc_library( + name = "xrt_local_service", + srcs = ["xrt_local_service.cc"], + hdrs = ["xrt_local_service.h"], + deps = [ + ":debug_macros", + ":xrt_session", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor/tpu:tpu_initializer_helper", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", + "@org_tensorflow//tensorflow/core/lib/core:errors", + "@org_tensorflow//tensorflow/core/lib/core:status", + ], + alwayslink = True, +) + +cc_library( + name = "xrt_session_cache", + srcs = ["xrt_session_cache.cc"], + hdrs = ["xrt_session_cache.h"], + deps = [ + ":metrics", + ":sys_util", + ":xrt_session", + "@org_tensorflow//tensorflow/compiler/xla:types", + ], + alwayslink = True, +) + +cc_library( + name = "xrt_session", + srcs = ["xrt_session.cc"], + hdrs = ["xrt_session.h"], + deps = [ + ":debug_macros", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/cc:cc_ops", + "@org_tensorflow//tensorflow/cc:client_session", + "@org_tensorflow//tensorflow/cc:scope", + "@org_tensorflow//tensorflow/compiler/xla:types", + ], + alwayslink = True, +) + +tf_cc_shared_object( + name = "libxla_computation_client.so", + linkopts = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location :tf_version_script.lds)", + ], + }), + visibility = ["//visibility:public"], + deps = [ + ":computation_client", + ":tf_exported_symbols.lds", + ":tf_version_script.lds", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/compiler/xla:literal_util", + "@org_tensorflow//tensorflow/compiler/xla/client", + "@org_tensorflow//tensorflow/compiler/xla/client:global_data", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", + "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/compiler/xla/client/lib:svd", + "@org_tensorflow//tensorflow/compiler/xla/rpc:grpc_stub", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system", + "@org_tensorflow//tensorflow/python/profiler/internal:profiler_pywrap_impl", + ], ) # TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. @@ -189,7 +492,7 @@ cc_library( # name = "pjrt_computation_client_test", # srcs = ["pjrt_computation_client_test.cc"], # deps = [ -# ":computation_client_impl", +# ":computation_client", # "@org_tensorflow//tensorflow/compiler/xla:literal", # "@org_tensorflow//tensorflow/compiler/xla:literal_util", # "@org_tensorflow//tensorflow/compiler/xla:shape_util", diff --git a/third_party/xla_client/metrics_analysis.cc b/third_party/xla_client/metrics_analysis.cc index 93e0181ec27..13176683eba 100644 --- a/third_party/xla_client/metrics_analysis.cc +++ b/third_party/xla_client/metrics_analysis.cc @@ -1,7 +1,8 @@ #include "third_party/xla_client/metrics_analysis.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" #include "absl/types/variant.h" -#include "third_party/xla_client/computation_client.h" #include "third_party/xla_client/metrics.h" #include "third_party/xla_client/tf_logging.h" #include "third_party/xla_client/types.h" @@ -96,6 +97,10 @@ class XrtMetricFrequency : public Analyzer { counter_(0) {} Analysis Run() override { + LOG(FATAL) << "For XrtMetricFrequency, use the metrics overload"; + } + + Analysis Run(const std::map& xrt_metrics) override { // XRT GetMetrics call is relatively expensive. if (counter_++ != run_every_n_) { return {Analysis::Symptom::kNormal}; @@ -108,8 +113,7 @@ class XrtMetricFrequency : public Analyzer { std::stringstream ss; int64_t step_count = step->Value(); - auto xrt_metrics = ComputationClient::Get()->GetMetrics(); - for (auto const& kv : metric_name_thresholds_) { + for (const auto& kv : metric_name_thresholds_) { auto it = xrt_metrics.find(kv.first); if (it == xrt_metrics.end()) { continue; @@ -193,11 +197,12 @@ std::vector* GetAnalyzers() { } // namespace -std::string CreatePerformanceReport() { +std::string CreatePerformanceReport( + const std::map& xrt_metrics) { std::stringstream ss; std::vector* analyzers = GetAnalyzers(); for (auto const& analyzer : *analyzers) { - Analysis result = analyzer->Run(); + Analysis result = analyzer->Run(xrt_metrics); if (result.symptom != Analysis::Symptom::kNormal) { ss << result.repr << std::endl; } diff --git a/third_party/xla_client/metrics_analysis.h b/third_party/xla_client/metrics_analysis.h index e910b1bf1c6..846920a3798 100644 --- a/third_party/xla_client/metrics_analysis.h +++ b/third_party/xla_client/metrics_analysis.h @@ -2,9 +2,13 @@ #define XLA_CLIENT_METRICS_ANALYSIS_H_ #include +#include #include +#include #include +#include "third_party/xla_client/types.h" + namespace xla { namespace metrics { @@ -34,10 +38,16 @@ struct Analysis { class Analyzer { public: + virtual ~Analyzer() = default; + virtual Analysis Run() = 0; + virtual Analysis Run(const std::map& metrics) { + return Run(); + } }; -std::string CreatePerformanceReport(); +std::string CreatePerformanceReport( + const std::map& metrics); } // namespace metrics } // namespace xla diff --git a/third_party/xla_client/metrics_reader.cc b/third_party/xla_client/metrics_reader.cc index 13410ff19ab..1aa41bb307c 100644 --- a/third_party/xla_client/metrics_reader.cc +++ b/third_party/xla_client/metrics_reader.cc @@ -2,7 +2,6 @@ #include -#include "third_party/xla_client/computation_client.h" #include "third_party/xla_client/debug_macros.h" #include "third_party/xla_client/metrics.h" #include "third_party/xla_client/util.h" @@ -28,10 +27,10 @@ MetricFnInfo GetMetricRenderInfo(const Percentile& percentile) { } } -std::string CreateXrtMetricReport() { - auto xrt_metrics = ComputationClient::Get()->GetMetrics(); +std::string CreateXrtMetricReport( + const std::map& xrt_metrics) { std::stringstream ss; - for (auto& name_metric : xrt_metrics) { + for (const auto& name_metric : xrt_metrics) { if (name_metric.second.percentile) { const Percentile& percentile = *name_metric.second.percentile; MetricFnInfo minfo = GetMetricRenderInfo(percentile); @@ -70,8 +69,9 @@ std::string CreateXrtMetricReport() { } // namespace -std::string CreateMetricReport() { - return metrics::CreateMetricReport() + CreateXrtMetricReport(); +std::string CreateMetricReport( + const std::map& xrt_metrics) { + return metrics::CreateMetricReport() + CreateXrtMetricReport(xrt_metrics); } std::string CreateMetricReport(const std::vector& counter_names, diff --git a/third_party/xla_client/metrics_reader.h b/third_party/xla_client/metrics_reader.h index ffc6267e39f..31bc7a18d54 100644 --- a/third_party/xla_client/metrics_reader.h +++ b/third_party/xla_client/metrics_reader.h @@ -1,14 +1,19 @@ #ifndef XLA_CLIENT_METRICS_READER_H_ #define XLA_CLIENT_METRICS_READER_H_ +#include #include #include +#include "third_party/xla_client/metrics.h" +#include "third_party/xla_client/types.h" + namespace xla { namespace metrics_reader { // Creates a report with the current metrics statistics. -std::string CreateMetricReport(); +std::string CreateMetricReport( + const std::map& xrt_metrics); // Creates a report with the selected metrics statistics. std::string CreateMetricReport(const std::vector& counter_names, diff --git a/third_party/xla_client/util_test.cc b/third_party/xla_client/util_test.cc index efe8f81155e..3b29ad26228 100644 --- a/third_party/xla_client/util_test.cc +++ b/third_party/xla_client/util_test.cc @@ -3,7 +3,11 @@ #include #include +#include #include +#include + +#include "absl/types/span.h" namespace xla { namespace util { @@ -67,6 +71,37 @@ TEST(UtilTest, MapInsert) { std::unordered_map v; EXPECT_EQ(MapInsert(&v, 1, [] { return 1; }), 1); EXPECT_EQ(MapInsert(&v, 1, [] { return 7; }), 1); + EXPECT_EQ(MapInsert(&v, 1, [] { return 12; }), 1); +} + +TEST(UtilTest, GetEnumValue) { + enum E { A = 0, B, C, D }; + EXPECT_EQ(GetEnumValue(E::A), 0); + EXPECT_EQ(GetEnumValue(E::B), 1); + EXPECT_EQ(GetEnumValue(E::C), 2); + EXPECT_EQ(GetEnumValue(E::D), 3); +} + +TEST(UtilTest, Multiply) { + std::vector t = {1, 2, 3, 4, 5}; + EXPECT_EQ(Multiply(t), 120); + t.push_back(6); + EXPECT_EQ(Multiply(t), 720); +} + +TEST(UtilTest, Hash) { + std::pair temp = {"hello", 3}; + EXPECT_EQ(Hash(std::pair{"hello", 3}), Hash(temp)); + EXPECT_EQ(HexHash(Hash(std::pair{"hello", 3})), + HexHash(Hash(temp))); + + std::vector t = {1, 2, 3, 4, 5}; + EXPECT_EQ(Hash({1, 2, 3, 4, 5}), Hash({1, 2, 3, 4, 5})); + EXPECT_EQ(Hash(std::set{1, 2, 3}), Hash(std::set{1, 2, 3})); + EXPECT_EQ(Hash(t), Hash(std::vector{1, 2, 3, 4, 5})); + + EXPECT_EQ(StdDataHash(t.data(), t.size()), + StdDataHash(std::vector{1, 2, 3, 4, 5}.data(), t.size())); } } // namespace util diff --git a/third_party/xla_client/xla_util_test.cc b/third_party/xla_client/xla_util_test.cc new file mode 100644 index 00000000000..dd39cc7cdba --- /dev/null +++ b/third_party/xla_client/xla_util_test.cc @@ -0,0 +1,119 @@ +#include "third_party/xla_client/xla_util.h" + +#include +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status_matchers.h" +#include "tensorflow/tsl/protobuf/error_codes.pb.h" +#include "xla_util.h" + +namespace xla { +namespace util { + +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; +using ::tsl::testing::StatusIs; + +TEST(XlaUtilTest, ShapeHash) { + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + EXPECT_EQ(ShapeHash(shape), ShapeHash(shape)); +} + +template +StatusOr ParseTextProto(const std::string& text_proto) { + tensorflow::protobuf::TextFormat::Parser parser; + MessageType parsed_proto; + tensorflow::protobuf::io::ArrayInputStream input_stream(text_proto.data(), + text_proto.size()); + if (!parser.Parse(&input_stream, &parsed_proto)) { + return tensorflow::errors::InvalidArgument("Could not parse text proto: ", + text_proto); + } + return parsed_proto; +} + +TEST(XlaUtilrest, CreateModule) { + TF_ASSERT_OK_AND_ASSIGN( + HloModuleProto hlo_module_proto, + ParseTextProto( + R"pb( + name: "myname" + id: 7 + entry_computation_name: "mycomp" + entry_computation_id: 0 + computations { + id: 0 + name: "c1" + instructions: { + name: "i1" + id: 1 + opcode: "constant" + shape: { + element_type: S32 + layout {} + } + literal: { + shape: { + element_type: S32 + layout {} + } + s32s: 0 + } + } + instructions: { + name: "constant.3" + id: 0 + opcode: "constant" + shape: { + element_type: S32 + layout {} + } + literal: { + shape: { + element_type: S32 + layout {} + } + s32s: 0 + } + } + root_id: 1 + } + host_program_shape: { result: { element_type: 4 } } + )pb")); + + HloModule m("cool_module", {}); + auto got = CreateModuleFromProto(hlo_module_proto); + EXPECT_THAT(got, IsOk()); + EXPECT_EQ((*got)->name(), "myname"); + EXPECT_EQ((*got)->computation_count(), 1); +} + +TEST(XlaUtilrest, XlaToHlo) { + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, input_shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + ASSERT_THAT(GetComputationHloText(*builder.Build()), + IsOkAndHolds(AllOf( + HasSubstr("HloModule AddComputation.4"), + HasSubstr("%AddComputation.4 (x.1: f32[2,2], y.2: f32[2,2])"), + HasSubstr("ROOT %add.3")))); +} + +} // namespace util +} // namespace xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 03292e23449..8a7e936dd83 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -352,7 +352,8 @@ void StepMarker(const std::string& device_str, XLAGraphExecutor::Get()->MarkStep(device); bool debug_mode = xla::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { - std::string report = xla::metrics::CreatePerformanceReport(); + std::string report = xla::metrics::CreatePerformanceReport( + xla::ComputationClient::Get()->GetMetrics()); if (!report.empty()) { std::string fout = xla::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); if (TF_PREDICT_FALSE(!fout.empty())) { @@ -1307,7 +1308,8 @@ void InitXlaModuleBindings(py::module m) { // TODO(jwtan): Unify them once ComputationClient becomes a standalone // library. return torch::lazy::CreateMetricReport() + - xla::metrics_reader::CreateMetricReport(); + xla::metrics_reader::CreateMetricReport( + xla::ComputationClient::Get()->GetMetrics()); }); m.def("_short_xla_metrics_report", [](const py::list& counter_names, const py::list& metric_names) {