Skip to content

Commit

Permalink
Update OpenXLA-pin to Nov24 (#6012)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored and bhavya01 committed Apr 22, 2024
1 parent 071a0c4 commit 60eb358
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 212 deletions.
5 changes: 2 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ http_archive(
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
strip_prefix = "xla-8744c9a94782cd7804f015e6d29df253437af3cb",
urls = [
"https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz",
"https://github.com/openxla/xla/archive/8744c9a94782cd7804f015e6d29df253437af3cb.tar.gz",
],
)

Expand Down
3 changes: 2 additions & 1 deletion openxla_patches/cache_urls.diff
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ index a4574d75d..f9ce37094 100644
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
build_file = "//third_party/llvm:llvm.BUILD",
patch_file = [
patch_file = [

6 changes: 3 additions & 3 deletions openxla_patches/constexpr_return.diff
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
diff --git a/xla/primitive_util.h b/xla/primitive_util.h
index 696147844..dfea15a4d 100644
index 63fa4e193..ab352626c 100644
--- a/xla/primitive_util.h
+++ b/xla/primitive_util.h
@@ -748,6 +748,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) {
@@ -706,6 +706,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) {
std::numeric_limits<NativeT>::max() >= x;
}
LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty);
+ return false;
+ return false;
},
ty);
}
8 changes: 4 additions & 4 deletions openxla_patches/gpu_race_condition.diff
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc
index 242961dd1..787275868 100644
index 1f9903cb3..763b7fc23 100644
--- a/xla/service/gpu/gpu_executable.cc
+++ b/xla/service/gpu/gpu_executable.cc
@@ -563,8 +563,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
@@ -589,8 +589,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
}

// Force synchronous execution if the allocator requires it.
- const bool block_host_until_done =
- !memory_allocator->AllowsAsynchronousDeallocation();
+ const bool block_host_until_done = true;


// Lock the GPU with a shared lock so that we don't interfere with autotuning
// Lock the GPU with a shared lock so that we don't interfere with autotuning
// that may be running during JIT compilation while allowing multiple XLA
184 changes: 0 additions & 184 deletions openxla_patches/gpu_topk_rewriter.diff

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20231022'
_libtpu_version = '0.1.dev20231125'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'


Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class ComputationClient {
virtual int GetNumProcesses() const = 0;

using DeviceAttribute =
std::variant<std::string, int64_t, std::vector<int64_t>, float, bool>;
std::variant<std::string, bool, int64_t, std::vector<int64_t>, float>;

virtual const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
Expand Down
30 changes: 15 additions & 15 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,29 @@ PjRtComputationClient::PjRtComputationClient() {
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
if (distributed_client != nullptr) {
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](const std::string& k,
absl::Duration timeout) {
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](const std::string& k,
const std::string& v) {
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async,
/*allocator_config=*/GetGpuAllocatorConfig(),
/*node_id=*/global_process_rank,
/*num_nodes=*/global_world_size,
/*allowed_devices=*/allowed_devices,
/*platform_name=*/"gpu",
/*should_stage_host_to_device_transfers=*/true,
/*kv_get=*/kv_get,
/*kv_put=*/kv_put)
.value());
xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
options.num_nodes = global_world_size;
options.allowed_devices = allowed_devices;
options.platform_name = "gpu";
options.should_stage_host_to_device_transfers = true;
options.kv_get = kv_get;
options.kv_put = kv_put;
client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down

0 comments on commit 60eb358

Please sign in to comment.