Skip to content

Commit

Permalink
Port *ArrayOr* -> Dense*ArrayAttr in CHLO and SHLO (openxla#1952)
Browse files Browse the repository at this point in the history
This completes the migration to `DenseI64ArrayAttr` started in
openxla#1658 (superseded by
openxla#1872) and announced in
https://groups.google.com/a/openxla.org/g/openxla-discuss/c/hEoA4V5DZF0

Fixes openxla#1578
  • Loading branch information
mlevesquedion authored Jan 26, 2024
1 parent e37b4b0 commit 7eb9902
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ namespace {
/// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(Location loc, Value input,
DenseIntElementsAttr padding,
Attribute lhsDilation,
DenseI64ArrayAttr lhsDilation,
llvm::ArrayRef<int64_t> dimMappings,
OpBuilder &rewriter) {
SmallVector<int64_t> lhsDilationValues;
if (lhsDilation) lhsDilationValues = hlo::getI64Array(lhsDilation);
if (lhsDilation)
lhsDilationValues = llvm::to_vector(lhsDilation.asArrayRef());
bool noPadding = !padding || isSplatValue(padding, 0);
bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1);
if (noPadding && noDilation) return input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct ReduceOpToGenericConverter final
}
auto srcRank = cast<ShapedType>(adaptor.getInputs()[0].getType()).getRank();

SmallVector<int64_t> reductionDims = op.getDimensions();
SmallVector<int64_t> reductionDims = llvm::to_vector(op.getDimensions());

SmallVector<Type> resultTypes;
if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
Expand Down Expand Up @@ -226,7 +226,7 @@ struct ReduceOpToReduceConverter final
"unsupported reduce (noop or empty)");
}

auto reductionDims = op.getDimensions();
auto reductionDims = llvm::to_vector(op.getDimensions());
// stablehlo.reduce doesn't specify the order of the reduction dimensions.
llvm::sort(reductionDims);

Expand Down Expand Up @@ -346,17 +346,17 @@ struct ReduceWindowOpOnTensorsGenericConversion final

llvm::SmallVector<int64_t> baseDilations;
if (op.getBaseDilations()) {
baseDilations = *op.getBaseDilations();
baseDilations = llvm::to_vector(*op.getBaseDilations());
}

llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
if (op.getWindowStrides()) {
windowStrides = *op.getWindowStrides();
windowStrides = llvm::to_vector(*op.getWindowStrides());
}

llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
if (op.getWindowDilations()) {
windowDilations = *op.getWindowDilations();
windowDilations = llvm::to_vector(*op.getWindowDilations());
}

auto rank = static_cast<int64_t>(windowDimensions.size());
Expand Down
40 changes: 0 additions & 40 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,46 +256,6 @@ ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue,
// Attribute Printers and Parsers
//===----------------------------------------------------------------------===//

void printDenseI64Array(OpAsmPrinter& p, Operation* op,
DenseIntElementsAttr attr) {
if (attr.getType().getRank() != 1)
llvm::report_fatal_error("printDenseI64Array only supports rank-1 arrays");
auto values = llvm::to_vector(attr.getValues<int64_t>());
DenseI64ArrayAttr arrayAttr =
DenseI64ArrayAttr::get(op->getContext(), values);
arrayAttr.print(p);
}

ParseResult parseDenseI64Array(OpAsmParser& parser,
DenseIntElementsAttr& attr) {
DenseI64ArrayAttr arrayAttr = DenseI64ArrayAttr::parse(parser, Type{})
.dyn_cast_or_null<DenseI64ArrayAttr>();
if (!arrayAttr) return failure();

ArrayRef<int64_t> data = arrayAttr.asArrayRef();
RankedTensorType type =
RankedTensorType::get(data.size(), parser.getBuilder().getI64Type());
attr = DenseIntElementsAttr::get(type, data);
return success();
}

void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op,
Attribute attr) {
if (auto elems = dyn_cast<DenseIntElementsAttr>(attr)) {
printDenseI64Array(p, op, elems);
return;
}
dyn_cast<DenseI64ArrayAttr>(attr).print(p);
}

ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser,
Attribute& attr) {
if ((attr = DenseI64ArrayAttr::parse(parser, Type{}))) {
return success();
}
return failure();
}

void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
Expand Down
25 changes: 0 additions & 25 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,6 @@ ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue,
// Attribute Printers and Parsers
//===----------------------------------------------------------------------===//

// DenseI64Array - Used to print DenseIntElementsAttrs that are verified to have
// rank 1 as an i64 array without needing the dense specifier or type specifier.
//
// Generic:
// { dense<[1, 2]> : tensor<2xi64> }
// Custom:
// [1, 2]
void printDenseI64Array(OpAsmPrinter& p, Operation* op,
DenseIntElementsAttr attr);

ParseResult parseDenseI64Array(OpAsmParser& parser, DenseIntElementsAttr& attr);

// I64DenseArrayOrElements1D - Used to print an attr that can be either
// I64ElementsAttr (DenseIntElementsAttr) or DenseI64ArrayAttr.
//
// Dense elements:
// { dense<[1, 2]> : tensor<2xi64> }
// Array:
// { array<i64: 1, 2> }
void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op,
Attribute attr);

ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser,
Attribute& attr);

// SliceRanges - Used to print multi-dimensional ranges for slice.
void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
Expand Down
25 changes: 0 additions & 25 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,30 +605,5 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

SmallVector<int64_t> getI64Array(Attribute attr) {
if (!attr) return {};
if (auto elements = attr.dyn_cast<DenseIntElementsAttr>())
return llvm::to_vector(elements.getValues<int64_t>());
if (auto array = attr.dyn_cast<DenseI64ArrayAttr>())
return llvm::to_vector(array.asArrayRef());
llvm::report_fatal_error(
"called getI64Array on Attribute that was neither a "
"DenseIntElementsAttr or a DenseI64ArrayAttr",
false);
}

SmallVector<bool> getBoolArray(Attribute attr) {
if (!attr) return {};
if (auto elements = attr.dyn_cast<DenseIntOrFPElementsAttr>())
return llvm::to_vector(elements.getValues<bool>());
if (auto array = attr.dyn_cast<DenseBoolArrayAttr>()) {
return SmallVector<bool>(array.asArrayRef());
}
llvm::report_fatal_error(
"called getBoolArray on Attribute that was neither a "
"DenseIntOrFPElementsAttr or a DenseBoolArrayAttr",
false);
}

} // namespace hlo
} // namespace mlir
19 changes: 0 additions & 19 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,8 @@ inline static bool isStaticDimSize(int64_t val) {
}

// Checks whether every position in the given array contains the given value.
// This is especially useful for dealing with instances of
// I64DenseArrayOrElements1DAttr, which returns a SmallVector<int64_t> as its
// value no matter what actual attribute is backing it.
// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr
// have been removed.
bool isSplatArray(ArrayRef<int64_t> arr, int64_t val);

// Returns a vector of the int64 values in a I64DenseArrayOrElements1DAttr.
// Such an Attr can be backed by either a 1-dimensional DenseIntElementsAttr or
// a DenseI64ArrayAttr.
// TODO(#1578): Remove this code once all uses of I64DenseArrayOrElements1DAttr
// have been removed.
SmallVector<int64_t> getI64Array(Attribute);

// Returns a vector of the bool values in a BoolDenseArrayOrElementsAttr.
// Such an Attr can be backed by either a DenseIntOrFPElementsAttr or
// a DenseBoolArrayAttr.
// TODO(#1578): Remove this code once all uses of BoolDenseArrayOrElementsAttr
// have been removed.
SmallVector<bool> getBoolArray(Attribute);

// Verifies that the two types have compatible shape with bounds but allows
// different element types.
LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2);
Expand Down
13 changes: 0 additions & 13 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,4 @@ def HLO_BoundedAttrInterface : AttrInterface<"BoundedAttrInterface"> {
>];
}

//===----------------------------------------------------------------------===//
// Common attrs.
//===----------------------------------------------------------------------===//

def I64Elements1D : And<[I64ElementsAttr.predicate, CPred<"$_self.cast<DenseIntElementsAttr>().getType().getRank() == 1">]>;

// TODO(#1578) migrate uses to DenseI64ArrayAttr and delete this attr
def I64DenseArrayOrElements1DAttr : Attr<Or<[DenseI64ArrayAttr.predicate, I64Elements1D]>, "either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr."> {
let storageType = "mlir::Attribute";
let returnType = "SmallVector<int64_t>";
let convertFromStorage = "hlo::getI64Array($_self)";
}

#endif // STABLEHLO_DIALECT_BASE
2 changes: 1 addition & 1 deletion stablehlo/dialect/BroadcastUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace mlir {
namespace hlo {

bool isLegalNumpyRankedBroadcast(Value lhs, Value rhs,
ArrayRef<int64_t> broadcastDimensions) {
llvm::ArrayRef<int64_t> broadcastDimensions) {
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType) return false;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/BroadcastUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace hlo {
// to the smaller ranked operand until it is of the same rank as the larger).
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
bool isLegalNumpyRankedBroadcast(Value lhs, Value rhs,
ArrayRef<int64_t> broadcastDims);
llvm::ArrayRef<int64_t> broadcastDims);

// Emits shape dialect ops to compute the result shape for a broadcasting
// binary/n-ary elementwise op which broadcasts according to "numpy" semantics
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/InliningUtils.h"
Expand Down Expand Up @@ -165,7 +166,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
auto broadcastDimensionsAttr = op->getAttr("broadcast_dimensions");
if (broadcastDimensionsAttr &&
!hlo::isLegalNumpyRankedBroadcast(
lhs, rhs, hlo::getI64Array(broadcastDimensionsAttr))) {
lhs, rhs, broadcastDimensionsAttr.cast<mlir::DenseI64ArrayAttr>())) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
Expand Down Expand Up @@ -212,7 +213,7 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(

void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs,
Attribute broadcastDimensions,
DenseI64ArrayAttr broadcastDimensions,
chlo::ComparisonDirection comparisonDirection,
chlo::ComparisonType compareType) {
build(builder, result, lhs, rhs, broadcastDimensions,
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class CHLO_BroadcastBinaryElementwiseOp<
HLO_Tensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$broadcast_dimensions
OptionalAttr<DenseI64ArrayAttr>:$broadcast_dimensions
);

let results = (outs HLO_Tensor);
Expand Down Expand Up @@ -313,7 +313,7 @@ def CHLO_BroadcastZetaOp : CHLO_BroadcastBinaryElementwiseOp<
HLO_FpTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$broadcast_dimensions
OptionalAttr<DenseI64ArrayAttr>:$broadcast_dimensions
);
let results = (outs HLO_FpTensor);
}
Expand All @@ -331,7 +331,7 @@ class CHLO_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
HLO_PredOrIntTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$broadcast_dimensions
OptionalAttr<DenseI64ArrayAttr>:$broadcast_dimensions
);
}

Expand Down Expand Up @@ -448,7 +448,7 @@ def CHLO_BroadcastComplexOp : CHLO_BroadcastBinaryElementwiseOp<
HLO_FpTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$broadcast_dimensions
OptionalAttr<DenseI64ArrayAttr>:$broadcast_dimensions
);
let results = (outs HLO_ComplexTensor);
}
Expand Down Expand Up @@ -755,15 +755,15 @@ def CHLO_BroadcastCompareOp : CHLO_BroadcastBinaryElementwiseOp<
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$broadcast_dimensions,
OptionalAttr<DenseI64ArrayAttr>:$broadcast_dimensions,
CHLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<CHLO_ComparisonTypeAttr>:$compare_type
);
let results = (outs HLO_PredTensor);

let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
"Attribute":$broadcast_dimensions,
"DenseI64ArrayAttr":$broadcast_dimensions,
"::mlir::chlo::ComparisonDirection":$comparison_direction,
CArg<"::mlir::chlo::ComparisonType",
"::mlir::chlo::ComparisonType::NOTYPE">:$compare_type)>,
Expand Down
25 changes: 4 additions & 21 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,35 +179,18 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def StableHLO_BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];

let convertFromStorage = "$_self";
}

def BoolDenseArrayOrElementsAttr : Attr<Or<[DenseBoolArrayAttr.predicate, StableHLO_BoolElementsAttr.predicate]>, "either a DenseBoolArrayAttr or a StableHLO_BoolElementsAttr"> {
let storageType = "mlir::Attribute";
let returnType = "SmallVector<bool>";
let convertFromStorage = "hlo::getBoolArray($_self)";
}

def StableHLO_ConvolutionAttributes {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$window_strides,
OptionalAttr<DenseI64ArrayAttr>:$window_strides,
// Default value: two zeros for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$lhs_dilation,
OptionalAttr<DenseI64ArrayAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64DenseArrayOrElements1DAttr>:$rhs_dilation,
OptionalAttr<DenseI64ArrayAttr>:$rhs_dilation,
// Default value: false for each of the spatial dimension.
OptionalAttr<BoolDenseArrayOrElementsAttr>:$window_reversal,
OptionalAttr<DenseBoolArrayAttr>:$window_reversal,
StableHLO_ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
Expand Down
Loading

0 comments on commit 7eb9902

Please sign in to comment.