-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable passing down dynamic dimensions from torch to XLA #5790
Changes from 6 commits
56f698d
24f42c4
4acf02d
495f844
af14415
73ec1c5
f2e65aa
6b8072d
07d2b43
7c0ac9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,12 @@ | |
namespace torch_xla { | ||
namespace { | ||
|
||
static const bool experimental_unbounded_dynamism = | ||
lsy323 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); | ||
|
||
// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. | ||
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min(); | ||
|
||
xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, | ||
xla::XlaOp result) { | ||
xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1); | ||
|
@@ -63,6 +69,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, | |
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input); | ||
for (size_t i = 0; i < dimensions.size(); ++i) { | ||
bcast_sizes.at(dimensions[i]) = sizes[i]; | ||
if (experimental_unbounded_dynamism) { | ||
XLA_CHECK(sizes[i] != kUnboundedSize); | ||
} | ||
} | ||
return xla::BroadcastInDim(input, bcast_sizes, | ||
GetAllDimensions(bcast_sizes.size())); | ||
|
@@ -322,6 +331,57 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, | |
: xla::Reshape(input, shape.dimensions()); | ||
} | ||
|
||
bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { | ||
XLA_CHECK(experimental_unbounded_dynamism) | ||
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; | ||
const absl::Span<const int64_t> dims = shape.dimensions(); | ||
return std::any_of(dims.begin(), dims.end(), | ||
[](int64_t size) { return size == kUnboundedSize; }); | ||
} | ||
|
||
xla::XlaOp XlaHelpers::DynamicUnboundedReshape( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please prepare an opset issue to track supporting this mode of dynamism? Similar work example: #5764 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I think we have a list internally There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work to maintain the list on GH? |
||
xla::XlaOp input, xla::XlaOp aux_input, | ||
absl::Span<const int64_t> output_sizes) { | ||
XLA_CHECK(experimental_unbounded_dynamism) | ||
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; | ||
lsy323 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); | ||
XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) | ||
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; | ||
std::vector<xla::XlaOp> get_dim_ops; | ||
std::vector<xla::XlaOp> reshaped_ops; | ||
bool all_static = true; | ||
std::vector<bool> output_dynamic(output_sizes.size(), false); | ||
|
||
for (int i = 0; i < output_sizes.size(); i++) { | ||
if (output_sizes[i] == kUnboundedSize) { | ||
output_dynamic[i] = true; | ||
get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i)); | ||
all_static = false; | ||
} else { | ||
get_dim_ops.push_back(XlaHelpers::ScalarValue<int32_t>( | ||
output_sizes[i], aux_input.builder())); | ||
} | ||
} | ||
|
||
if (all_static) { | ||
return xla::Reshape(input, output_sizes); | ||
} | ||
|
||
// Create the reshape from scalar to 1-D vector | ||
for (auto get_dim_op : get_dim_ops) { | ||
reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); | ||
} | ||
|
||
// Create Concatenate op | ||
auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); | ||
return xla::CustomCall( | ||
aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op}, | ||
xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes, | ||
output_dynamic)); | ||
|
||
return input; | ||
} | ||
|
||
bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, | ||
const xla::Shape& shape2) { | ||
return shape1.is_static() && shape2.is_static() && | ||
|
@@ -485,6 +545,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, | |
runtime::util::ToVector<int64_t>(shape1.dimensions()), | ||
runtime::util::ToVector<int64_t>(shape2.dimensions()))); | ||
} | ||
if (experimental_unbounded_dynamism) { | ||
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && | ||
!XlaHelpers::IsUnboundedDynamic(shape2)) | ||
<< "Unreachable for unbounded dynamic code\n"; | ||
} | ||
return GetPromotedDynamicShape(shape1, shape2); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,9 @@ struct SummationResult { | |
xla::XlaOp result; | ||
}; | ||
|
||
static const bool experimental_unbounded_dynamism = | ||
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, for adding env variables, please keep this file up to date. https://github.com/pytorch/xla/blob/master/configuration.yaml There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current thought is to use env variable to limit the code path to "Export Only", as you've pointed out in many places. If there are better mechanisms to accomplish that we can also use something different. |
||
|
||
ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape, | ||
absl::Span<const int64_t> dimensions, | ||
bool keep_reduced_dimensions) { | ||
|
@@ -81,7 +84,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, | |
xla::XlaOp scale = xla::Select(xla::Ne(count, zero), | ||
one / xla::ConvertElementType(count, type), | ||
xla::NanValue(input.builder(), type)); | ||
return input * scale; | ||
|
||
if (experimental_unbounded_dynamism) { | ||
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. | ||
// TODO(lsy323): Remove this branch once the support is added in XLA. | ||
auto promoted = XlaHelpers::Promote(input, scale); | ||
return promoted.first * promoted.second; | ||
} else { | ||
return input * scale; | ||
} | ||
} | ||
|
||
xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { | ||
|
@@ -109,8 +120,15 @@ SummationResult CreateSummation(xla::XlaOp input, | |
result.result, result.rinfo.element_count.size, shape.element_type()); | ||
} | ||
if (keep_reduced_dimensions) { | ||
result.result = | ||
XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); | ||
if (experimental_unbounded_dynamism) { | ||
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is | ||
// added. | ||
result.result = XlaHelpers::DynamicUnboundedReshape( | ||
result.result, input, result.rinfo.new_dimensions); | ||
} else { | ||
result.result = XlaHelpers::DynamicReshape(result.result, | ||
result.rinfo.new_dimensions); | ||
} | ||
} | ||
return result; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this path execute in the non export path?
Wdyt we limit this experimental condition to torch.export path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the discussion below, seems there is no better solution than using a env variable to enable the unbounded dynamism lowering. But we can put more thought on this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. This issue doesn't block this PR. Let's prepare a proposal in parallel.