diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index c0607cf34..4cbd8815b 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -998,32 +998,35 @@ struct SgTransposeOpPattern : public XeOneToNConversion { }; bool isLegalElementWiseOp(mlir::Operation *op) { - auto res = op->getResult(0); - auto resType = mlir::dyn_cast(res.getType()); - if (resType && resType.getRank() > 2) - return false; - return true; + // Check that all results are of vector type and has rank > 2. + auto numResults = op->getNumResults(); + for (unsigned i = 0; i < numResults; i++) { + auto res = op->getResult(i); + auto resType = mlir::dyn_cast(res.getType()); + if (!resType || resType.getRank() <= 2) + return true; + } + return false; } -template -Op createOp(XeOneToNPatternRewriter &rewriter, mlir::Location loc, - llvm::SmallVector> operands, int i) { - static_assert(numOperands >= 1 && numOperands <= 3, - "Unsupported number of operands"); - - if constexpr (numOperands == 1) { - return rewriter.create(loc, operands[0][i]); - } else if constexpr (numOperands == 2) { - return rewriter.create(loc, operands[0][i], operands[1][i]); - } else if constexpr (numOperands == 3) { - return rewriter.create(loc, operands[0][i], operands[1][i], - operands[2][i]); +// Convert a llvm::ArrayRef of operands range, where each range consists of a +// list of same operand, To a llvm::ArrayRef of operand range, where the range +// is created from element from each list of operand. + +llvm::SmallVector> +verticalToHorizontalToValueRange(llvm::ArrayRef operands) { + auto numBlocks = operands[0].size(); + llvm::SmallVector> values(numBlocks); + for (auto operand : operands) { + for (unsigned i = 0; i < operand.size(); i++) { + values[i].push_back(operand[i]); + } } + return values; } -template +template struct ElementWiseOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; using RangeT = llvm::ArrayRef; using OpAdaptor = typename Op::template GenericAdaptor; @@ -1031,98 +1034,52 @@ struct ElementWiseOpPattern : public XeOneToNConversion { mlir::LogicalResult matchAndRewrite(Op op, OpAdaptor adaptor, XeOneToNPatternRewriter &rewriter) const override { - - auto res = op.getResult(); - auto resType = mlir::dyn_cast(res.getType()); - // non-vector ops, or 1D/2D vector ops generated during lowering. - if (!resType || resType.getRank() <= 2) - return mlir::failure(); - - // For non 2D vector ops, we expect 4D vector ops only - if (resType.getRank() != 4) { - op.emitOpError() << "type is not 4D vector"; - return mlir::failure(); + auto numResults = op.getOperation()->getNumResults(); + + llvm::SmallVector newResultTypes; + for (unsigned i = 0; i < numResults; i++) { + mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( + *op.getODSResults(i).begin()); + auto resultType = mlir::dyn_cast(result.getType()); + // Check if the result types are 4D vectors, if any of the result type is + // not a 4D vector, return failure. + if (!resultType || resultType.getRank() != 4) + return mlir::failure(); + + auto shape = resultType.getShape(); + // Get the new result type, this is the type of the result of the new + // blocked op that works on the 2-D vector. + auto vecTy = mlir::VectorType::get({shape[2], shape[3]}, + resultType.getElementType()); + newResultTypes.push_back(vecTy); } - auto shape = resType.getShape(); - auto newTy = - mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType()); - - // Get all the slices of Operands + // Get the operands auto operands = adaptor.getOperands(); - llvm::SmallVector> operand; - if (numOperands == 1) - operand.push_back(operands[0]); - else if (numOperands == 2) { - operand.push_back(operands[0]); - operand.push_back(operands[1]); - } else { - operand.push_back(operands[0]); - operand.push_back(operands[1]); - operand.push_back(operands[2]); - } - + // The operands are in the form of llvm::ArrayRef, where + // each ValueRange consists of a list of same operand. However, to use the + // operands in the new op, the operands of the same block should together in + // a ValueRange (Vector of operands of the each block should be in a + // vector). + auto horizontalOperands = verticalToHorizontalToValueRange(operands); + // Get the attributes + auto attributes = op.getOperation()->getAttrs(); + Op newOp; llvm::SmallVector newOps; - for (int i = 0; i < shape[0] * shape[1]; i++) { - auto newOp = createOp(rewriter, op.getLoc(), operand, i); - newOp->getResult(0).setType(newTy); - newOps.push_back(newOp); - } - - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -template -struct TypecastOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - using RangeT = llvm::ArrayRef; - using OpAdaptor = typename CastOp::template GenericAdaptor; - - mlir::LogicalResult - matchAndRewrite(CastOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto out = mlir::dyn_cast(op.getType()); - if (!out || out.getRank() != 4) - return mlir::failure(); - - auto shape = out.getShape(); - auto vecTy = - mlir::VectorType::get({shape[2], shape[3]}, out.getElementType()); - auto inputs = adaptor.getIn(); - llvm::SmallVector newOps; - for (auto in : inputs) { - auto newOp = rewriter.create(op.getLoc(), vecTy, in); - newOps.push_back(newOp); - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -struct SgArithCmpIOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto res = op.getResult(); - auto resType = mlir::dyn_cast(res.getType()); - if (!resType || resType.getRank() != 4) - return mlir::failure(); - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - - llvm::SmallVector newOps; - for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { - auto newOp = rewriter.create( - op.getLoc(), op.getPredicate(), l, r); - newOps.push_back(newOp); + for (auto newOperands : horizontalOperands) { + // We are using the generic builder that is supported by all ops. + // static void build(::mlir::OpBuilder &, ::mlir::OperationState + // &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, + // ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); + newOp = rewriter.create(op.getLoc(), newResultTypes, newOperands, + attributes); + for (unsigned i = 0; i < numResults; i++) { + mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( + *newOp.getODSResults(i).begin()); + newOps.push_back(result); + } } - rewriter.replaceOp(op, newOps); return mlir::success(); } @@ -1226,10 +1183,10 @@ struct SgVectorCreateMaskOpPattern : public XeOneToNConversion { if (constDef && constDef.value() == shape[0]) { // Case 1: all rows are enabled. // See assumptions about the supported create_mask op in - // VectorCreateMaskOpPattern in xetile blocking pass. The second and forth - // operands are the same. This value is the mask of the inner dimension of - // the original shape. Different masks are created based on the new inner - // dimension size. + // VectorCreateMaskOpPattern in xetile blocking pass. The second and + // forth operands are the same. This value is the mask of the inner + // dimension of the original shape. Different masks are created based on + // the new inner dimension size. auto one = rewriter.create(loc, 1); llvm::SmallVector> newOperands; mlir::Value mask = adaptor.getOperands()[3][0]; @@ -1287,63 +1244,80 @@ struct SgVectorSplatOpPattern : public XeOneToNConversion { void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, mlir::RewritePatternSet &patterns, TileUsageAnalysis &analysis) { - patterns.insert< - SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern, - SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern, - SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern, - SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern, - SgTransposeOpPattern, - SgTransposeOpPattern, SgBroadcastOpPattern, - SgTileReductionOpPattern, SgVectorCreateMaskOpPattern, - SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis); + patterns + .add, + SgTransposeOpPattern, SgBroadcastOpPattern, + SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( + patterns.getContext(), converter, analysis); // Element-wise math operations - patterns.insert, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>( - patterns.getContext(), converter, analysis); - - // Element-wise arithmetic operations - patterns.insert, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>( - patterns.getContext(), converter, analysis); - patterns.insert, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern, - TypecastOpPattern>( + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>(patterns.getContext(), + converter, analysis); + + // Arithmetic operations + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>(patterns.getContext(), + converter, analysis); + + // Typecast operations + patterns.add, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern>( patterns.getContext(), converter, analysis); } diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index c74d9cf3d..4b7197511 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -88,77 +88,12 @@ class XeTileConversionTarget : public mlir::ConversionTarget { mlir::succeeded(uArchInterface->isLegalPrefetch2dOp(op))); }); - // Arith ops - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + // Arith ops, since we support all the arith ops, we can dynamically make + // the whole dialect legal. + addDynamicallyLegalDialect( + [&](mlir::Operation *op) -> std::optional { + return isLegalElementWiseOp(op); + }); // Math Ops addDynamicallyLegalOp( diff --git a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir index a88da2a34..429a64f33 100644 --- a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir @@ -5,20 +5,22 @@ %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: arith.addf {{.*}}, {{.*}} : vector<1x16xf16> + // CHECK-COUNT-256: arith.addf {{.*}}, {{.*}} fastmath : vector<1x16xf16> // CHECK-COUNT-256: arith.sub // CHECK-COUNT-256: arith.mulf // CHECK-COUNT-256: arith.maximumf // CHECK-COUNT-256: arith.minimumf // CHECK-COUNT-256: arith.divf // CHECK-COUNT-256: arith.remf - %result = arith.addf %3, %1 : vector<64x4x1x16xf16> + // CHECK-COUNT-256: arith.cmpf + %result = arith.addf %3, %1 fastmath : vector<64x4x1x16xf16> %subf_result = arith.subf %result, %1 : vector<64x4x1x16xf16> %mulf_result = arith.mulf %subf_result, %1 : vector<64x4x1x16xf16> %maxf_result = arith.maximumf %mulf_result, %1 : vector<64x4x1x16xf16> %minf_result = arith.minimumf %maxf_result, %mulf_result : vector<64x4x1x16xf16> %divf_result = arith.divf %minf_result, %1 : vector<64x4x1x16xf16> %remf_result = arith.remf %minf_result, %divf_result : vector<64x4x1x16xf16> + %cmpf_result = arith.cmpf ult, %remf_result, %divf_result : vector<64x4x1x16xf16> gpu.return } @@ -39,6 +41,7 @@ // CHECK-COUNT-256: arith.remsi // CHECK-COUNT-256: arith.remui // CHECK-COUNT-256: arith.andi + // CHECK-COUNT-256: arith.addui_extended %result = arith.addi %3, %1 : vector<64x4x1x16xi16> %subi_result = arith.subi %3, %1 : vector<64x4x1x16xi16> %muli_result = arith.muli %subi_result, %1 : vector<64x4x1x16xi16> @@ -51,6 +54,8 @@ %remsi_result = arith.remsi %minsi_result, %divsi_result : vector<64x4x1x16xi16> %remui_result = arith.remui %minui_result, %divui_result : vector<64x4x1x16xi16> %and_result = arith.andi %remsi_result, %remui_result : vector<64x4x1x16xi16> + %addui_sum, %addui_overflow = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> + %addui_extented_result:2 = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> gpu.return }