diff --git a/setup.py b/setup.py index 1215498a673c..a9ec34680c66 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20231130' +_libtpu_version = '0.1.dev20240117' #need to be 20240118 _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/test/stablehlo/test_implicit_broadcasting.py b/test/stablehlo/test_implicit_broadcasting.py index 8b35813d6b02..24c8a80c77d7 100644 --- a/test/stablehlo/test_implicit_broadcasting.py +++ b/test/stablehlo/test_implicit_broadcasting.py @@ -71,10 +71,6 @@ def test_same_rank_broadcast_with_unbounded_dynamic_shapes_1(self): re.search( r'dynamic_broadcast_in_dim.*=.*\[0\].*: \(tensor<\?xf32>, tensor<1xi32>\) -> tensor<10xf32>', stablehlo_text) is not None) - self.assertTrue( - re.search( - r'dynamic_broadcast_in_dim.*=.*\[0\].*: \(tensor<10xf32>, tensor<1xi32>\) -> tensor<10xf32>', - stablehlo_text) is not None) ### (?,?) * (?,1) def test_same_rank_broadcast_with_unbounded_dynamic_shapes_2(self): @@ -175,10 +171,6 @@ def test_different_rank_broadcast_with_unbounded_dynamic_shapes_3(self): torch_xla._XLAC._xla_mark_dynamic(b, 0) c = a * b stablehlo_text = xm.get_stablehlo([c]) - self.assertTrue( - re.search( - r'dynamic_broadcast_in_dim.*=.*\[0, 1\].*: \(tensor<2x5xf32>, tensor<2xi32>\) -> tensor<2x5xf32>', - stablehlo_text) is not None) self.assertTrue( re.search( r'dynamic_broadcast_in_dim.*=.*\[1\].*: \(tensor<\?xf32>, tensor<2xi32>\) -> tensor<2x5xf32>', diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index f1b6a6711370..ce8b3f029e5c 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -5,7 +5,9 @@ #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" @@ -95,33 +97,25 @@ InitializePjRt(const std::string& device_type) { std::string port = runtime::sys_util::GetEnvString( "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); std::optional> allowed_devices; if (!spmd) { allowed_devices = std::set{local_process_rank}; } + + std::shared_ptr kv_store; if (global_world_size > 1) { - // Use the XlaCoordinator as the distributed key-value store. + // Use the distributed key-value store from DistributedRuntimeClient. coordinator = std::make_unique( global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator->GetClient(); - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; + xla::GpuClientOptions options; options.allocator_config = GetGpuAllocatorConfig(); options.node_id = global_process_rank; @@ -129,8 +123,7 @@ InitializePjRt(const std::string& device_type) { options.allowed_devices = allowed_devices; options.platform_name = "gpu"; options.should_stage_host_to_device_transfers = true; - options.kv_get = kv_get; - options.kv_put = kv_put; + options.kv_store = kv_store; client = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client...";