diff --git a/WORKSPACE b/WORKSPACE index 790f6d8f31c..1b51229132e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,9 +43,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", ], - strip_prefix = "xla-cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d", + strip_prefix = "xla-7a371ed44aba34f83d6d3d1159d2e6d0d327c603", urls = [ - "https://github.com/openxla/xla/archive/cd2cf5c34931e4fc1cacf83bfc480a5b93f05f6d.tar.gz", + "https://github.com/openxla/xla/archive/7a371ed44aba34f83d6d3d1159d2e6d0d327c603.tar.gz", ], ) diff --git a/setup.py b/setup.py index 3e5c65d729a..1d60a7776d9 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20230809' +_libtpu_version = '0.1.dev20230826' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 89afbe1b9d2..a2865d41e5a 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -87,7 +87,7 @@ fi if [ "$LOGFILE" != "" ]; then - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:all ${FILTER:+"$FILTER"} 2> $LOGFILE + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:all --test_timeout 1000 ${FILTER:+"$FILTER"} 2> $LOGFILE else - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:all ${FILTER:+"$FILTER"} + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:all --test_timeout 1000 ${FILTER:+"$FILTER"} fi diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index e5a35030e66..eea50940b78 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -29,7 +29,6 @@ #include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/pjrt/tpu_client.h" #include "xla/shape.h" -#include "xla/stream_executor/tpu/tpu_initializer_framework_helper.h" using xla::internal::XlaBuilderFriend; @@ -105,7 +104,7 @@ PjRtComputationClient::PjRtComputationClient() { TF_VLOG(1) << "Initializing TFRT TPU client..."; XLA_CHECK_OK(pjrt::LoadPjrtPlugin( "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); - tsl::Status tpu_status = tensorflow::tpu::FindAndLoadTpuLibrary(); + tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); client_ = std::move(xla::GetCApiClient("TPU").value()); } else if (device_type == "TPU_LEGACY") {