Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing down dynamic dimensions from torch to XLA #5790

Merged
merged 10 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ def bazel_build(self, ext):

bazel_argv = [
'bazel', 'build', ext.bazel_target,
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}",
'\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args])
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
]
for opt in extra_compile_args:
bazel_argv.append("--cxxopt={}".format(opt))

# Debug build.
if DEBUG:
Expand Down
58 changes: 58 additions & 0 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import exported_program_to_stablehlo

# Note: Unbounded dynamism is under development. It works with unmerged
# XLA changes. Experimental XLA branch: https://github.com/lsy323/openxla-xla/tree/lsiyuan/sandeep-dynamism-rebased

device = xm.xla_device()


class UnboundedDynamismExportTest(unittest.TestCase):

def test_simply_add(self):
a = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(a, 0)
b = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(b, 0)
c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
self.assertTrue(
"(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content)

def test_export_dynamism(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y

example_args = (torch.tensor([[1, 2], [2, 4]], device=device),
torch.tensor([[1, 2], [2, 4]], device=device))
constraints = [
# First dimension of each input is a dynamic batch size
torch.export.dynamic_dim(example_args[0], 0),
torch.export.dynamic_dim(example_args[1], 0),
# The dynamic batch size between the inputs are equal
torch.export.dynamic_dim(example_args[0],
0) == torch.export.dynamic_dim(
example_args[1], 0),
]
ep = torch.export.export(M(), args=example_args, constraints=constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text("forward")
self.assertTrue(
"(%arg0: tensor<?x2xi64>, %arg1: tensor<?x2xi64>) -> tensor<?x2xi64>" in
shlo_text)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
13 changes: 11 additions & 2 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/random.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
Expand Down Expand Up @@ -66,8 +67,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output,

xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
return xla::Max(input, XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder()));
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
return xla::Max(promoted.first, promoted.second);
} else {
return xla::Max(input, scalar);
}
}

xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Expand Down
65 changes: 64 additions & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
Expand All @@ -21,6 +20,9 @@
namespace torch_xla {
namespace {

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::XlaOp result) {
xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1);
Expand Down Expand Up @@ -63,6 +65,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
return xla::BroadcastInDim(input, bcast_sizes,
GetAllDimensions(bcast_sizes.size()));
Expand Down Expand Up @@ -322,6 +327,59 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
: xla::Reshape(input, shape.dimensions());
}

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
[](int64_t size) { return size == kUnboundedSize; });
}

xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please prepare an opset issue to track supporting this mode of dynamism?

Similar work example: #5764

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think we have a list internally

Copy link
Collaborator

@miladm miladm Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work to maintain the list on GH?

xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
std::vector<xla::XlaOp> get_dim_ops;
std::vector<xla::XlaOp> reshaped_ops;
bool all_static = true;
std::vector<bool> output_dynamic(output_sizes.size(), false);

for (int i = 0; i < output_sizes.size(); i++) {
if (output_sizes[i] == kUnboundedSize) {
output_dynamic[i] = true;
get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i));
all_static = false;
} else {
get_dim_ops.push_back(XlaHelpers::ScalarValue<int32_t>(
output_sizes[i], aux_input.builder()));
}
}

if (all_static) {
return xla::Reshape(input, output_sizes);
}

// Create the reshape from scalar to 1-D vector
for (auto get_dim_op : get_dim_ops) {
reshaped_ops.push_back(xla::Reshape(get_dim_op, {1}));
}

// Create Concatenate op
auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0});
return xla::CustomCall(
aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op},
xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes,
output_dynamic));

return input;
}

bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2) {
return shape1.is_static() && shape2.is_static() &&
Expand Down Expand Up @@ -485,6 +543,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
}
return GetPromotedDynamicShape(shape1, shape2);
}

Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/util.h"
#include "tsl/platform/bfloat16.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -158,6 +159,17 @@ class XlaHelpers {
static xla::XlaOp DynamicReshape(xla::XlaOp input,
absl::Span<const int64_t> output_sizes);

static bool IsUnboundedDynamic(const xla::Shape& shape);

static bool IsUnboundedDynamismEnabled() {
return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM",
false);
}

static xla::XlaOp DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);

static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);

static bool SameStaticDimensions(const xla::Shape& shape1,
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,12 @@ void InitXlaModuleBindings(py::module m) {
return handles;
});

m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) {
TORCH_LAZY_COUNTER("XlaMarkDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->MarkDynamicDimension(dim);
});

// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ xla::Shape XlaNode::GetOpShape(
std::string XlaNode::ToString() const {
std::stringstream ss;
ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_;
ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ")
<< ')';
return ss.str();
}

Expand Down
13 changes: 12 additions & 1 deletion torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -138,6 +138,17 @@ class XlaNode : public torch::lazy::Node {

std::string ToString() const override;

void MarkDynamicDimension(uint32_t dim) {
unbounded_dynamic_dims_.insert(dim);
}

const std::unordered_set<uint32_t>& dynamic_dims() const {
return unbounded_dynamic_dims_;
}

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

Expand Down
38 changes: 33 additions & 5 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,31 @@ LoweringContext::LoweringContext(
}
}

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data) {
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::XlaOp param = xla::Parameter(
builder(), parameters_.size(),
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape(),
absl::StrCat("p", parameters_.size()));
->shape();
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape,
absl::StrCat("p", parameters_.size()));
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
<< "The unbounded dynamic dims can only be set when Parameter is "
"created.";
}
parameter_sequence_.push_back(it->second.index);
return it->second.param;
Expand Down Expand Up @@ -170,6 +182,22 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {

const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
result_ops = casted->Lower(this);
if (!casted->dynamic_dims().empty()) {
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
}
} catch (const std::exception& ex) {
ReportBuilderError(node, ex.what());
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data);
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& dynamic_dims = {});

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector DeviceData::Lower(LoweringContext* loctx) const {
return ReturnOp(loctx->GetParameter(data_), loctx);
return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx);
}

DeviceData* DeviceData::Cast(const torch::lazy::Node* node) {
Expand Down
21 changes: 18 additions & 3 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,
xla::XlaOp scale = xla::Select(xla::Ne(count, zero),
one / xla::ConvertElementType(count, type),
xla::NanValue(input.builder(), type));
return input * scale;

if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scale);
return promoted.first * promoted.second;
} else {
return input * scale;
}
}

xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) {
Expand Down Expand Up @@ -109,8 +117,15 @@ SummationResult CreateSummation(xla::XlaOp input,
result.result, result.rinfo.element_count.size, shape.element_type());
}
if (keep_reduced_dimensions) {
result.result =
XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions);
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is
// added.
result.result = XlaHelpers::DynamicUnboundedReshape(
result.result, input, result.rinfo.new_dimensions);
} else {
result.result = XlaHelpers::DynamicReshape(result.result,
result.rinfo.new_dimensions);
}
}
return result;
}
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,9 @@ int64_t XLATensor::GetHandle() const {
}
}

void XLATensor::MarkDynamicDimension(uint32_t dim) {
auto* xla_node = dynamic_cast<XlaNode*>(GetIrValue().node.get());
xla_node->MarkDynamicDimension(dim);
}

} // namespace torch_xla
Loading