Skip to content

Commit

Permalink
Migrate StableHLO to use properties (openxla#2221)
Browse files Browse the repository at this point in the history
Migrate StableHLO to use properties with a few additional changes:

- Constant op's custom fallback parsing needs to be updated to handle
properties _or_ attributes.
- There are PjRT who require handling version skew, and are yet to
migrate to VHLO (in progress, est ~Jun next release), in the meantime we
require inherent IR attribute downgrades (DenseArray->DenseElements in
[openxla/xla/pjrt/mlir_to_hlo.cc](
https://github.com/openxla/xla/blob/aefa3ab3a0613b538e14b449817dce986a765e84/xla/pjrt/mlir_to_hlo.cc#L180)).
This isn't possible in a dialect that is using properties since
properties require that property types are never invalid, so this PR
introduces a DenseArray backed by generic Attribute storage. Cleanup
tasks for this are tracked by openxla#2216

Closes openxla#1584
  • Loading branch information
GleasonK authored Apr 16, 2024
1 parent 126eaf5 commit 8adf5b8
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,10 @@ struct TransposeOpToTransposeConverter final
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());

// TODO(#2216) Cleanup Attribute -> DenseArrayAttr
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
op, adaptor.getOperand(), emptyTensor, op.getPermutationAttr(),
op, adaptor.getOperand(), emptyTensor,
op.getPermutationAttr().dyn_cast_or_null<DenseI64ArrayAttr>(),
linalg::getPrunedAttributeList(op));
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
// Implements logic for lowering StableHLO convolution ops to Linalg dialect.

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -30,12 +31,12 @@ namespace {
/// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(Location loc, Value input,
DenseIntElementsAttr padding,
DenseI64ArrayAttr lhsDilation,
std::optional<ArrayRef<int64_t>> lhsDilation,
llvm::ArrayRef<int64_t> dimMappings,
OpBuilder &rewriter) {
SmallVector<int64_t> lhsDilationValues;
if (lhsDilation)
lhsDilationValues = llvm::to_vector(lhsDilation.asArrayRef());
if (lhsDilation.has_value())
lhsDilationValues = llvm::to_vector(lhsDilation.value());
bool noPadding = !padding || isSplatValue(padding, 0);
bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1);
if (noPadding && noDilation) return input;
Expand Down Expand Up @@ -230,7 +231,7 @@ struct NormalConvolutionOpConversion final
llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
input = applyConvolutionPadding(loc, input, op.getPaddingAttr(),
op.getLhsDilationAttr(), spatialDimMapping,
op.getLhsDilation(), spatialDimMapping,
rewriter);

switch (rank) {
Expand Down Expand Up @@ -350,10 +351,10 @@ struct ConvolutionOpGeneralConversion final
// Decompose the convolution into an initial padding
Value modifiedLhs = applyConvolutionPadding(
op.getLoc(), adaptor.getLhs(), adaptor.getPaddingAttr(),
adaptor.getLhsDilationAttr(),
adaptor.getLhsDilation(),
op.getDimensionNumbers().getInputSpatialDimensions(), rewriter);
Value modifiedRhs = applyConvolutionPadding(
op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilationAttr(),
op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilation(),
op.getDimensionNumbers().getKernelSpatialDimensions(), rewriter);
modifiedRhs = applyConvolutionReversal(loc, rewriter, op, modifiedRhs);

Expand Down Expand Up @@ -591,7 +592,7 @@ struct DepthwiseConvolutionOpConversion final
// Make sure that this is depthwise convolution.
int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension();
int64_t inputFeatureCount =
op.getLhs().getType().getDimSize(inputFeatureDim);
cast<ShapedType>(op.getLhs().getType()).getDimSize(inputFeatureDim);
if (static_cast<int64_t>(op.getFeatureGroupCount()) != inputFeatureCount) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
Expand Down Expand Up @@ -643,10 +644,11 @@ struct DepthwiseConvolutionOpConversion final
llvm::SmallVector<int64_t> spatialDimMapping(spatialRank);
std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
input = applyConvolutionPadding(loc, input, op.getPaddingAttr(),
op.getLhsDilationAttr(), spatialDimMapping,
op.getLhsDilation(), spatialDimMapping,
rewriter);

auto filterDims = llvm::to_vector(op.getRhs().getType().getShape());
auto filterDims =
llvm::to_vector(cast<ShapedType>(op.getRhs().getType()).getShape());

auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) {
SmallVector<ReassociationIndices> reassociations;
Expand Down Expand Up @@ -679,7 +681,8 @@ struct DepthwiseConvolutionOpConversion final
reshapedFilterDims[kernelOutputFeatureDimension] /=
op.getFeatureGroupCount();
auto reshapedFilterType = RankedTensorType::get(
reshapedFilterDims, op.getRhs().getType().getElementType());
reshapedFilterDims,
cast<ShapedType>(op.getRhs().getType()).getElementType());

reshapedFilter = rewriter.create<mlir::stablehlo::ReshapeOp>(
loc, reshapedFilterType, filter);
Expand Down
58 changes: 40 additions & 18 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,16 @@ ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) {
// Parse the generic form.
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRParen()) return failure();
// Parse optional properties
if (succeeded(parser.parseOptionalLess()) &&
(failed(parser.parseAttribute(result.propertiesAttr)) ||
failed(parser.parseGreater())))
return failure();

// Parse optional attributes
if (parser.parseOptionalAttrDict(result.attributes)) return failure();

// Parse type signature
if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
parser.parseArrow())
return failure();
Expand Down Expand Up @@ -685,39 +694,41 @@ ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result) {
//===----------------------------------------------------------------------===//

void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
ArrayRef<int64_t> strides) {
Attribute startIndicesAttr, Attribute limitIndicesAttr,
Attribute stridesAttr) {
auto startIndices = cast<DenseI64ArrayAttr>(startIndicesAttr);
auto limitIndices = cast<DenseI64ArrayAttr>(limitIndicesAttr);
auto strides = cast<DenseI64ArrayAttr>(stridesAttr);
p << "[";
// Let's be safe if we're printing invalid IR somehow: this can't be parsed
// back!
if (startIndices.size() != limitIndices.size() ||
startIndices.size() != strides.size()) {
p << "start_indices: ";
llvm::interleaveComma(startIndices, p);
llvm::interleaveComma(startIndices.asArrayRef(), p);
p << ", limit_indices: ";
llvm::interleaveComma(limitIndices, p);
llvm::interleaveComma(limitIndices.asArrayRef(), p);
p << ", strides: ";
llvm::interleaveComma(strides, p);
llvm::interleaveComma(strides.asArrayRef(), p);
p << "]";
return;
}

llvm::interleaveComma(llvm::zip(startIndices, limitIndices, strides), p,
[&](std::tuple<int64_t, int64_t, int64_t> pack) {
auto [start, limit, stride] = pack;
p << start << ":" << limit;
if (stride != 1) {
p << ":" << stride;
}
});
llvm::interleaveComma(
llvm::zip(startIndices.asArrayRef(), limitIndices.asArrayRef(),
strides.asArrayRef()),
p, [&](std::tuple<int64_t, int64_t, int64_t> pack) {
auto [start, limit, stride] = pack;
p << start << ":" << limit;
if (stride != 1) {
p << ":" << stride;
}
});
p << "]";
}

ParseResult parseSliceRanges(OpAsmParser& parser,
DenseI64ArrayAttr& startIndices,
DenseI64ArrayAttr& limitIndices,
DenseI64ArrayAttr& strides) {
ParseResult parseSliceRanges(OpAsmParser& parser, Attribute& startIndices,
Attribute& limitIndices, Attribute& strides) {
if (parser.parseLSquare()) return failure();
// Parse groups of comma-separated: `start`:`limit`[:`stride`]
// If the stride isn't provided it'll be 1.
Expand Down Expand Up @@ -747,6 +758,17 @@ ParseResult parseSliceRanges(OpAsmParser& parser,
return success();
}

void printDenseI64Array(OpAsmPrinter& p, Operation* op, Attribute attr) {
cast<DenseI64ArrayAttr>(attr).print(p);
}

ParseResult parseDenseI64Array(OpAsmParser& parser, Attribute& attr) {
if ((attr = DenseI64ArrayAttr::parse(parser, Type{}))) {
return success();
}
return failure();
}

ParseResult dimSizeFromString(AsmParser& parser, int64_t& result) {
if (succeeded(parser.parseOptionalQuestion())) {
result = ShapedType::kDynamic;
Expand Down
25 changes: 16 additions & 9 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,23 @@ ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result);
// Attribute Printers and Parsers
//===----------------------------------------------------------------------===//

// TODO(#2216) Cleanup Attribute -> DenseArrayAttr for print/parse.
// SliceRanges - Used to print multi-dimensional ranges for slice.
void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
ArrayRef<int64_t> strides);

ParseResult parseSliceRanges(OpAsmParser& parser,
DenseI64ArrayAttr& startIndices,
DenseI64ArrayAttr& limitIndices,
DenseI64ArrayAttr& strides);
void printSliceRanges(OpAsmPrinter& p, Operation* op, Attribute startIndices,
Attribute limitIndices, Attribute strides);

ParseResult parseSliceRanges(OpAsmParser& parser, Attribute& startIndices,
Attribute& limitIndices, Attribute& strides);

// GenericI64DenseArray - Used to print an attr that can be either
//
// Dense elements:
// { dense<[1, 2]> : tensor<2xi64> }
// Array:
// { array<i64: 1, 2> }
void printDenseI64Array(OpAsmPrinter& p, Operation* op, Attribute attr);

ParseResult parseDenseI64Array(OpAsmParser& parser, Attribute& attr);

// DimSizes - Print an array of ints. Dynamic dimensions printed as `?`.
//
Expand Down
26 changes: 22 additions & 4 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let printer = "printDimSizes($_printer, $_self)";
}

def GenericDenseI64ArrayAttr : Attr<DenseI64ArrayAttr.predicate, "DenseI64ArrayAttr with generic Attribute storage"> {
let storageType = "Attribute";
let valueType = DenseI64ArrayAttr.valueType;
let returnType = "::llvm::ArrayRef<int64_t>";
let baseAttr = DenseI64ArrayAttr;
let convertFromStorage = "$_self.cast<DenseI64ArrayAttr>().asArrayRef()";
let constBuilderCall = "$_builder.getDenseI64ArrayAttr($0)";
}

def GenericDenseBoolArrayAttr : Attr<DenseBoolArrayAttr.predicate, "DenseBoolArrayAttr with generic Attribute storage"> {
let storageType = "Attribute";
let valueType = DenseBoolArrayAttr.valueType;
let returnType = "::llvm::ArrayRef<bool>";
let convertFromStorage = "$_self.cast<DenseBoolArrayAttr>().asArrayRef()";
let constBuilderCall = "$_builder.getDenseBoolArrayAttr($0)";
}


def StableHLO_ScatterDimensionNumbers : AttrDef<StableHLO_Dialect, "ScatterDimensionNumbers"> {
let mnemonic = "scatter";
let summary = "Attribute that models the dimension information for scatter";
Expand Down Expand Up @@ -182,15 +200,15 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
def StableHLO_ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<DenseI64ArrayAttr>:$window_strides,
OptionalAttr<GenericDenseI64ArrayAttr>:$window_strides,
// Default value: two zeros for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<DenseI64ArrayAttr>:$lhs_dilation,
OptionalAttr<GenericDenseI64ArrayAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<DenseI64ArrayAttr>:$rhs_dilation,
OptionalAttr<GenericDenseI64ArrayAttr>:$rhs_dilation,
// Default value: false for each of the spatial dimension.
OptionalAttr<DenseBoolArrayAttr>:$window_reversal,
OptionalAttr<GenericDenseBoolArrayAttr>:$window_reversal,
StableHLO_ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
Expand Down
19 changes: 10 additions & 9 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,7 @@ LogicalResult UniformDequantizeOp::inferReturnTypeComponents(

using mlir::hlo::parseComplexOpType;
using mlir::hlo::parseCustomCallTarget;
using mlir::hlo::parseDenseI64Array;
using mlir::hlo::parseDotDimensionNumbers;
using mlir::hlo::parseExponentMantissa;
using mlir::hlo::parsePairwiseOpType;
Expand All @@ -2357,6 +2358,7 @@ using mlir::hlo::parseVariadicOperandWithAttribute;
using mlir::hlo::parseVariadicSameOperandsAndResultType;
using mlir::hlo::printComplexOpType;
using mlir::hlo::printCustomCallTarget;
using mlir::hlo::printDenseI64Array;
using mlir::hlo::printDotDimensionNumbers;
using mlir::hlo::printExponentMantissa;
using mlir::hlo::printPairwiseOpType;
Expand Down Expand Up @@ -3028,11 +3030,11 @@ void printWindowPadding(OpAsmPrinter& p, DenseElementsAttr padding) {
} // namespace

void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
std::optional<DenseI64ArrayAttr> windowStrides,
std::optional<Attribute> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseI64ArrayAttr> lhsDilation,
std::optional<DenseI64ArrayAttr> rhsDilation,
std::optional<DenseBoolArrayAttr> windowReversal) {
std::optional<Attribute> lhsDilation,
std::optional<Attribute> rhsDilation,
std::optional<Attribute> windowReversal) {
using pair_t = std::pair<Attribute, StringRef>;
std::array<pair_t, 5> printedAttributes = {{
{windowStrides ? *windowStrides : nullptr, "stride"},
Expand Down Expand Up @@ -3064,12 +3066,11 @@ void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
});
}

ParseResult parseWindowAttributes(OpAsmParser& parser,
DenseI64ArrayAttr& windowStrides,
ParseResult parseWindowAttributes(OpAsmParser& parser, Attribute& windowStrides,
DenseIntElementsAttr& padding,
DenseI64ArrayAttr& lhsDilation,
DenseI64ArrayAttr& rhsDilation,
DenseBoolArrayAttr& windowReversal) {
Attribute& lhsDilation,
Attribute& rhsDilation,
Attribute& windowReversal) {
StringRef attributeName;

llvm::StringSet<> allowedAttributeNames{
Expand Down
18 changes: 9 additions & 9 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,20 @@ void printConvolutionDimensions(AsmPrinter &p, Operation *,
ParseResult parseConvolutionDimensions(AsmParser &parser,
ConvDimensionNumbersAttr &dimNums);

// TODO(#2216) Cleanup Attribute -> DenseArrayAttr for print/parse.
// Custom formatting for convolution window attributes.
void printWindowAttributes(OpAsmPrinter &p, Operation *op,
std::optional<DenseI64ArrayAttr> windowStrides,
std::optional<Attribute> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseI64ArrayAttr> lhsDilation,
std::optional<DenseI64ArrayAttr> rhsDilation,
std::optional<DenseBoolArrayAttr> windowReversal);
std::optional<Attribute> lhsDilation,
std::optional<Attribute> rhsDilation,
std::optional<Attribute> windowReversal);

ParseResult parseWindowAttributes(OpAsmParser &parser,
DenseI64ArrayAttr &windowStrides,
ParseResult parseWindowAttributes(OpAsmParser &parser, Attribute &windowStrides,
DenseIntElementsAttr &padding,
DenseI64ArrayAttr &lhsDilation,
DenseI64ArrayAttr &rhsDilation,
DenseBoolArrayAttr &windowReversal);
Attribute &lhsDilation,
Attribute &rhsDilation,
Attribute &windowReversal);

} // end namespace stablehlo
} // end namespace mlir
Expand Down
Loading

0 comments on commit 8adf5b8

Please sign in to comment.