Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing down dynamic dimensions from torch to XLA #5790

Merged
merged 10 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ def bazel_build(self, ext):

bazel_argv = [
'bazel', 'build', ext.bazel_target,
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}",
'\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args])
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
]
for opt in extra_compile_args:
bazel_argv.append("--cxxopt={}".format(opt))

# Debug build.
if DEBUG:
Expand Down
16 changes: 14 additions & 2 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/random.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
Expand All @@ -14,6 +15,9 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val,
const at::Scalar& max_val) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
Expand Down Expand Up @@ -66,8 +70,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output,

xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
return xla::Max(input, XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder()));
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (experimental_unbounded_dynamism) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this path execute in the non export path?
Wdyt we limit this experimental condition to torch.export path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the discussion below, seems there is no better solution than using a env variable to enable the unbounded dynamism lowering. But we can put more thought on this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. This issue doesn't block this PR. Let's prepare a proposal in parallel.

// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
return xla::Max(promoted.first, promoted.second);
} else {
return xla::Max(input, scalar);
}
}

xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Expand Down
65 changes: 65 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::XlaOp result) {
xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1);
Expand Down Expand Up @@ -63,6 +69,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
if (experimental_unbounded_dynamism) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
return xla::BroadcastInDim(input, bcast_sizes,
GetAllDimensions(bcast_sizes.size()));
Expand Down Expand Up @@ -322,6 +331,57 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
: xla::Reshape(input, shape.dimensions());
}

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(experimental_unbounded_dynamism)
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
[](int64_t size) { return size == kUnboundedSize; });
}

xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please prepare an opset issue to track supporting this mode of dynamism?

Similar work example: #5764

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think we have a list internally

Copy link
Collaborator

@miladm miladm Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work to maintain the list on GH?

xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(experimental_unbounded_dynamism)
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
std::vector<xla::XlaOp> get_dim_ops;
std::vector<xla::XlaOp> reshaped_ops;
bool all_static = true;
std::vector<bool> output_dynamic(output_sizes.size(), false);

for (int i = 0; i < output_sizes.size(); i++) {
if (output_sizes[i] == kUnboundedSize) {
output_dynamic[i] = true;
get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i));
all_static = false;
} else {
get_dim_ops.push_back(XlaHelpers::ScalarValue<int32_t>(
output_sizes[i], aux_input.builder()));
}
}

if (all_static) {
return xla::Reshape(input, output_sizes);
}

// Create the reshape from scalar to 1-D vector
for (auto get_dim_op : get_dim_ops) {
reshaped_ops.push_back(xla::Reshape(get_dim_op, {1}));
}

// Create Concatenate op
auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0});
return xla::CustomCall(
aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op},
xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes,
output_dynamic));

return input;
}

bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2) {
return shape1.is_static() && shape2.is_static() &&
Expand Down Expand Up @@ -485,6 +545,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
if (experimental_unbounded_dynamism) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
}
return GetPromotedDynamicShape(shape1, shape2);
}

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ class XlaHelpers {
static xla::XlaOp DynamicReshape(xla::XlaOp input,
absl::Span<const int64_t> output_sizes);

static bool IsUnboundedDynamic(const xla::Shape& shape);

static xla::XlaOp DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);

static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);

static bool SameStaticDimensions(const xla::Shape& shape1,
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,12 @@ void InitXlaModuleBindings(py::module m) {
return handles;
});

m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) {
TORCH_LAZY_COUNTER("XlaMarkDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->MarkDynamicDimension(dim);
});

// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ xla::Shape XlaNode::GetOpShape(
std::string XlaNode::ToString() const {
std::stringstream ss;
ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_;
ss << ", dynamic_dims: ";
for (const auto dim : dynamic_dims_) {
ss << dim;
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
}
return ss.str();
}

Expand Down
11 changes: 10 additions & 1 deletion torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -138,6 +138,15 @@ class XlaNode : public torch::lazy::Node {

std::string ToString() const override;

void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); }

const std::unordered_set<uint32_t>& dynamic_dims() const {
return dynamic_dims_;
}

protected:
std::unordered_set<uint32_t> dynamic_dims_;
lsy323 marked this conversation as resolved.
Show resolved Hide resolved

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

Expand Down
38 changes: 33 additions & 5 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,31 @@ LoweringContext::LoweringContext(
}
}

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data) {
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::XlaOp param = xla::Parameter(
builder(), parameters_.size(),
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape(),
absl::StrCat("p", parameters_.size()));
->shape();
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape,
absl::StrCat("p", parameters_.size()));
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
<< "The unbounded dynamic dims can only be set when Parameter is "
"created.";
}
parameter_sequence_.push_back(it->second.index);
return it->second.param;
Expand Down Expand Up @@ -170,6 +182,22 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {

const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
result_ops = casted->Lower(this);
if (!casted->dynamic_dims().empty()) {
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
}
} catch (const std::exception& ex) {
ReportBuilderError(node, ex.what());
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data);
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& dynamic_dims = {});

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector DeviceData::Lower(LoweringContext* loctx) const {
return ReturnOp(loctx->GetParameter(data_), loctx);
return ReturnOp(loctx->GetParameter(data_, dynamic_dims_), loctx);
}

DeviceData* DeviceData::Cast(const torch::lazy::Node* node) {
Expand Down
24 changes: 21 additions & 3 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct SummationResult {
xla::XlaOp result;
};

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env variables causes poor user experience. What's the plan to clean up a better solution (potentially via the upstream torch API level)?

cc @JackCaoG @qihqi @lsy323

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for adding env variables, please keep this file up to date.

https://github.com/pytorch/xla/blob/master/configuration.yaml

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current thought is to use env variable to limit the code path to "Export Only", as you've pointed out in many places. If there are better mechanisms to accomplish that we can also use something different.


ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape,
absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down Expand Up @@ -81,7 +84,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,
xla::XlaOp scale = xla::Select(xla::Ne(count, zero),
one / xla::ConvertElementType(count, type),
xla::NanValue(input.builder(), type));
return input * scale;

if (experimental_unbounded_dynamism) {
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scale);
return promoted.first * promoted.second;
} else {
return input * scale;
}
}

xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) {
Expand Down Expand Up @@ -109,8 +120,15 @@ SummationResult CreateSummation(xla::XlaOp input,
result.result, result.rinfo.element_count.size, shape.element_type());
}
if (keep_reduced_dimensions) {
result.result =
XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions);
if (experimental_unbounded_dynamism) {
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is
// added.
result.result = XlaHelpers::DynamicUnboundedReshape(
result.result, input, result.rinfo.new_dimensions);
} else {
result.result = XlaHelpers::DynamicReshape(result.result,
result.rinfo.new_dimensions);
}
}
return result;
}
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,9 @@ int64_t XLATensor::GetHandle() const {
}
}

void XLATensor::MarkDynamicDimension(uint32_t dim) {
auto* xla_node = dynamic_cast<XlaNode*>(GetIrValue().node.get());
xla_node->MarkDynamicDimension(dim);
}

} // namespace torch_xla
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class XLATensor : public torch::lazy::LazyTensor {
// Set logical_element_type which is visible to upstream PyTorch.
void SetScalarType(c10::optional<at::ScalarType> logical_element_type);

void MarkDynamicDimension(uint32_t dim);
// We don't use the upstream shape to provide xla::shape.
runtime::util::MaybeRef<xla::Shape> shape() const;

Expand Down
10 changes: 10 additions & 0 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
new_kwargs['device'] = self._device
return super().call_function(target, args, new_kwargs)

def run_node(self, n) -> Any:
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
if n.op == 'placeholder':
fake_t = n.meta['val']
res = super().run_node(n)
for i, x in enumerate(fake_t.shape):
if not isinstance(x, int):
torch_xla._XLAC._xla_mark_dynamic(res, i)
return res
return super().run_node(n)


def _extract_input_args(exported_model, options):
if options.override_tracing_arguments is not None:
Expand Down