From 7f0c8c0f5df0c21472734db69246275cf20e9ff0 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Mon, 25 Nov 2024 17:15:10 +0000 Subject: [PATCH] [XeTileToXeGPU] Add generic lowering pattern for Xetile elementwise ops (arith, math dialect ops). Allows Arith/math dialect pre/post-ops to be decomposed (blocked and updated). Add a generic element-wise conversion pattern that can tackle all operations. It removes multiple patterns necessary to handle different kind of ops. It also supports passing attributes (e.g., fastmath). Add support for all the operations in Arith dialect. --- .../XeTileToXeGPU/XeTileOpConversion.cpp | 304 ++++++++---------- .../XeTileToXeGPU/XeTileToXeGPU.cpp | 77 +---- .../XeTileToXeGPU/elementwise_ops.mlir | 9 +- 3 files changed, 152 insertions(+), 238 deletions(-) diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index c0607cf34..4cbd8815b 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -998,32 +998,35 @@ struct SgTransposeOpPattern : public XeOneToNConversion { }; bool isLegalElementWiseOp(mlir::Operation *op) { - auto res = op->getResult(0); - auto resType = mlir::dyn_cast(res.getType()); - if (resType && resType.getRank() > 2) - return false; - return true; + // Check that all results are of vector type and has rank > 2. + auto numResults = op->getNumResults(); + for (unsigned i = 0; i < numResults; i++) { + auto res = op->getResult(i); + auto resType = mlir::dyn_cast(res.getType()); + if (!resType || resType.getRank() <= 2) + return true; + } + return false; } -template -Op createOp(XeOneToNPatternRewriter &rewriter, mlir::Location loc, - llvm::SmallVector> operands, int i) { - static_assert(numOperands >= 1 && numOperands <= 3, - "Unsupported number of operands"); - - if constexpr (numOperands == 1) { - return rewriter.create(loc, operands[0][i]); - } else if constexpr (numOperands == 2) { - return rewriter.create(loc, operands[0][i], operands[1][i]); - } else if constexpr (numOperands == 3) { - return rewriter.create(loc, operands[0][i], operands[1][i], - operands[2][i]); +// Convert a llvm::ArrayRef of operands range, where each range consists of a +// list of same operand, To a llvm::ArrayRef of operand range, where the range +// is created from element from each list of operand. + +llvm::SmallVector> +verticalToHorizontalToValueRange(llvm::ArrayRef operands) { + auto numBlocks = operands[0].size(); + llvm::SmallVector> values(numBlocks); + for (auto operand : operands) { + for (unsigned i = 0; i < operand.size(); i++) { + values[i].push_back(operand[i]); + } } + return values; } -template +template struct ElementWiseOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; using RangeT = llvm::ArrayRef; using OpAdaptor = typename Op::template GenericAdaptor; @@ -1031,98 +1034,52 @@ struct ElementWiseOpPattern : public XeOneToNConversion { mlir::LogicalResult matchAndRewrite(Op op, OpAdaptor adaptor, XeOneToNPatternRewriter &rewriter) const override { - - auto res = op.getResult(); - auto resType = mlir::dyn_cast(res.getType()); - // non-vector ops, or 1D/2D vector ops generated during lowering. - if (!resType || resType.getRank() <= 2) - return mlir::failure(); - - // For non 2D vector ops, we expect 4D vector ops only - if (resType.getRank() != 4) { - op.emitOpError() << "type is not 4D vector"; - return mlir::failure(); + auto numResults = op.getOperation()->getNumResults(); + + llvm::SmallVector newResultTypes; + for (unsigned i = 0; i < numResults; i++) { + mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( + *op.getODSResults(i).begin()); + auto resultType = mlir::dyn_cast(result.getType()); + // Check if the result types are 4D vectors, if any of the result type is + // not a 4D vector, return failure. + if (!resultType || resultType.getRank() != 4) + return mlir::failure(); + + auto shape = resultType.getShape(); + // Get the new result type, this is the type of the result of the new + // blocked op that works on the 2-D vector. + auto vecTy = mlir::VectorType::get({shape[2], shape[3]}, + resultType.getElementType()); + newResultTypes.push_back(vecTy); } - auto shape = resType.getShape(); - auto newTy = - mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType()); - - // Get all the slices of Operands + // Get the operands auto operands = adaptor.getOperands(); - llvm::SmallVector> operand; - if (numOperands == 1) - operand.push_back(operands[0]); - else if (numOperands == 2) { - operand.push_back(operands[0]); - operand.push_back(operands[1]); - } else { - operand.push_back(operands[0]); - operand.push_back(operands[1]); - operand.push_back(operands[2]); - } - + // The operands are in the form of llvm::ArrayRef, where + // each ValueRange consists of a list of same operand. However, to use the + // operands in the new op, the operands of the same block should together in + // a ValueRange (Vector of operands of the each block should be in a + // vector). + auto horizontalOperands = verticalToHorizontalToValueRange(operands); + // Get the attributes + auto attributes = op.getOperation()->getAttrs(); + Op newOp; llvm::SmallVector newOps; - for (int i = 0; i < shape[0] * shape[1]; i++) { - auto newOp = createOp(rewriter, op.getLoc(), operand, i); - newOp->getResult(0).setType(newTy); - newOps.push_back(newOp); - } - - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -template -struct TypecastOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - using RangeT = llvm::ArrayRef; - using OpAdaptor = typename CastOp::template GenericAdaptor; - - mlir::LogicalResult - matchAndRewrite(CastOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto out = mlir::dyn_cast(op.getType()); - if (!out || out.getRank() != 4) - return mlir::failure(); - - auto shape = out.getShape(); - auto vecTy = - mlir::VectorType::get({shape[2], shape[3]}, out.getElementType()); - auto inputs = adaptor.getIn(); - llvm::SmallVector newOps; - for (auto in : inputs) { - auto newOp = rewriter.create(op.getLoc(), vecTy, in); - newOps.push_back(newOp); - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -struct SgArithCmpIOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto res = op.getResult(); - auto resType = mlir::dyn_cast(res.getType()); - if (!resType || resType.getRank() != 4) - return mlir::failure(); - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - - llvm::SmallVector newOps; - for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { - auto newOp = rewriter.create( - op.getLoc(), op.getPredicate(), l, r); - newOps.push_back(newOp); + for (auto newOperands : horizontalOperands) { + // We are using the generic builder that is supported by all ops. + // static void build(::mlir::OpBuilder &, ::mlir::OperationState + // &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, + // ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + newOp = rewriter.create(op.getLoc(), newResultTypes, newOperands, + attributes); + for (unsigned i = 0; i < numResults; i++) { + mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( + *newOp.getODSResults(i).begin()); + newOps.push_back(result); + } } - rewriter.replaceOp(op, newOps); return mlir::success(); } @@ -1226,10 +1183,10 @@ struct SgVectorCreateMaskOpPattern : public XeOneToNConversion { if (constDef && constDef.value() == shape[0]) { // Case 1: all rows are enabled. // See assumptions about the supported create_mask op in - // VectorCreateMaskOpPattern in xetile blocking pass. The second and forth - // operands are the same. This value is the mask of the inner dimension of - // the original shape. Different masks are created based on the new inner - // dimension size. + // VectorCreateMaskOpPattern in xetile blocking pass. The second and + // forth operands are the same. This value is the mask of the inner + // dimension of the original shape. Different masks are created based on + // the new inner dimension size. auto one = rewriter.create(loc, 1); llvm::SmallVector> newOperands; mlir::Value mask = adaptor.getOperands()[3][0]; @@ -1287,63 +1244,80 @@ struct SgVectorSplatOpPattern : public XeOneToNConversion { void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, mlir::RewritePatternSet &patterns, TileUsageAnalysis &analysis) { - patterns.insert< - SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern, - SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern, - SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern, - SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern, - SgTransposeOpPattern, - SgTransposeOpPattern, SgBroadcastOpPattern, - SgTileReductionOpPattern, SgVectorCreateMaskOpPattern, - SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis); + patterns + .add, + SgTransposeOpPattern, SgBroadcastOpPattern, + SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( + patterns.getContext(), converter, analysis); // Element-wise math operations - patterns.insert, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>( - patterns.getContext(), converter, analysis); - - // Element-wise arithmetic operations - patterns.insert, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>( - patterns.getContext(), converter, analysis); - patterns.insert, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern>( + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>(patterns.getContext(), + converter, analysis); + + // Arithmetic operations + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>(patterns.getContext(), + converter, analysis); + + // Typecast operations + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>( patterns.getContext(), converter, analysis); } diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index c74d9cf3d..4b7197511 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -88,77 +88,12 @@ class XeTileConversionTarget : public mlir::ConversionTarget { mlir::succeeded(uArchInterface->isLegalPrefetch2dOp(op))); }); - // Arith ops - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + // Arith ops, since we support all the arith ops, we can dynamically make + // the whole dialect legal. + addDynamicallyLegalDialect( + [&](mlir::Operation *op) -> std::optional { + return isLegalElementWiseOp(op); + }); // Math Ops addDynamicallyLegalOp( diff --git a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir index a88da2a34..429a64f33 100644 --- a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir @@ -5,20 +5,22 @@ %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: arith.addf {{.*}}, {{.*}} : vector<1x16xf16> + // CHECK-COUNT-256: arith.addf {{.*}}, {{.*}} fastmath : vector<1x16xf16> // CHECK-COUNT-256: arith.sub // CHECK-COUNT-256: arith.mulf // CHECK-COUNT-256: arith.maximumf // CHECK-COUNT-256: arith.minimumf // CHECK-COUNT-256: arith.divf // CHECK-COUNT-256: arith.remf - %result = arith.addf %3, %1 : vector<64x4x1x16xf16> + // CHECK-COUNT-256: arith.cmpf + %result = arith.addf %3, %1 fastmath : vector<64x4x1x16xf16> %subf_result = arith.subf %result, %1 : vector<64x4x1x16xf16> %mulf_result = arith.mulf %subf_result, %1 : vector<64x4x1x16xf16> %maxf_result = arith.maximumf %mulf_result, %1 : vector<64x4x1x16xf16> %minf_result = arith.minimumf %maxf_result, %mulf_result : vector<64x4x1x16xf16> %divf_result = arith.divf %minf_result, %1 : vector<64x4x1x16xf16> %remf_result = arith.remf %minf_result, %divf_result : vector<64x4x1x16xf16> + %cmpf_result = arith.cmpf ult, %remf_result, %divf_result : vector<64x4x1x16xf16> gpu.return } @@ -39,6 +41,7 @@ // CHECK-COUNT-256: arith.remsi // CHECK-COUNT-256: arith.remui // CHECK-COUNT-256: arith.andi + // CHECK-COUNT-256: arith.addui_extended %result = arith.addi %3, %1 : vector<64x4x1x16xi16> %subi_result = arith.subi %3, %1 : vector<64x4x1x16xi16> %muli_result = arith.muli %subi_result, %1 : vector<64x4x1x16xi16> @@ -51,6 +54,8 @@ %remsi_result = arith.remsi %minsi_result, %divsi_result : vector<64x4x1x16xi16> %remui_result = arith.remui %minui_result, %divui_result : vector<64x4x1x16xi16> %and_result = arith.andi %remsi_result, %remui_result : vector<64x4x1x16xi16> + %addui_sum, %addui_overflow = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> + %addui_extented_result:2 = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> gpu.return }