Skip to content

Commit

Permalink
* Use xla::KeyValueStoreInterface in GpuClientOptions.
Browse files Browse the repository at this point in the history
* Update StableHLO unbounded dynamism tests
  • Loading branch information
yeounoh committed Jan 18, 2024
1 parent 4fa29e5 commit 0f48030
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 25 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20231130'
_libtpu_version = '0.1.dev20240117' #need to be 20240118
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'


Expand Down
8 changes: 0 additions & 8 deletions test/stablehlo/test_implicit_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ def test_same_rank_broadcast_with_unbounded_dynamic_shapes_1(self):
re.search(
r'dynamic_broadcast_in_dim.*=.*\[0\].*: \(tensor<\?xf32>, tensor<1xi32>\) -> tensor<10xf32>',
stablehlo_text) is not None)
self.assertTrue(
re.search(
r'dynamic_broadcast_in_dim.*=.*\[0\].*: \(tensor<10xf32>, tensor<1xi32>\) -> tensor<10xf32>',
stablehlo_text) is not None)

### (?,?) * (?,1)
def test_same_rank_broadcast_with_unbounded_dynamic_shapes_2(self):
Expand Down Expand Up @@ -175,10 +171,6 @@ def test_different_rank_broadcast_with_unbounded_dynamic_shapes_3(self):
torch_xla._XLAC._xla_mark_dynamic(b, 0)
c = a * b
stablehlo_text = xm.get_stablehlo([c])
self.assertTrue(
re.search(
r'dynamic_broadcast_in_dim.*=.*\[0, 1\].*: \(tensor<2x5xf32>, tensor<2xi32>\) -> tensor<2x5xf32>',
stablehlo_text) is not None)
self.assertTrue(
re.search(
r'dynamic_broadcast_in_dim.*=.*\[1\].*: \(tensor<\?xf32>, tensor<2xi32>\) -> tensor<2x5xf32>',
Expand Down
25 changes: 9 additions & 16 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
Expand Down Expand Up @@ -95,42 +97,33 @@ InitializePjRt(const std::string& device_type) {
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false);
std::optional<std::set<int>> allowed_devices;
if (!spmd) {
allowed_devices = std::set{local_process_rank};
}

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;
if (global_world_size > 1) {
// Use the XlaCoordinator as the distributed key-value store.
// Use the distributed key-value store from DistributedRuntimeClient.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;

xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
options.num_nodes = global_world_size;
options.allowed_devices = allowed_devices;
options.platform_name = "gpu";
options.should_stage_host_to_device_transfers = true;
options.kv_get = kv_get;
options.kv_put = kv_put;
options.kv_store = kv_store;
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
Expand Down

0 comments on commit 0f48030

Please sign in to comment.