Skip to content

Commit

Permalink
[XeTileToXeGPU] Add generic lowering pattern for Xetile elementwise o…
Browse files Browse the repository at this point in the history
…ps (arith, math dialect ops).

Allows Arith/math dialect pre/post-ops to be decomposed (blocked and updated).

Add a generic element-wise conversion pattern that can tackle all operations.
It removes multiple patterns necessary to handle different kind of ops.
It also supports passing attributes (e.g., fastmath).

Add support for all the operations in Arith dialect.
  • Loading branch information
mshahneo committed Nov 25, 2024
1 parent 5d72a84 commit 7f0c8c0
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 238 deletions.
304 changes: 139 additions & 165 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,131 +998,88 @@ struct SgTransposeOpPattern : public XeOneToNConversion<OpTy> {
};

bool isLegalElementWiseOp(mlir::Operation *op) {
auto res = op->getResult(0);
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (resType && resType.getRank() > 2)
return false;
return true;
// Check that all results are of vector type and has rank > 2.
auto numResults = op->getNumResults();
for (unsigned i = 0; i < numResults; i++) {
auto res = op->getResult(i);
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (!resType || resType.getRank() <= 2)
return true;
}
return false;
}

template <typename Op, int numOperands>
Op createOp(XeOneToNPatternRewriter &rewriter, mlir::Location loc,
llvm::SmallVector<llvm::SmallVector<mlir::Value>> operands, int i) {
static_assert(numOperands >= 1 && numOperands <= 3,
"Unsupported number of operands");

if constexpr (numOperands == 1) {
return rewriter.create<Op>(loc, operands[0][i]);
} else if constexpr (numOperands == 2) {
return rewriter.create<Op>(loc, operands[0][i], operands[1][i]);
} else if constexpr (numOperands == 3) {
return rewriter.create<Op>(loc, operands[0][i], operands[1][i],
operands[2][i]);
// Convert a llvm::ArrayRef of operands range, where each range consists of a
// list of same operand, To a llvm::ArrayRef of operand range, where the range
// is created from element from each list of operand.

llvm::SmallVector<llvm::SmallVector<mlir::Value>>
verticalToHorizontalToValueRange(llvm::ArrayRef<mlir::ValueRange> operands) {
auto numBlocks = operands[0].size();
llvm::SmallVector<llvm::SmallVector<mlir::Value>> values(numBlocks);
for (auto operand : operands) {
for (unsigned i = 0; i < operand.size(); i++) {
values[i].push_back(operand[i]);
}
}
return values;
}

template <typename Op, int numOperands>
template <typename Op>
struct ElementWiseOpPattern : public XeOneToNConversion<Op> {

using XeOneToNConversion<Op>::XeOneToNConversion;
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
using OpAdaptor = typename Op::template GenericAdaptor<RangeT>;

mlir::LogicalResult
matchAndRewrite(Op op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
// non-vector ops, or 1D/2D vector ops generated during lowering.
if (!resType || resType.getRank() <= 2)
return mlir::failure();

// For non 2D vector ops, we expect 4D vector ops only
if (resType.getRank() != 4) {
op.emitOpError() << "type is not 4D vector";
return mlir::failure();
auto numResults = op.getOperation()->getNumResults();

llvm::SmallVector<mlir::Type> newResultTypes;
for (unsigned i = 0; i < numResults; i++) {
mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>(
*op.getODSResults(i).begin());
auto resultType = mlir::dyn_cast<mlir::VectorType>(result.getType());
// Check if the result types are 4D vectors, if any of the result type is
// not a 4D vector, return failure.
if (!resultType || resultType.getRank() != 4)
return mlir::failure();

auto shape = resultType.getShape();
// Get the new result type, this is the type of the result of the new
// blocked op that works on the 2-D vector.
auto vecTy = mlir::VectorType::get({shape[2], shape[3]},
resultType.getElementType());
newResultTypes.push_back(vecTy);
}

auto shape = resType.getShape();
auto newTy =
mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType());

// Get all the slices of Operands
// Get the operands
auto operands = adaptor.getOperands();

llvm::SmallVector<llvm::SmallVector<mlir::Value>> operand;
if (numOperands == 1)
operand.push_back(operands[0]);
else if (numOperands == 2) {
operand.push_back(operands[0]);
operand.push_back(operands[1]);
} else {
operand.push_back(operands[0]);
operand.push_back(operands[1]);
operand.push_back(operands[2]);
}

// The operands are in the form of llvm::ArrayRef<mlir::ValueRange>, where
// each ValueRange consists of a list of same operand. However, to use the
// operands in the new op, the operands of the same block should together in
// a ValueRange (Vector of operands of the each block should be in a
// vector).
auto horizontalOperands = verticalToHorizontalToValueRange(operands);
// Get the attributes
auto attributes = op.getOperation()->getAttrs();
Op newOp;
llvm::SmallVector<mlir::Value> newOps;
for (int i = 0; i < shape[0] * shape[1]; i++) {
auto newOp = createOp<Op, numOperands>(rewriter, op.getLoc(), operand, i);
newOp->getResult(0).setType(newTy);
newOps.push_back(newOp);
}

rewriter.replaceOp(op, newOps);
return mlir::success();
}
};

template <typename CastOp>
struct TypecastOpPattern : public XeOneToNConversion<CastOp> {
using XeOneToNConversion<CastOp>::XeOneToNConversion;
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
using OpAdaptor = typename CastOp::template GenericAdaptor<RangeT>;

mlir::LogicalResult
matchAndRewrite(CastOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto out = mlir::dyn_cast<mlir::VectorType>(op.getType());
if (!out || out.getRank() != 4)
return mlir::failure();

auto shape = out.getShape();
auto vecTy =
mlir::VectorType::get({shape[2], shape[3]}, out.getElementType());
auto inputs = adaptor.getIn();
llvm::SmallVector<mlir::Value> newOps;
for (auto in : inputs) {
auto newOp = rewriter.create<CastOp>(op.getLoc(), vecTy, in);
newOps.push_back(newOp);
}
rewriter.replaceOp(op, newOps);
return mlir::success();
}
};

struct SgArithCmpIOpPattern : public XeOneToNConversion<mlir::arith::CmpIOp> {
using XeOneToNConversion<mlir::arith::CmpIOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (!resType || resType.getRank() != 4)
return mlir::failure();

auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();

llvm::SmallVector<mlir::Value> newOps;
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto newOp = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), op.getPredicate(), l, r);
newOps.push_back(newOp);
for (auto newOperands : horizontalOperands) {
// We are using the generic builder that is supported by all ops.
// static void build(::mlir::OpBuilder &, ::mlir::OperationState
// &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
// ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
newOp = rewriter.create<Op>(op.getLoc(), newResultTypes, newOperands,
attributes);
for (unsigned i = 0; i < numResults; i++) {
mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>(
*newOp.getODSResults(i).begin());
newOps.push_back(result);
}
}

rewriter.replaceOp(op, newOps);
return mlir::success();
}
Expand Down Expand Up @@ -1226,10 +1183,10 @@ struct SgVectorCreateMaskOpPattern : public XeOneToNConversion<CreateMaskOp> {
if (constDef && constDef.value() == shape[0]) {
// Case 1: all rows are enabled.
// See assumptions about the supported create_mask op in
// VectorCreateMaskOpPattern in xetile blocking pass. The second and forth
// operands are the same. This value is the mask of the inner dimension of
// the original shape. Different masks are created based on the new inner
// dimension size.
// VectorCreateMaskOpPattern in xetile blocking pass. The second and
// forth operands are the same. This value is the mask of the inner
// dimension of the original shape. Different masks are created based on
// the new inner dimension size.
auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
llvm::SmallVector<llvm::SmallVector<mlir::Value>> newOperands;
mlir::Value mask = adaptor.getOperands()[3][0];
Expand Down Expand Up @@ -1287,63 +1244,80 @@ struct SgVectorSplatOpPattern : public XeOneToNConversion<SplatOp> {
void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.insert<
SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern,
SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern,
SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern,
SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern,
SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis);
patterns
.add<SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern,
SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern,
SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern,
SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
patterns.getContext(), converter, analysis);

// Element-wise math operations
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
ElementWiseOpPattern<mlir::math::SinOp, 1>,
ElementWiseOpPattern<mlir::math::CosOp, 1>,
ElementWiseOpPattern<mlir::math::SqrtOp, 1>,
ElementWiseOpPattern<mlir::math::TanhOp, 1>,
ElementWiseOpPattern<mlir::math::LogOp, 1>,
ElementWiseOpPattern<mlir::math::RsqrtOp, 1>,
ElementWiseOpPattern<mlir::math::ErfOp, 1>,
ElementWiseOpPattern<mlir::math::PowFOp, 2>>(
patterns.getContext(), converter, analysis);

// Element-wise arithmetic operations
patterns.insert<ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
ElementWiseOpPattern<mlir::arith::AddIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivFOp, 2>,
ElementWiseOpPattern<mlir::arith::DivSIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MulFOp, 2>,
ElementWiseOpPattern<mlir::arith::MulIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaximumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinimumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MinSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinUIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
ElementWiseOpPattern<mlir::arith::RemSIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemUIOp, 2>,
ElementWiseOpPattern<mlir::arith::SubFOp, 2>,
ElementWiseOpPattern<mlir::arith::SubIOp, 2>,
ElementWiseOpPattern<mlir::arith::AndIOp, 2>,
ElementWiseOpPattern<mlir::arith::XOrIOp, 2>,
ElementWiseOpPattern<mlir::arith::SelectOp, 3>>(
patterns.getContext(), converter, analysis);
patterns.insert<TypecastOpPattern<mlir::arith::ExtFOp>,
TypecastOpPattern<mlir::arith::ExtSIOp>,
TypecastOpPattern<mlir::arith::ExtUIOp>,
TypecastOpPattern<mlir::arith::FPToSIOp>,
TypecastOpPattern<mlir::arith::FPToUIOp>,
TypecastOpPattern<mlir::arith::IndexCastOp>,
TypecastOpPattern<mlir::arith::IndexCastUIOp>,
TypecastOpPattern<mlir::arith::SIToFPOp>,
TypecastOpPattern<mlir::arith::UIToFPOp>,
TypecastOpPattern<mlir::arith::TruncFOp>,
TypecastOpPattern<mlir::arith::TruncIOp>>(
patterns.add<ElementWiseOpPattern<mlir::math::ExpOp>,
ElementWiseOpPattern<mlir::math::SinOp>,
ElementWiseOpPattern<mlir::math::CosOp>,
ElementWiseOpPattern<mlir::math::SqrtOp>,
ElementWiseOpPattern<mlir::math::TanhOp>,
ElementWiseOpPattern<mlir::math::LogOp>,
ElementWiseOpPattern<mlir::math::RsqrtOp>,
ElementWiseOpPattern<mlir::math::ErfOp>,
ElementWiseOpPattern<mlir::math::PowFOp>>(patterns.getContext(),
converter, analysis);

// Arithmetic operations
patterns.add<ElementWiseOpPattern<mlir::arith::AddFOp>,
ElementWiseOpPattern<mlir::arith::AddIOp>,
ElementWiseOpPattern<mlir::arith::AddUIExtendedOp>,
ElementWiseOpPattern<mlir::arith::AndIOp>,
ElementWiseOpPattern<mlir::arith::CeilDivSIOp>,
ElementWiseOpPattern<mlir::arith::CeilDivUIOp>,
ElementWiseOpPattern<mlir::arith::CmpFOp>,
ElementWiseOpPattern<mlir::arith::CmpIOp>,
ElementWiseOpPattern<mlir::arith::DivFOp>,
ElementWiseOpPattern<mlir::arith::DivSIOp>,
ElementWiseOpPattern<mlir::arith::DivUIOp>,
ElementWiseOpPattern<mlir::arith::FloorDivSIOp>,
ElementWiseOpPattern<mlir::arith::MaximumFOp>,
ElementWiseOpPattern<mlir::arith::MaxNumFOp>,
ElementWiseOpPattern<mlir::arith::MaxSIOp>,
ElementWiseOpPattern<mlir::arith::MaxUIOp>,
ElementWiseOpPattern<mlir::arith::MinimumFOp>,
ElementWiseOpPattern<mlir::arith::MinNumFOp>,
ElementWiseOpPattern<mlir::arith::MinSIOp>,
ElementWiseOpPattern<mlir::arith::MinUIOp>,
ElementWiseOpPattern<mlir::arith::MulFOp>,
ElementWiseOpPattern<mlir::arith::MulIOp>,
ElementWiseOpPattern<mlir::arith::MulSIExtendedOp>,
ElementWiseOpPattern<mlir::arith::MulUIExtendedOp>,
ElementWiseOpPattern<mlir::arith::NegFOp>,
ElementWiseOpPattern<mlir::arith::OrIOp>,
ElementWiseOpPattern<mlir::arith::RemFOp>,
ElementWiseOpPattern<mlir::arith::RemSIOp>,
ElementWiseOpPattern<mlir::arith::RemUIOp>,
ElementWiseOpPattern<mlir::arith::SelectOp>,
ElementWiseOpPattern<mlir::arith::ShLIOp>,
ElementWiseOpPattern<mlir::arith::ShRSIOp>,
ElementWiseOpPattern<mlir::arith::ShRUIOp>,
ElementWiseOpPattern<mlir::arith::SubFOp>,
ElementWiseOpPattern<mlir::arith::SubIOp>,
ElementWiseOpPattern<mlir::arith::XOrIOp>>(patterns.getContext(),
converter, analysis);

// Typecast operations
patterns.add<ElementWiseOpPattern<mlir::arith::BitcastOp>,
ElementWiseOpPattern<mlir::arith::ExtFOp>,
ElementWiseOpPattern<mlir::arith::ExtSIOp>,
ElementWiseOpPattern<mlir::arith::ExtUIOp>,
ElementWiseOpPattern<mlir::arith::FPToSIOp>,
ElementWiseOpPattern<mlir::arith::FPToUIOp>,
ElementWiseOpPattern<mlir::arith::IndexCastOp>,
ElementWiseOpPattern<mlir::arith::IndexCastUIOp>,
ElementWiseOpPattern<mlir::arith::SIToFPOp>,
ElementWiseOpPattern<mlir::arith::TruncFOp>,
ElementWiseOpPattern<mlir::arith::TruncIOp>,
ElementWiseOpPattern<mlir::arith::UIToFPOp>>(
patterns.getContext(), converter, analysis);
}

Expand Down
Loading

0 comments on commit 7f0c8c0

Please sign in to comment.