Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TEST] openxla-pin nov24 test on tpu #5940

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ http_archive(
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
strip_prefix = "xla-8744c9a94782cd7804f015e6d29df253437af3cb",
urls = [
"https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz",
"https://github.com/openxla/xla/archive/8744c9a94782cd7804f015e6d29df253437af3cb.tar.gz",
],
)

Expand Down
33 changes: 16 additions & 17 deletions openxla_patches/cache_urls.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index d7f3a8093..a7af9c68a 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -13,7 +13,9 @@ def repo(name):
strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
build_file = "//third_party/llvm:llvm.BUILD",
patch_file = [
diff --git a/xla/mlir_hlo/WORKSPACE b/xla/mlir_hlo/WORKSPACE
index cc9eeb64f..b290eb455 100644
index c3115e33d..d315ad745 100644
--- a/xla/mlir_hlo/WORKSPACE
+++ b/xla/mlir_hlo/WORKSPACE
@@ -35,7 +35,10 @@ http_archive(
Expand All @@ -13,19 +27,4 @@ index cc9eeb64f..b290eb455 100644
+ ],
)

load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps")
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps")
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index a4574d75d..f9ce37094 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -13,7 +13,9 @@ def repo(name):
strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
build_file = "//third_party/llvm:llvm.BUILD",
patch_file = [
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
6 changes: 3 additions & 3 deletions openxla_patches/constexpr_return.diff
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
diff --git a/xla/primitive_util.h b/xla/primitive_util.h
index 696147844..dfea15a4d 100644
index 63fa4e193..ab352626c 100644
--- a/xla/primitive_util.h
+++ b/xla/primitive_util.h
@@ -748,6 +748,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) {
@@ -706,6 +706,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) {
std::numeric_limits<NativeT>::max() >= x;
}
LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty);
+ return false;
+ return false;
},
ty);
}
2 changes: 1 addition & 1 deletion openxla_patches/f16_abi_clang.diff
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ index 3f7af5197..ce4491c5d 100644
+#if defined(__x86_64__)
// Older versions of Clang don't have _Float16. Since both float and _Float16
// are passed in the same register we can use the wider type and careful casting
// to conform to x86_64 psABI. This only works with the assumption that we're
// to conform to x86_64 psABI. This only works with the assumption that we're
8 changes: 4 additions & 4 deletions openxla_patches/gpu_race_condition.diff
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc
index 242961dd1..787275868 100644
index 1f9903cb3..763b7fc23 100644
--- a/xla/service/gpu/gpu_executable.cc
+++ b/xla/service/gpu/gpu_executable.cc
@@ -563,8 +563,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
@@ -589,8 +589,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStreamImpl(
}

// Force synchronous execution if the allocator requires it.
- const bool block_host_until_done =
- !memory_allocator->AllowsAsynchronousDeallocation();
+ const bool block_host_until_done = true;


// Lock the GPU with a shared lock so that we don't interfere with autotuning
// Lock the GPU with a shared lock so that we don't interfere with autotuning
// that may be running during JIT compilation while allowing multiple XLA
14 changes: 6 additions & 8 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index f74973ae1..8e3f0e06b 100644
index 0f0c5e842..59a30c585 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -67,6 +67,7 @@ cc_library(
Expand All @@ -14,10 +11,10 @@ index f74973ae1..8e3f0e06b 100644
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 08d5f49c8..2f9ad1e0b 100644
index cc7aa9e9e..0eaa68ff2 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -664,6 +664,70 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
@@ -664,6 +664,71 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

Expand Down Expand Up @@ -84,11 +81,12 @@ index 08d5f49c8..2f9ad1e0b 100644
+ storage_min, storage_max);
+ }
+}
+
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
@@ -933,6 +997,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
@@ -933,6 +998,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
Expand All @@ -115,7 +113,7 @@ index 08d5f49c8..2f9ad1e0b 100644
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index 9f05992c8..03cf4840d 100644
index 1494efd9e..dcb3d9e89 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
Expand Down
4 changes: 1 addition & 3 deletions openxla_patches/stablehlo_quant_seralization.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format
// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed.
diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch
new file mode 100644
index 000000000..24e23b67d
Expand Down Expand Up @@ -33,7 +31,7 @@ index 000000000..24e23b67d
+ } // namespace
+
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index 9f4494aac..64fa072bb 100644
index 80ab0e479..caaa11080 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -15,5 +15,6 @@ def repo():
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20231022'
_libtpu_version = '0.1.dev20231125'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'


Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class ComputationClient {
virtual int GetNumProcesses() const = 0;

using DeviceAttribute =
std::variant<std::string, int64_t, std::vector<int64_t>, float, bool>;
std::variant<std::string, bool, int64_t, std::vector<int64_t>, float>;

virtual const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
Expand Down
29 changes: 14 additions & 15 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,28 @@ PjRtComputationClient::PjRtComputationClient() {
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
if (distributed_client != nullptr) {
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](const std::string& k,
absl::Duration timeout) {
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](const std::string& k,
const std::string& v) {
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async,
/*allocator_config=*/GetGpuAllocatorConfig(),
/*node_id=*/global_process_rank,
/*num_nodes=*/global_world_size,
/*allowed_devices=*/allowed_devices,
/*platform_name=*/"gpu",
/*should_stage_host_to_device_transfers=*/true,
/*kv_get=*/kv_get,
/*kv_put=*/kv_put)
.value());
xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
options.num_nodes = global_world_size;
options.allowed_devices = allowed_devices;
options.platform_name = "gpu";
options.kv_get = kv_get;
options.kv_put = kv_put;
client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down
Loading