Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XeTileToXeGPU] Add generic lowering pattern for Xetile elementwise o… #970

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 139 additions & 165 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,131 +998,88 @@ struct SgTransposeOpPattern : public XeOneToNConversion<OpTy> {
};

bool isLegalElementWiseOp(mlir::Operation *op) {
auto res = op->getResult(0);
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (resType && resType.getRank() > 2)
return false;
return true;
// Check that all results are of vector type and has rank > 2.
auto numResults = op->getNumResults();
for (unsigned i = 0; i < numResults; i++) {
auto res = op->getResult(i);
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (!resType || resType.getRank() <= 2)
return true;
}
return false;
}

template <typename Op, int numOperands>
Op createOp(XeOneToNPatternRewriter &rewriter, mlir::Location loc,
llvm::SmallVector<llvm::SmallVector<mlir::Value>> operands, int i) {
static_assert(numOperands >= 1 && numOperands <= 3,
"Unsupported number of operands");

if constexpr (numOperands == 1) {
return rewriter.create<Op>(loc, operands[0][i]);
} else if constexpr (numOperands == 2) {
return rewriter.create<Op>(loc, operands[0][i], operands[1][i]);
} else if constexpr (numOperands == 3) {
return rewriter.create<Op>(loc, operands[0][i], operands[1][i],
operands[2][i]);
// Convert a llvm::ArrayRef of operands range, where each range consists of a
// list of same operand, To a llvm::ArrayRef of operand range, where the range
// is created from element from each list of operand.

llvm::SmallVector<llvm::SmallVector<mlir::Value>>
verticalToHorizontalToValueRange(llvm::ArrayRef<mlir::ValueRange> operands) {
auto numBlocks = operands[0].size();
llvm::SmallVector<llvm::SmallVector<mlir::Value>> values(numBlocks);
for (auto operand : operands) {
for (unsigned i = 0; i < operand.size(); i++) {
values[i].push_back(operand[i]);
}
}
return values;
}

template <typename Op, int numOperands>
template <typename Op>
struct ElementWiseOpPattern : public XeOneToNConversion<Op> {

using XeOneToNConversion<Op>::XeOneToNConversion;
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
using OpAdaptor = typename Op::template GenericAdaptor<RangeT>;

mlir::LogicalResult
matchAndRewrite(Op op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
// non-vector ops, or 1D/2D vector ops generated during lowering.
if (!resType || resType.getRank() <= 2)
return mlir::failure();

// For non 2D vector ops, we expect 4D vector ops only
if (resType.getRank() != 4) {
op.emitOpError() << "type is not 4D vector";
return mlir::failure();
auto numResults = op.getOperation()->getNumResults();

llvm::SmallVector<mlir::Type> newResultTypes;
for (unsigned i = 0; i < numResults; i++) {
mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>(
*op.getODSResults(i).begin());
auto resultType = mlir::dyn_cast<mlir::VectorType>(result.getType());
// Check if the result types are 4D vectors, if any of the result type is
// not a 4D vector, return failure.
if (!resultType || resultType.getRank() != 4)
return mlir::failure();

auto shape = resultType.getShape();
// Get the new result type, this is the type of the result of the new
// blocked op that works on the 2-D vector.
auto vecTy = mlir::VectorType::get({shape[2], shape[3]},
resultType.getElementType());
newResultTypes.push_back(vecTy);
}

auto shape = resType.getShape();
auto newTy =
mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType());

// Get all the slices of Operands
// Get the operands
auto operands = adaptor.getOperands();

llvm::SmallVector<llvm::SmallVector<mlir::Value>> operand;
if (numOperands == 1)
operand.push_back(operands[0]);
else if (numOperands == 2) {
operand.push_back(operands[0]);
operand.push_back(operands[1]);
} else {
operand.push_back(operands[0]);
operand.push_back(operands[1]);
operand.push_back(operands[2]);
}

// The operands are in the form of llvm::ArrayRef<mlir::ValueRange>, where
// each ValueRange consists of a list of same operand. However, to use the
// operands in the new op, the operands of the same block should together in
// a ValueRange (Vector of operands of the each block should be in a
// vector).
auto horizontalOperands = verticalToHorizontalToValueRange(operands);
// Get the attributes
auto attributes = op.getOperation()->getAttrs();
Op newOp;
llvm::SmallVector<mlir::Value> newOps;
for (int i = 0; i < shape[0] * shape[1]; i++) {
auto newOp = createOp<Op, numOperands>(rewriter, op.getLoc(), operand, i);
newOp->getResult(0).setType(newTy);
newOps.push_back(newOp);
}

rewriter.replaceOp(op, newOps);
return mlir::success();
}
};

template <typename CastOp>
struct TypecastOpPattern : public XeOneToNConversion<CastOp> {
using XeOneToNConversion<CastOp>::XeOneToNConversion;
using RangeT = llvm::ArrayRef<mlir::ValueRange>;
using OpAdaptor = typename CastOp::template GenericAdaptor<RangeT>;

mlir::LogicalResult
matchAndRewrite(CastOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto out = mlir::dyn_cast<mlir::VectorType>(op.getType());
if (!out || out.getRank() != 4)
return mlir::failure();

auto shape = out.getShape();
auto vecTy =
mlir::VectorType::get({shape[2], shape[3]}, out.getElementType());
auto inputs = adaptor.getIn();
llvm::SmallVector<mlir::Value> newOps;
for (auto in : inputs) {
auto newOp = rewriter.create<CastOp>(op.getLoc(), vecTy, in);
newOps.push_back(newOp);
}
rewriter.replaceOp(op, newOps);
return mlir::success();
}
};

struct SgArithCmpIOpPattern : public XeOneToNConversion<mlir::arith::CmpIOp> {
using XeOneToNConversion<mlir::arith::CmpIOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
if (!resType || resType.getRank() != 4)
return mlir::failure();

auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();

llvm::SmallVector<mlir::Value> newOps;
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto newOp = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), op.getPredicate(), l, r);
newOps.push_back(newOp);
for (auto newOperands : horizontalOperands) {
// We are using the generic builder that is supported by all ops.
// static void build(::mlir::OpBuilder &, ::mlir::OperationState
// &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
// ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
newOp = rewriter.create<Op>(op.getLoc(), newResultTypes, newOperands,
attributes);
for (unsigned i = 0; i < numResults; i++) {
mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>(
*newOp.getODSResults(i).begin());
newOps.push_back(result);
}
}

rewriter.replaceOp(op, newOps);
return mlir::success();
}
Expand Down Expand Up @@ -1226,10 +1183,10 @@ struct SgVectorCreateMaskOpPattern : public XeOneToNConversion<CreateMaskOp> {
if (constDef && constDef.value() == shape[0]) {
// Case 1: all rows are enabled.
// See assumptions about the supported create_mask op in
// VectorCreateMaskOpPattern in xetile blocking pass. The second and forth
// operands are the same. This value is the mask of the inner dimension of
// the original shape. Different masks are created based on the new inner
// dimension size.
// VectorCreateMaskOpPattern in xetile blocking pass. The second and
// forth operands are the same. This value is the mask of the inner
// dimension of the original shape. Different masks are created based on
// the new inner dimension size.
auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
llvm::SmallVector<llvm::SmallVector<mlir::Value>> newOperands;
mlir::Value mask = adaptor.getOperands()[3][0];
Expand Down Expand Up @@ -1287,63 +1244,80 @@ struct SgVectorSplatOpPattern : public XeOneToNConversion<SplatOp> {
void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.insert<
SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern,
SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern,
SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern,
SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern,
SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis);
patterns
.add<SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern,
SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern,
SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern,
SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern,
SgTransposeOpPattern<mlir::vector::TransposeOp>,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
patterns.getContext(), converter, analysis);

// Element-wise math operations
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
ElementWiseOpPattern<mlir::math::SinOp, 1>,
ElementWiseOpPattern<mlir::math::CosOp, 1>,
ElementWiseOpPattern<mlir::math::SqrtOp, 1>,
ElementWiseOpPattern<mlir::math::TanhOp, 1>,
ElementWiseOpPattern<mlir::math::LogOp, 1>,
ElementWiseOpPattern<mlir::math::RsqrtOp, 1>,
ElementWiseOpPattern<mlir::math::ErfOp, 1>,
ElementWiseOpPattern<mlir::math::PowFOp, 2>>(
patterns.getContext(), converter, analysis);

// Element-wise arithmetic operations
patterns.insert<ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
ElementWiseOpPattern<mlir::arith::AddIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivFOp, 2>,
ElementWiseOpPattern<mlir::arith::DivSIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MulFOp, 2>,
ElementWiseOpPattern<mlir::arith::MulIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaximumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinimumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MinSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinUIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
ElementWiseOpPattern<mlir::arith::RemSIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemUIOp, 2>,
ElementWiseOpPattern<mlir::arith::SubFOp, 2>,
ElementWiseOpPattern<mlir::arith::SubIOp, 2>,
ElementWiseOpPattern<mlir::arith::AndIOp, 2>,
ElementWiseOpPattern<mlir::arith::XOrIOp, 2>,
ElementWiseOpPattern<mlir::arith::SelectOp, 3>>(
patterns.getContext(), converter, analysis);
patterns.insert<TypecastOpPattern<mlir::arith::ExtFOp>,
TypecastOpPattern<mlir::arith::ExtSIOp>,
TypecastOpPattern<mlir::arith::ExtUIOp>,
TypecastOpPattern<mlir::arith::FPToSIOp>,
TypecastOpPattern<mlir::arith::FPToUIOp>,
TypecastOpPattern<mlir::arith::IndexCastOp>,
TypecastOpPattern<mlir::arith::IndexCastUIOp>,
TypecastOpPattern<mlir::arith::SIToFPOp>,
TypecastOpPattern<mlir::arith::UIToFPOp>,
TypecastOpPattern<mlir::arith::TruncFOp>,
TypecastOpPattern<mlir::arith::TruncIOp>>(
patterns.add<ElementWiseOpPattern<mlir::math::ExpOp>,
ElementWiseOpPattern<mlir::math::SinOp>,
ElementWiseOpPattern<mlir::math::CosOp>,
ElementWiseOpPattern<mlir::math::SqrtOp>,
ElementWiseOpPattern<mlir::math::TanhOp>,
ElementWiseOpPattern<mlir::math::LogOp>,
ElementWiseOpPattern<mlir::math::RsqrtOp>,
ElementWiseOpPattern<mlir::math::ErfOp>,
ElementWiseOpPattern<mlir::math::PowFOp>>(patterns.getContext(),
converter, analysis);

// Arithmetic operations
patterns.add<ElementWiseOpPattern<mlir::arith::AddFOp>,
ElementWiseOpPattern<mlir::arith::AddIOp>,
ElementWiseOpPattern<mlir::arith::AddUIExtendedOp>,
ElementWiseOpPattern<mlir::arith::AndIOp>,
ElementWiseOpPattern<mlir::arith::CeilDivSIOp>,
ElementWiseOpPattern<mlir::arith::CeilDivUIOp>,
ElementWiseOpPattern<mlir::arith::CmpFOp>,
ElementWiseOpPattern<mlir::arith::CmpIOp>,
ElementWiseOpPattern<mlir::arith::DivFOp>,
ElementWiseOpPattern<mlir::arith::DivSIOp>,
ElementWiseOpPattern<mlir::arith::DivUIOp>,
ElementWiseOpPattern<mlir::arith::FloorDivSIOp>,
ElementWiseOpPattern<mlir::arith::MaximumFOp>,
ElementWiseOpPattern<mlir::arith::MaxNumFOp>,
ElementWiseOpPattern<mlir::arith::MaxSIOp>,
ElementWiseOpPattern<mlir::arith::MaxUIOp>,
ElementWiseOpPattern<mlir::arith::MinimumFOp>,
ElementWiseOpPattern<mlir::arith::MinNumFOp>,
ElementWiseOpPattern<mlir::arith::MinSIOp>,
ElementWiseOpPattern<mlir::arith::MinUIOp>,
ElementWiseOpPattern<mlir::arith::MulFOp>,
ElementWiseOpPattern<mlir::arith::MulIOp>,
ElementWiseOpPattern<mlir::arith::MulSIExtendedOp>,
ElementWiseOpPattern<mlir::arith::MulUIExtendedOp>,
ElementWiseOpPattern<mlir::arith::NegFOp>,
ElementWiseOpPattern<mlir::arith::OrIOp>,
ElementWiseOpPattern<mlir::arith::RemFOp>,
ElementWiseOpPattern<mlir::arith::RemSIOp>,
ElementWiseOpPattern<mlir::arith::RemUIOp>,
ElementWiseOpPattern<mlir::arith::SelectOp>,
ElementWiseOpPattern<mlir::arith::ShLIOp>,
ElementWiseOpPattern<mlir::arith::ShRSIOp>,
ElementWiseOpPattern<mlir::arith::ShRUIOp>,
ElementWiseOpPattern<mlir::arith::SubFOp>,
ElementWiseOpPattern<mlir::arith::SubIOp>,
ElementWiseOpPattern<mlir::arith::XOrIOp>>(patterns.getContext(),
converter, analysis);

// Typecast operations
patterns.add<ElementWiseOpPattern<mlir::arith::BitcastOp>,
ElementWiseOpPattern<mlir::arith::ExtFOp>,
ElementWiseOpPattern<mlir::arith::ExtSIOp>,
ElementWiseOpPattern<mlir::arith::ExtUIOp>,
ElementWiseOpPattern<mlir::arith::FPToSIOp>,
ElementWiseOpPattern<mlir::arith::FPToUIOp>,
ElementWiseOpPattern<mlir::arith::IndexCastOp>,
ElementWiseOpPattern<mlir::arith::IndexCastUIOp>,
ElementWiseOpPattern<mlir::arith::SIToFPOp>,
ElementWiseOpPattern<mlir::arith::TruncFOp>,
ElementWiseOpPattern<mlir::arith::TruncIOp>,
ElementWiseOpPattern<mlir::arith::UIToFPOp>>(
patterns.getContext(), converter, analysis);
}

Expand Down
Loading
Loading