Skip to content

Commit

Permalink
Update OpenXLA-pin to Jun27 (#7334)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Jul 1, 2024
1 parent 020f3f5 commit 9c32c66
Show file tree
Hide file tree
Showing 26 changed files with 55 additions and 45 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'ef3ae8863519edff6e7c18ada8ed0672b9d9f158'
xla_hash = '8533a6869ae02fb3b15a8a12739a982fc3c9f6e7'

http_archive(
name = "xla",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240618'
_date = '20240628'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.30.dev{_date}'
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/convolution_helper.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "torch_xla/csrc/convolution_helper.h"

#include "absl/status/status.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/tensor_float_32_utils.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -39,7 +40,7 @@ std::string ToString(TensorFormat format) {

// Performs some basic checks on ConvOpAttrs that are true for all kinds of
// XLA convolutions (as currently implemented).
xla::Status CheckConvAttrs(const ConvOpAttrs& attrs) {
absl::Status CheckConvAttrs(const ConvOpAttrs& attrs) {
const int num_dims = attrs.num_spatial_dims + 2;
const int attrs_strides_size = attrs.strides.size();
if (attrs_strides_size != num_dims) {
Expand Down Expand Up @@ -94,7 +95,7 @@ xla::Shape GroupedFilterShapeForDepthwiseConvolution(
// This part of helpers are origionally from
// https://github.com/tensorflow/tensorflow/blob/7f39a389d5b82d6aca13240c21f2647c3ebdb765/tensorflow/core/framework/kernel_shape_util.cc

xla::Status GetWindowedOutputSizeVerboseV2(
absl::Status GetWindowedOutputSizeVerboseV2(
int64_t input_size, int64_t filter_size, int64_t dilation_rate,
int64_t stride, Padding padding_type, int64_t* output_size,
int64_t* padding_before, int64_t* padding_after) {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/DLConvertor.h>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/ops/device_data.h"
Expand All @@ -16,7 +17,6 @@
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/status.h"

namespace torch_xla {

Expand Down Expand Up @@ -325,7 +325,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
}
xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> pjrt_buffer =
absl::StatusOr<std::unique_ptr<xla::PjRtBuffer>> pjrt_buffer =
device->client()->CreateViewOfDeviceBuffer(
static_cast<char*>(dlmt->dl_tensor.data) +
dlmt->dl_tensor.byte_offset,
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <iterator>
#include <limits>

#include "absl/status/status.h"
#include "absl/strings/str_join.h"
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/dtype.h"
Expand Down Expand Up @@ -971,7 +972,7 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp(
return unary_op(op);
}

xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
absl::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
const std::vector<size_t>& buffer_donor_indices) {
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tuple>
#include <vector>

#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
Expand Down Expand Up @@ -386,7 +387,7 @@ class XlaHelpers {
s_mat_mul_precision = precision;
}

static xla::StatusOr<xla::XlaComputation> WrapXlaComputation(
static absl::StatusOr<xla::XlaComputation> WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
const std::vector<size_t>& buffer_donor_indices);
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/variant.h"
Expand Down Expand Up @@ -870,7 +871,7 @@ void BuildProfilerSubmodule(py::module* m) {
absl::flat_hash_map<std::string, std::variant<int, std::string>> opts =
ConvertDictToMap(options);
std::chrono::seconds sleep_s(interval_s);
xla::Status status;
absl::Status status;
{
NoGilSection nogil;
for (int i = 0; i <= timeout_s / interval_s; i++) {
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stdexcept>
#include <string_view>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
Expand Down Expand Up @@ -152,8 +153,8 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) {
root_tuple_.at(index) = std::move(op);
}

xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
xla::StatusOr<xla::XlaComputation> xla;
absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
absl::StatusOr<xla::XlaComputation> xla;

// check whether build for cond/body computation or not, and skip Tuple step
// if yes
Expand All @@ -175,7 +176,7 @@ xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
return xla;
}

xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
XLA_CHECK(root_tuple_.empty());
auto xla = builder()->Build(root);

Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
Expand Down Expand Up @@ -73,13 +74,13 @@ class LoweringContext : public torch::lazy::LoweringContext {

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
xla::StatusOr<xla::XlaComputation> BuildXla();
absl::StatusOr<xla::XlaComputation> BuildXla();

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
// Uses root as return value for the computation. It is an error to use this
// API after having called the AddResult() API.
xla::StatusOr<xla::XlaComputation> BuildXla(xla::XlaOp root);
absl::StatusOr<xla::XlaComputation> BuildXla(xla::XlaOp root);

// Lowers a single IR node. All the inputs to the node must have a lowering
// before calling this API. Returns the generated XLA operations.
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/util.h>

#include "absl/status/status.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
Expand Down Expand Up @@ -314,12 +315,12 @@ xla::XlaOp ComputeMaxPoolIndices(
xla::ShapeUtil::MakeShape(kIndicesType, {pool_elements})));

auto cond_fn = [&](absl::Span<const xla::XlaOp> init,
xla::XlaBuilder* builder) -> xla::StatusOr<xla::XlaOp> {
xla::XlaBuilder* builder) -> absl::StatusOr<xla::XlaOp> {
return xla::Lt(init[counter_id], init[limit_id]);
};
auto body_fn =
[&](absl::Span<const xla::XlaOp> init,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::XlaBuilder* builder) -> absl::StatusOr<std::vector<xla::XlaOp>> {
PoolSliceIndices slice_indices =
ComputeSliceIndices(init[counter_id], pool_result_shape.dimensions(),
pooling_op_attributes.stride);
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/debug_macros.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef XLA_CLIENT_DEBUG_MACROS_H_
#define XLA_CLIENT_DEBUG_MACROS_H_

#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/platform/stacktrace.h"
#include "xla/statusor.h"
Expand All @@ -16,7 +17,7 @@
#define XLA_CHECK_GT(a, b) TF_CHECK_GT(a, b) << "\n" << tsl::CurrentStackTrace()

template <typename T>
T ConsumeValue(xla::StatusOr<T>&& status) {
T ConsumeValue(absl::StatusOr<T>&& status) {
XLA_CHECK_OK(status.status());
return std::move(status).value();
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ IfrtComputationClient::ExecuteReplicated(
.value();

result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)](
xla::Status status) mutable {
absl::Status status) mutable {
timed.reset();
TF_VLOG(3)
<< "ExecuteReplicated returned_future->OnReady finished with status "
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/ifrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
Expand All @@ -16,14 +17,13 @@
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tests/literal_test_util.h"

namespace torch_xla {
namespace runtime {

xla::StatusOr<xla::XlaComputation> MakeComputation() {
absl::StatusOr<xla::XlaComputation> MakeComputation() {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_set>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -316,7 +317,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable.";

// Returns error if the buffer is already on `dst_device`.
xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> status_or =
absl::StatusOr<std::unique_ptr<xla::PjRtBuffer>> status_or =
pjrt_data->buffer->CopyToDevice(dst_device);
if (!status_or.ok()) {
return data;
Expand Down Expand Up @@ -472,7 +473,7 @@ std::uintptr_t PjRtComputationClient::UnsafeBufferPointer(
XLA_CHECK(pjrt_data) << "handle must be PjRtData, got " << handle->ToString();
XLA_CHECK(pjrt_data->buffer != nullptr)
<< "PjRt buffer is null in " << __FUNCTION__;
xla::StatusOr<std::uintptr_t> ptr =
absl::StatusOr<std::uintptr_t> ptr =
client_->UnsafeBufferPointer(pjrt_data->buffer.get());
XLA_CHECK(ptr.ok());
return ptr.value();
Expand Down Expand Up @@ -744,7 +745,7 @@ PjRtComputationClient::ExecuteComputation(
.value();

returned_future->OnReady(std::move(
[timed, op_tracker = std::move(op_tracker)](xla::Status unused) mutable {
[timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
}));
Expand Down Expand Up @@ -850,7 +851,7 @@ PjRtComputationClient::ExecuteReplicated(

(*returned_futures)[0].OnReady(
std::move([timed, op_tracker = std::move(op_tracker)](
xla::Status unused) mutable {
absl::Status unused) mutable {
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
}));
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
Expand All @@ -17,14 +18,13 @@
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tests/literal_test_util.h"

namespace torch_xla {
namespace runtime {

xla::StatusOr<xla::XlaComputation> MakeComputation() {
absl::StatusOr<xla::XlaComputation> MakeComputation() {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "torch_xla/csrc/runtime/pjrt_registry.h"

#include "absl/log/initialize.h"
#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/profiler.h"
Expand Down Expand Up @@ -138,7 +139,7 @@ InitializePjRt(const std::string& device_type) {
env::kEnvTpuLibraryPath,
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
xla::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
absl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK_OK(tpu_status);
client = std::move(xla::GetCApiClient("TPU").value());
const PJRT_Api* c_api =
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "torch_xla/csrc/runtime/profiler.h"

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/profiler/lib/profiler_factory.h"
#include "tsl/profiler/rpc/client/capture_profile.h"
#include "tsl/profiler/rpc/profiler_server.h"
#include "xla/backends/profiler/plugin/plugin_tracer.h"
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand Down Expand Up @@ -45,7 +45,7 @@ void ProfilerServer::Start(int port) {

ProfilerServer::~ProfilerServer() {}

xla::Status Trace(
absl::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include <memory>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand All @@ -23,7 +23,7 @@ class ProfilerServer {
std::unique_ptr<Impl> impl_;
};

xla::Status Trace(
absl::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/runtime/tf_logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <sstream>

#include "tsl/platform/logging.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/runtime/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "torch_xla/csrc/runtime/types.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/hash.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/xla_coordinator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size,
sys_util::GetEnvInt(env::kEnvDistSvcShutdownTimeoutInMin, 5);
service_options.shutdown_timeout = absl::Minutes(shutdown_timeout);

xla::StatusOr<std::unique_ptr<xla::DistributedRuntimeService>>
absl::StatusOr<std::unique_ptr<xla::DistributedRuntimeService>>
dist_runtime_service = xla::GetDistributedRuntimeService(
dist_service_addr, service_options);
XLA_CHECK(dist_runtime_service.ok())
Expand Down
Loading

0 comments on commit 9c32c66

Please sign in to comment.