diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 6d0ecadc5..5264a8621 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -208,7 +208,6 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion { mlir::OneToNTypeMapping newMapping(op.getResult().getType()); newMapping.addInputs(0, newResultTypes); rewriter.replaceOp(op, newInitTileOps, newMapping); - return mlir::success(); } }; @@ -358,7 +357,6 @@ struct WGToSGSCFYieldOpPattern : public XeOneToNConversion { mlir::LogicalResult matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, imex::XeOneToNPatternRewriter &rewriter) const override { - llvm::SmallVector convertedResults; llvm::SmallVector newResultTypes; for (auto &values : adaptor.getResults()) @@ -383,7 +381,6 @@ class WGToSGUpdateTileOffsetOpPattern mlir::LogicalResult matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor, XeOneToNPatternRewriter &rewriter) const override { - llvm::SmallVector<::mlir::Value> newUpdateTileOffsetOps; llvm::SmallVector newResultTypes; for (auto tile : adaptor.getTile()) { @@ -582,6 +579,47 @@ class WGToSGVectorTranspose }; + +class WGToSGVectorBroadcast + :public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::vector::BroadcastOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + if (op.getVector().getType().getRank() != 2) + return mlir::failure(); + + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + + auto srcTy = mlir::dyn_cast((adaptor.getSource()[0]).getType()); + auto srcShape = srcTy.getShape(); + + auto mapAttr = + llvm::dyn_cast_or_null(op->getAttr("map")); + + if (!mapAttr) { + return mlir::failure(); + } + + auto sgData = mapAttr.getSgData(); + auto newTy = mlir::VectorType::get({sgData[0], sgData[1]}, + resType.getElementType()); + auto dstShape = newTy.getShape(); + + if (!(srcShape[0] == 1 && srcShape[1] == dstShape[1]) && + !(srcShape[1] == 1 && srcShape[0] == dstShape[0])) + return mlir::failure(); + + auto newOp = rewriter.create( + op.getLoc(), newTy, adaptor.getSource()[0]); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + + // TODO: Add more pre-ops bool isElementWiseOp(mlir::Operation *op) { return llvm::isa(op) || @@ -639,18 +677,15 @@ void analyzeInitTileOps(mlir::Operation *op) { llvm::cast(*loadUser->user_begin()); ops.push_back(transposeOp); - // Check if the transpose has only one user and that user is a TileMMAOp - // or a pre-op followed by TileMMA - if (!transposeOp->hasOneUse()) - return mlir::WalkResult::skip(); - auto consumerOp = *transposeOp->user_begin(); // Check if vector.transpose is consumed by TileMMA directly or // is consumed by some pre-op and then TileMMA. if(!llvm::isa(consumerOp)){ - if(!isElementWiseOp(consumerOp)) + if(!isElementWiseOp(consumerOp) && + !(llvm::isa(consumerOp))) { return mlir::WalkResult::skip(); + } else { if (!(consumerOp->hasOneUse() && llvm::isa(*consumerOp->user_begin()))) @@ -676,7 +711,8 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter, patterns.insert(patterns.getContext(), converter, + WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, + WGToSGVectorBroadcast>(patterns.getContext(), converter, analysis); patterns.insert, WGToSGElementWiseOpPattern, @@ -777,7 +813,8 @@ class XeTileWgToSgPass }); target.addDynamicallyLegalOp( + mlir::math::ExpOp, mlir::vector::TransposeOp, + mlir::vector::BroadcastOp>( [&](mlir::Operation *op) -> bool { auto mapAttr = llvm::dyn_cast_or_null( op->getAttr("map")); diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir b/test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir new file mode 100644 index 000000000..090d135df --- /dev/null +++ b/test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir @@ -0,0 +1,38 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s + +gpu.module @test_broadcast { + gpu.func @test_kernel(%arg0: memref<256x384xf16>, %arg1: memref<1x384xf16>, %arg2: memref<256x512xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[1, 1, 1, 4]> : vector<4xi64>, gemm_tiles_y = dense<[1, 1, 1, 8]> : vector<4xi64>, habana_runner.num_inputs = 2 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<256x384xf16>, dense<1.000000e+00> : tensor<1x384xf16>], outputs = [dense<3.840000e+02> : tensor<256x512xf32>]}], physical_nd_range = dense<1> : vector<2xi64>, region_partition = 0 : i64, region_size = 1 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<256x384xf16>, tensor<1x384xf16>) -> tensor<256x512xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000015571.16 : f64} { + %c1 = arith.constant 1 : index + %c1_0 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c1_1 = arith.constant 1 : index + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1_0, %arg11 = %c1_1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c4, %arg13 = %c8, %arg14 = %c1_1) { + %c384 = arith.constant 384 : index + %c32 = arith.constant 32 : index + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x512xf32> + %c0 = arith.constant 0 : index + %0 = xetile.init_tile %arg0[%c0, %c0] : memref<256x384xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> + %1 = xetile.init_tile %arg1[%c0, %c0] : memref<1x384xf16> -> !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> + %2:3 = scf.for %arg15 = %c0 to %c384 step %c32 iter_args(%arg16 = %cst, %arg17 = %0, %arg18 = %1) -> (vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>>) { + %4 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> + %5 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> + %6 = xetile.load_tile %arg17 { padding = 0.000000e+00 : f32 } : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> -> vector<256x32xf16> + %7 = xetile.load_tile %arg18 { padding = 0.000000e+00 : f32 } : !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> -> vector<1x32xf16> + //CHECK: %[[TRANSPOSE:.*]] = vector.transpose {{%.*}}, [1, 0] : vector<1x32xf16> to vector<32x1xf16> + %8 = vector.transpose %7, [1, 0] {map = #xetile.wg_map} : vector<1x32xf16> to vector<32x1xf16> + //CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[TRANSPOSE]] : vector<32x1xf16> to vector<32x64xf16> + %9 = vector.broadcast %8 {map = #xetile.wg_map} : vector<32x1xf16> to vector<32x512xf16> + xegpu.compile_hint + %10 = xetile.tile_mma %6, %9, %cst {wg_map_a =#xetile.wg_map, wg_map_b =#xetile.wg_map, wg_map_c =#xetile.wg_map} : vector<256x32xf16>, vector<32x512xf16>, vector<256x512xf32> -> vector<256x512xf32> + xegpu.compile_hint + %11 = arith.addf %arg16, %10 {map = #xetile.wg_map} : vector<256x512xf32> + scf.yield %11, %5, %4 : vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> + } + %3 = xetile.init_tile %arg2[%c0, %c0] : memref<256x512xf32> -> !xetile.tile<256x512xf32, #xetile.tile_attr, inner_blocks = []>> + xetile.store_tile %2#0, %3 : vector<256x512xf32>, !xetile.tile<256x512xf32, #xetile.tile_attr, inner_blocks = []>> + gpu.terminator + } + gpu.return + } +}