From 706b38b5427c9bd86f2e6eeb407c34b652e92961 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 31 Oct 2024 17:50:11 +0000 Subject: [PATCH] Add lowering patterns for xetile gather/scatter version ops --- .../XeTileToXeGPU/XeTileOpConversion.cpp | 303 +++++++++++++----- .../XeTileToXeGPU/XeTileToXeGPU.cpp | 5 +- .../XeTileToXeGPU/XeTileToXeGPUConversion.cpp | 8 +- .../XeTile/Transforms/BlockingAnalysis.cpp | 6 +- lib/Transforms/VectorLinearize.cpp | 50 ++- lib/Transforms/VnniTransformation.cpp | 3 + .../XeTileToXeGPU/sg_scattered_ops.mlir | 78 +++++ .../XeTileToXeGPU/sg_tiled_scattered_ops.mlir | 69 ++++ .../Dialect/XeTile/sg_add_scattered_ops.mlir | 102 ++++++ 9 files changed, 531 insertions(+), 93 deletions(-) create mode 100644 test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir create mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir create mode 100644 test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 6b5b4db3b..cbe6960be 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -34,6 +34,7 @@ using mlir::vector::ExtractOp; using mlir::vector::ExtractStridedSliceOp; using mlir::vector::ShapeCastOp; using mlir::vector::ShuffleOp; +using mlir::vector::SplatOp; using VectorTypedValue = mlir::TypedValue; using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location, @@ -396,74 +397,105 @@ class SgInitTileOpPattern : public XeOneToNConversion { // using array_length for load if dim1 of innerBlocks is smaller than // dim1 of shape. auto elemTy = tileTy.getElementType(); - auto array_length = isForLoad(op) && shape[1] > innerBlk[1] - ? getBlockArrayLength(op, elemTy, innerBlk[0], - innerBlk[1], shape[1]) - : 1; - // If this tile is used in load -> transpose -> DPASB chain, optimize - // transpose optimization requires array_length to be 1. - if (isForLoadTransposeDPASB(op)) - array_length = 1; - - auto width = array_length * innerBlk[1]; - - llvm::SmallVector blocks( - {shape[0] / innerBlk[0], shape[1] / width}); - - llvm::SmallVector offsets; - auto staticOffsets = op.getStaticOffsets(); - auto dynamicOffsets = op.getOffsets(); - for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) { - if (mlir::ShapedType::isDynamic(staticOffsets[i])) { - offsets.push_back(dynamicOffsets[j++]); - } else { - offsets.push_back(rewriter.create( - op.getLoc(), rewriter.getIndexAttr(staticOffsets[i]))); - } - } - - // For col-major memref initial offsets need to be swapped. - auto offsetsY = offsets.pop_back_val(); - auto offsetsX = offsets.pop_back_val(); - - auto tDescTy = mlir::xegpu::TensorDescType::get( - innerBlk, elemTy, array_length, true /*boundary_check*/, MemorySpace); - - auto createIndexConstant = [&](mlir::Type type, int64_t value) { - auto attr = rewriter.getIndexAttr(value); - return rewriter.create(loc, type, attr); - }; - rewriter.setInsertionPoint(op); - - llvm::SmallVector xegpuOps(blocks[0] * blocks[1], - mlir::Value()); - for (int i = 0; i < blocks[0]; i++) { - for (int j = 0; j < blocks[1]; j++) { - auto subOffX = createIndexConstant(indexType, (innerBlk[0] * i)); - auto subOffY = createIndexConstant(indexType, (width * j)); - auto tDescOffsetX = - rewriter.createOrFold(loc, subOffX, offsetsX); - auto tDescOffsetY = - rewriter.createOrFold(loc, subOffY, offsetsY); - mlir::SmallVector tDescOffsets = llvm::to_vector<4>( - llvm::map_range(offsets, [](mlir::Value v) -> mlir::OpFoldResult { - return v; - })); - tDescOffsets.push_back(tDescOffsetX); - tDescOffsets.push_back(tDescOffsetY); - - // TODO: this needs improvement, it assumes the source is static - // memeref. - if (auto MemRefTypedSource = - mlir::cast>(source)) { - auto createNdOp = rewriter.create( - op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/, - tDescOffsets /*offsets*/); - - xegpuOps[blocks[1] * i + j] = createNdOp; + llvm::SmallVector xegpuOps; + // scattered tiles are lowered into create_tdesc ops with chunk_size = 1. + if (tileTy.getScatterAttr() == mlir::BoolAttr::get(op.getContext(), true)) { + llvm::SmallVector grids( + {shape[0] / innerBlk[0], shape[1] / innerBlk[1]}); + auto elems = innerBlk[0] * innerBlk[1]; + // TODO: get this from uArch. 32 is the max number of SIMD lanes. + assert(elems <= 32 && "Scattered tile size should be <= 32"); + mlir::xegpu::SGMapAttr sgMap = nullptr; + if (auto attr = tileTy.getSgMap()) { + llvm::SmallVector layout( + attr.getWiLayout().asArrayRef().begin(), + attr.getWiLayout().asArrayRef().end()); + llvm::SmallVector data(attr.getWiData().asArrayRef().begin(), + attr.getWiData().asArrayRef().end()); + sgMap = mlir::xegpu::SGMapAttr::get(op.getContext(), layout, data); + } + auto tdescTy = mlir::xegpu::TensorDescType::get( + elems, elemTy, 1 /* chunk_size */, MemorySpace, sgMap); + auto indiceTy = mlir::VectorType::get(elems, indexType); + auto indices = adaptor.getIndices(); + for (int64_t i = 0; i < grids[0]; i++) { + for (int64_t j = 0; j < grids[1]; j++) { + auto indice = indices[i * grids[1] + j]; + indice = rewriter.create(loc, indiceTy, indice); + auto createOp = rewriter.create( + loc, tdescTy, source, indice); + xegpuOps.push_back(createOp); + } + } + } else { + auto array_length = isForLoad(op) && shape[1] > innerBlk[1] + ? getBlockArrayLength(op, elemTy, innerBlk[0], + innerBlk[1], shape[1]) + : 1; + // If this tile is used in load -> transpose -> DPASB chain, optimize + // transpose optimization requires array_length to be 1. + if (isForLoadTransposeDPASB(op)) + array_length = 1; + + auto width = array_length * innerBlk[1]; + + llvm::SmallVector blocks( + {shape[0] / innerBlk[0], shape[1] / width}); + + llvm::SmallVector offsets; + auto staticOffsets = op.getStaticOffsets(); + auto dynamicOffsets = op.getOffsets(); + for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) { + if (mlir::ShapedType::isDynamic(staticOffsets[i])) { + offsets.push_back(dynamicOffsets[j++]); } else { - return mlir::failure(); + offsets.push_back(rewriter.create( + op.getLoc(), rewriter.getIndexAttr(staticOffsets[i]))); + } + } + + // For col-major memref initial offsets need to be swapped. + auto offsetsY = offsets.pop_back_val(); + auto offsetsX = offsets.pop_back_val(); + + auto tDescTy = mlir::xegpu::TensorDescType::get( + innerBlk, elemTy, array_length, true /*boundary_check*/, MemorySpace); + + auto createIndexConstant = [&](mlir::Type type, int64_t value) { + auto attr = rewriter.getIndexAttr(value); + return rewriter.create(loc, type, attr); + }; + + rewriter.setInsertionPoint(op); + xegpuOps.resize(blocks[0] * blocks[1]); + for (int i = 0; i < blocks[0]; i++) { + for (int j = 0; j < blocks[1]; j++) { + auto subOffX = createIndexConstant(indexType, (innerBlk[0] * i)); + auto subOffY = createIndexConstant(indexType, (width * j)); + auto tDescOffsetX = rewriter.createOrFold( + loc, subOffX, offsetsX); + auto tDescOffsetY = rewriter.createOrFold( + loc, subOffY, offsetsY); + mlir::SmallVector tDescOffsets = + llvm::to_vector<4>(llvm::map_range( + offsets, + [](mlir::Value v) -> mlir::OpFoldResult { return v; })); + tDescOffsets.push_back(tDescOffsetX); + tDescOffsets.push_back(tDescOffsetY); + + // TODO: this needs improvement, it assumes the source is static + // memeref. + if (auto MemRefTypedSource = + mlir::cast>(source)) { + auto createNdOp = rewriter.create( + op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/, + tDescOffsets /*offsets*/); + + xegpuOps[blocks[1] * i + j] = createNdOp; + } else { + return mlir::failure(); + } } } } @@ -608,6 +640,43 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { } }; +// It lowers XeTile::load into one ore more mlir::xegpu::load with chunk_size=1. +// since xetile::load typically works on 2D representation of the tile, while +// mlir::xegpu::load works on 1D representation, shapecast is used to convert +// vector type operands to 1D representation. +struct SgLoadGatherOpPattern : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::LoadGatherOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + auto tiles = adaptor.getTile(); + auto masks = adaptor.getMask(); + auto tileTy = op.getTile().getType(); + auto innerBlk = tileTy.getInnerBlocks(); + auto resTy = + mlir::VectorType::get(innerBlk.asArrayRef(), tileTy.getElementType()); + auto vecTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], + tileTy.getElementType()); + auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], + rewriter.getIntegerType(1)); + llvm::SmallVector xegpuOps; + auto transposeAttr = mlir::UnitAttr(); + auto cacheAttr = mlir::xegpu::CachePolicyAttr::get( + op.getContext(), mlir::xegpu::CachePolicy::CACHED); + for (auto [t, m] : llvm::zip(tiles, masks)) { + m = rewriter.create(op.getLoc(), maskTy, m); + auto ldOp = rewriter.create( + op.getLoc(), vecTy, t, m, transposeAttr, cacheAttr, cacheAttr, + cacheAttr); + auto v = rewriter.create(op.getLoc(), resTy, ldOp); + xegpuOps.push_back(v); + } + rewriter.replaceOp(op, xegpuOps); + return mlir::success(); + } +}; + // It lowers a XeTile::store_tile into one or more mlir::xegpu::store_2d // The adaptor will provide the set of xegpu.create_nd_desc lowered for // its input tile, and similar to its input vector value. @@ -641,6 +710,40 @@ struct SgStoreTileOpPattern : public XeOneToNConversion { } }; +// It lowers XeTile::store into one ore more mlir::xegpu::store with +// chunk_size=1. Similar to xetile::load, shapecast is used to convert vector +// type operands to 1D representation. +struct SgStoreScatterOpPattern + : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::StoreScatterOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + auto values = adaptor.getValue(); + auto tdescs = adaptor.getTile(); + auto masks = adaptor.getMask(); + + auto tileTy = op.getTile().getType(); + auto innerBlk = tileTy.getInnerBlocks(); + auto vecTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], + tileTy.getElementType()); + auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], + rewriter.getIntegerType(1)); + auto transposeAttr = mlir::UnitAttr(); + auto cacheAttr = mlir::xegpu::CachePolicyAttr::get( + op.getContext(), mlir::xegpu::CachePolicy::WRITE_BACK); + for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) { + m = rewriter.create(op.getLoc(), maskTy, m); + v = rewriter.create(op.getLoc(), vecTy, v); + rewriter.create( + op.getLoc(), v, t, m, transposeAttr, cacheAttr, cacheAttr, cacheAttr); + } + rewriter.eraseOp(op); + return mlir::success(); + } +}; + // It lowers a XeTile::tile_mma into one or more mlir::xegpu::dpas // The adaptor provides new inputs for each old input. struct SgTileMMAOpPattern : public XeOneToNConversion { @@ -703,26 +806,34 @@ struct SgUpdateTileOffsetOpPattern mlir::LogicalResult matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor, XeOneToNPatternRewriter &rewriter) const override { - auto offsetX = op.getOffsetX(); - auto offsetY = op.getOffsetY(); - auto tiles = adaptor.getTile(); - - bool hasColMajorTraversal = - op.getTile().getType().getOrder().asArrayRef() == - mlir::ArrayRef({0, 1}); - + auto tileTy = op.getTile().getType(); + auto tdescs = adaptor.getTile(); llvm::SmallVector newOps; - int64_t kDynamics[2] = {mlir::ShapedType::kDynamic, - mlir::ShapedType::kDynamic}; - for (const auto &tile : tiles) { - // if the traversal is col-major, we need to reverse the offsets at XeGPU - // level because only row-major traversal is supported. - auto xegpuTile = rewriter.create( - op.getLoc(), tile.getType(), tile, - hasColMajorTraversal ? mlir::ValueRange({offsetY, offsetX}) - : mlir::ValueRange({offsetX, offsetY}), - llvm::ArrayRef(kDynamics, 2)); - newOps.push_back(xegpuTile); + if (tileTy.getScatterAttr() == mlir::BoolAttr::get(op.getContext(), true)) { + auto indices = adaptor.getIndices(); + for (auto [tdesc, idx] : llvm::zip_equal(tdescs, indices)) { + auto type = mlir::cast(idx.getType()); + auto flatTy = + mlir::VectorType::get(type.getNumElements(), type.getElementType()); + idx = rewriter.create(op.getLoc(), flatTy, idx); + auto xegpuTile = rewriter.create( + op.getLoc(), tdesc.getType(), tdesc, idx); + newOps.push_back(xegpuTile); + } + } else { + auto offsetX = op.getOffsetX(); + auto offsetY = op.getOffsetY(); + int64_t kDynamics[2] = {mlir::ShapedType::kDynamic, + mlir::ShapedType::kDynamic}; + for (const auto &tdesc : tdescs) { + // if the traversal is col-major, we need to reverse the offsets at + // XeGPU level because only row-major traversal is supported. + auto xegpuTile = rewriter.create( + op.getLoc(), tdesc.getType(), tdesc, + mlir::ValueRange({offsetX, offsetY}), + llvm::ArrayRef(kDynamics, 2)); + newOps.push_back(xegpuTile); + } } rewriter.replaceOp(op, newOps); return mlir::success(); @@ -1116,13 +1227,33 @@ struct SgVectorCreateMaskOpPattern : public XeOneToNConversion { } }; +struct SgVectorSplatOpPattern : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(SplatOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + auto type = op.getAggregate().getType(); + if (type.getRank() != 4) + return mlir::failure(); + auto shape = type.getShape(); + auto newType = + mlir::VectorType::get(shape.take_back(2), type.getElementType()); + auto newOp = rewriter.create(op.getLoc(), op.getInput(), newType); + llvm::SmallVector newOps(shape[0] * shape[1], newOp); + rewriter.replaceOp(op, newOps); + return mlir::success(); + } +}; + void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, mlir::RewritePatternSet &patterns, TileUsageAnalysis &analysis) { patterns.insert< SgInitTileOpPattern, SgPrefetchTileOpPattern, SgTileUnpackOpPattern, SgTilePackOpPattern, SgLoadTileOpPattern, SgStoreTileOpPattern, - SgTileMMAOpPattern, SgUpdateTileOffsetOpPattern, + SgLoadGatherOpPattern, SgStoreScatterOpPattern, SgTileMMAOpPattern, + SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern, SgTransposeOpPattern, SgTransposeOpPattern, SgBroadcastOpPattern, SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index 31c09ebbc..0564fbb4d 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -54,7 +54,6 @@ class XeTileConversionTarget : public mlir::ConversionTarget { addLegalOp(); addLegalOp(); addLegalOp(); - addLegalOp(); addLegalOp(); addLegalDialect(); @@ -168,6 +167,10 @@ class XeTileConversionTarget : public mlir::ConversionTarget { [](mlir::vector::TransposeOp op) { return op.getResult().getType().getRank() == 2; }); + + addDynamicallyLegalOp([&](mlir::vector::SplatOp op) { + return op.getAggregate().getType().getRank() != 4; + }); } private: diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp index d951512b6..31f405ae3 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp @@ -185,8 +185,12 @@ XeOneToNTypeConverter::computeTypeMapping(mlir::ValueRange original, return mlir::failure(); auto shape = tileTy.getShape(); auto blkSZ = tdescTy.getShape(); - auto arr_len = tdescTy.getArrayLength(); - auto size = shape[0] / blkSZ[0] * shape[1] / (blkSZ[1] * arr_len); + auto arr_len = tdescTy.isScattered() ? 1 : tdescTy.getArrayLength(); + auto totalNumElems = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{}); + auto blockNumElems = + std::accumulate(blkSZ.begin(), blkSZ.end(), 1, std::multiplies<>{}); + auto size = totalNumElems / blockNumElems / arr_len; llvm::ArrayRef types(convertedTypes.begin() + j, convertedTypes.begin() + j + size); resultMap.addInputs(i, types); diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index 51b78c904..5de72a637 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -750,9 +750,11 @@ template Block BlockingAnalysisImpl::getInnerBlockSize( mlir::Operation *op, mlir::Type elemTy, llvm::ArrayRef &shape, int memorySpace) { - assert(elemTy.isIntOrFloat() && "only support int or float element type."); - int elemSize = elemTy.getIntOrFloatBitWidth(); + // TODO: is it safe to treat index as 32 bit integer? + // Expecting index vector is mainly used for gather/scatter ops on SLM. + // in which the address is 32-bit. + int elemSize = elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 32; const int64_t subgroupSize = uArch->getOneGRFSizeBits() / elemSize; int maxHeight = 0, minHeight = 0, maxWidth = 0, minWidth = 0; diff --git a/lib/Transforms/VectorLinearize.cpp b/lib/Transforms/VectorLinearize.cpp index ffa005dc2..bf357b92c 100644 --- a/lib/Transforms/VectorLinearize.cpp +++ b/lib/Transforms/VectorLinearize.cpp @@ -33,6 +33,34 @@ namespace imex { } // namespace imex namespace { + +// rewrite arith.constant op in form of vector<1xmxindex> into 1D form +// (vector) +struct ArithConstantOpConversion final + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::arith::ConstantOp constOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto value = llvm::dyn_cast(constOp.getValue()); + if (!value || value.getType().getRank() != 2) + return mlir::failure(); + auto type = value.getType(); + auto shape = type.getShape(); + auto elemTy = type.getElementType(); + if (shape[0] != 1 || !elemTy.isIndex()) + return mlir::failure(); + auto newTy = mlir::VectorType::get({shape[1]}, elemTy); + value = value.reshape(newTy); + auto newOp = + rewriter.create(constOp.getLoc(), value); + auto castOp = rewriter.create(constOp.getLoc(), + type, newOp); + rewriter.replaceOp(constOp, castOp); + return mlir::success(); + } +}; + struct VectorLoadOpConversion final : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; @@ -485,11 +513,29 @@ struct VectorLinearizePass final return (op && op.getAggregate().getType().getRank() == 1); }); + // borrowed from upstream with hacking for index type. Currently + // we only target vector<1xmxindex> to vector conversion. It is + // unclear whether others are valid or not; thus they are left untouched. + target.addDynamicallyLegalOp( + [&](mlir::arith::ConstantOp op) -> bool { + auto vecTy = mlir::dyn_cast(op.getType()); + if (!vecTy || vecTy.getRank() == 0) + return true; + + auto elemTy = vecTy.getElementType(); + if (elemTy.isIndex()) { + if (vecTy.getRank() == 2 && vecTy.getShape()[0] == 1) + return false; + return true; + } + return !mlir::vector::isLinearizableVector(vecTy); + }); + patterns.add( - typeConverter, context); + VectorStoreOpConversion, VectorCreateMaskOpConversion, + ArithConstantOpConversion>(typeConverter, context); // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes. mlir::vector::populateVectorTransposeLoweringPatterns( diff --git a/lib/Transforms/VnniTransformation.cpp b/lib/Transforms/VnniTransformation.cpp index af90a4a13..6f241b867 100644 --- a/lib/Transforms/VnniTransformation.cpp +++ b/lib/Transforms/VnniTransformation.cpp @@ -113,6 +113,9 @@ static bool isVNNIApplicable(mlir::Type type) { // VNNI transform only available for 2D vectors. if (!vecTy || vecTy.getRank() != 2) return false; + auto elemTy = vecTy.getElementType(); + if (!elemTy.isIntOrFloat()) + return false; auto factor = getVnniFactor(vecTy.getElementType()); auto shape = vecTy.getShape(); // factor == 1 means 32-bit data, and no need to apply VNNI. diff --git a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir new file mode 100644 index 000000000..1f3eacdac --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir @@ -0,0 +1,78 @@ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ +// RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test { + //CHECK-LABEL: @test_init_tile_for_scattered + //CHECK-SAME: %[[arg0:.*]]: memref<1024xf16> + gpu.func @test_init_tile_for_scattered(%arg0: memref<1024xf16>) { + //CHECK: %[[cst:.*]] = arith.constant dense : vector<32xi1> + //CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex> + //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + %cst = arith.constant dense : vector<4x32xi1> + %cst_0 = arith.constant dense<1> : vector<4x32xindex> + %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> + %1 = xetile.load %0, %cst : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xi1> -> vector<4x32xf16> + %2 = xetile.update_tile_offset %0, %cst_0 : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xindex> + xetile.store %1, %0, %cst : vector<4x32xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xi1> + gpu.return + } + + //CHECK-LABEL: @add_kernel + //CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32> + gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) { + //CHECK: %[[cst:.*]] = arith.constant dense : vector<16xi1> + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[cast:.*]] = memref.cast %[[arg0]] : memref<*xf32> to memref + //CHECK: %[[cast_0:.*]] = memref.cast %[[arg1]] : memref<*xf32> to memref + //CHECK: %[[cast_1:.*]] = memref.cast %[[arg2]] : memref<*xf32> to memref + //CHECK: %[[block_id_x:.*]] = gpu.block_id x + //CHECK: %[[r0:.*]] = arith.muli %[[block_id_x]], %[[c1024]] : index + //CHECK: %[[r1:.*]] = vector.splat %[[r0]] : vector<1x16xindex> + //CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[cast]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r4:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r5:.*]] = vector.shape_cast %[[r4]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r6:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r7:.*]] = vector.shape_cast %[[r6]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r8:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r9:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r10:.*]] = vector.shape_cast %[[r9]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r11:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r12:.*]] = vector.shape_cast %[[r11]] : vector<16xf32> to vector<1x16xf32> + //CHECK: %[[r13:.*]] = arith.addf %[[r5]], %[[r10]] : vector<1x16xf32> + //CHECK: %[[r14:.*]] = arith.addf %[[r7]], %[[r12]] : vector<1x16xf32> + //CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r16:.*]] = vector.shape_cast %[[r13]] : vector<1x16xf32> to vector<16xf32> + //CHECK: xegpu.store %[[r16]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: %[[r17:.*]] = vector.shape_cast %[[r14]] : vector<1x16xf32> to vector<16xf32> + //CHECK: xegpu.store %[[r17]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %c1024 = arith.constant 1024 : index + %cst = arith.constant dense : vector<1x32xi1> + %cast = memref.cast %arg0 : memref<*xf32> to memref + %cast_0 = memref.cast %arg1 : memref<*xf32> to memref + %cast_1 = memref.cast %arg2 : memref<*xf32> to memref + %block_id_x = gpu.block_id x + %0 = arith.muli %block_id_x, %c1024 : index + %1 = vector.splat %0 : vector<1x32xindex> + %2 = xetile.init_tile %cast, %1 : memref, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %3 = xetile.load %2, %cst : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %4 = xetile.init_tile %cast_0, %1 : memref, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %5 = xetile.load %4, %cst : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %6 = arith.addf %3, %5 : vector<1x32xf32> + %7 = xetile.init_tile %cast_1, %1 : memref, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + xetile.store %6, %7, %cst : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + gpu.return + } +} diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir new file mode 100644 index 000000000..86bceef91 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir @@ -0,0 +1,69 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test { + //CHECK-LABEL: @test_init_tile_for_scattered + //CHECK-SAME: %[[arg0:.*]]: memref<1024xf16> + gpu.func @test_init_tile_for_scattered(%arg0: memref<1024xf16>) { + + //CHECK: %[[cst:.*]] = arith.constant dense<1> : vector<16xindex> + //CHECK: %[[cst_0:.*]] = arith.constant dense : vector<16xi1> + //CHECK: %[[cst_1:.*]] = arith.constant dense<{{.*}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{.*}}> : vector<1x16xindex> + //CHECK: %[[cst_2:.*]] = arith.constant dense<{{.*}}16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31{{.*}}]> : vector<1x16xindex> + //CHECK: %[[cst_3:.*]] = arith.constant dense<{{.*}}32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47{{.*}}]> : vector<1x16xindex> + //CHECK: %[[cst_4:.*]] = arith.constant dense<{{.*}}48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63{{.*}}]> : vector<1x16xindex> + //CHECK: %[[cst_5:.*]] = arith.constant dense<{{.*}}64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79{{.*}}]> : vector<1x16xindex> + //CHECK: %[[cst_6:.*]] = arith.constant dense<{{.*}}80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95{{.*}}]> : vector<1x16xindex> + //CHECK: %[[cst_7:.*]] = arith.constant dense<{{.*}}96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111{{.*}}> : vector<1x16xindex> + //CHECK: %[[cst_8:.*]] = arith.constant dense<{{.*}}112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127{{.*}}> : vector<1x16xindex> + //CHECK: %[[r0:.*]] = vector.shape_cast %[[cst_1]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r1:.*]] = xegpu.create_tdesc %[[arg0]], %[[r0]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r2:.*]] = vector.shape_cast %[[cst_2]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[arg0]], %[[r2]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r4:.*]] = vector.shape_cast %[[cst_3]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r5:.*]] = xegpu.create_tdesc %[[arg0]], %[[r4]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r6:.*]] = vector.shape_cast %[[cst_4]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r7:.*]] = xegpu.create_tdesc %[[arg0]], %[[r6]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r8:.*]] = vector.shape_cast %[[cst_5]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r9:.*]] = xegpu.create_tdesc %[[arg0]], %[[r8]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r10:.*]] = vector.shape_cast %[[cst_6]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r11:.*]] = xegpu.create_tdesc %[[arg0]], %[[r10]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r12:.*]] = vector.shape_cast %[[cst_7]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r13:.*]] = xegpu.create_tdesc %[[arg0]], %[[r12]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r14:.*]] = vector.shape_cast %[[cst_8]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[arg0]], %[[r14]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r16:.*]] = xegpu.load %[[r1]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r17:.*]] = xegpu.load %[[r3]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r18:.*]] = xegpu.load %[[r5]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r19:.*]] = xegpu.load %[[r7]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r20:.*]] = xegpu.load %[[r9]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r21:.*]] = xegpu.load %[[r11]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r22:.*]] = xegpu.load %[[r13]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r23:.*]] = xegpu.load %[[r15]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r24:.*]] = xegpu.update_offset %[[r1]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r25:.*]] = xegpu.update_offset %[[r3]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r26:.*]] = xegpu.update_offset %[[r5]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r27:.*]] = xegpu.update_offset %[[r7]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r28:.*]] = xegpu.update_offset %[[r9]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r29:.*]] = xegpu.update_offset %[[r11]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r30:.*]] = xegpu.update_offset %[[r13]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r31:.*]] = xegpu.update_offset %[[r15]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: xegpu.store %[[r16]], %[[r1]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r17]], %[[r3]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r18]], %[[r5]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r19]], %[[r7]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r20]], %[[r9]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r21]], %[[r11]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r22]], %[[r13]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r23]], %[[r15]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + + + %cst = arith.constant dense : vector<4x2x1x16xi1> + %cst_0 = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : vector<4x2x1x16xindex> + %offsets = arith.constant dense<1> : vector<4x2x1x16xindex> + %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x2x1x16xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> + %1 = xetile.load %0, %cst : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xi1> -> vector<4x2x1x16xf16> + %2 = xetile.update_tile_offset %0, %offsets : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xindex> + xetile.store %1, %0, %cst : vector<4x2x1x16xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xi1> + gpu.return + } +} diff --git a/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir b/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir new file mode 100644 index 000000000..9b6daf379 --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir @@ -0,0 +1,102 @@ +// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +// NOTES : +// This example assumes one subgroup per one workgroup and the kernel specifies the computation +// done by a single subgroup. + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024xf32>, %B: memref<1024xf32>) -> memref<1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %A_gpu = gpu.alloc host_shared () : memref<1024xf32> + memref.copy %A, %A_gpu : memref<1024xf32> to memref<1024xf32> + %B_gpu = gpu.alloc host_shared () : memref<1024xf32> + memref.copy %B, %B_gpu : memref<1024xf32> to memref<1024xf32> + %C_gpu = gpu.alloc host_shared () : memref<1024xf32> + gpu.launch_func @test_kernel::@add_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024xf32>, %B_gpu : memref<1024xf32>, %C_gpu : memref<1024xf32>) + gpu.dealloc %A_gpu : memref<1024xf32> + gpu.dealloc %B_gpu : memref<1024xf32> + return %C_gpu : memref<1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @add_kernel(%A: memref<1024xf32>, %B: memref<1024xf32>, %C: memref<1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]>: vector<1x32xindex> + %offsets = arith.constant dense<32>: vector<1x32xindex> + %mask = arith.constant dense: vector<1x32xi1> + + %a_init_tile = xetile.init_tile %A, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %b_init_tile = xetile.init_tile %B, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %c_init_tile = xetile.init_tile %C, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + // %c_init_tile = xetile.init_tile %C[0, 0] : memref<1024xf32> -> !xetile.tile<1x32xf32> + + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_tile = %c_init_tile) + -> (!xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>) { + + // load A and B tiles + %a_value = xetile.load %a_tile, %mask : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %b_value = xetile.load %b_tile, %mask : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %c_value = arith.addf %a_value, %b_value : vector<1x32xf32> + + // xetile.store_tile %c_value, %c_tile : vector<1x32xf32>, !xetile.tile<1x32xf32> + xetile.store %c_value, %c_tile, %mask : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + + %a_next_tile = xetile.update_tile_offset %a_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %b_next_tile = xetile.update_tile_offset %b_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %c_next_tile = xetile.update_tile_offset %c_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + // %c_next_tile = xetile.update_tile_offset %c_tile, [%c0, %c32] : !xetile.tile<1x32xf32> + + scf.yield %a_next_tile, %b_next_tile, %c_next_tile + : !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr> + } + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : bf16 + %cf_1 = arith.constant 1.0 : bf16 + %A = memref.alloc() : memref<1024xf32> + %B = memref.alloc() : memref<1024xf32> + %C_ref = memref.alloc() : memref<1024xf32> + // intialize matrix A ; + scf.for %i = %c0 to %c1024 step %c1 { + %t = index.castu %i : index to i32 + %val = arith.uitofp %t : i32 to f32 + memref.store %val, %A[%i] : memref<1024xf32> + memref.store %val, %B[%i] : memref<1024xf32> + } + + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + %a_val = memref.load %A[%i] : memref<1024xf32> + %b_val = memref.load %B[%i] : memref<1024xf32> + %c_val = arith.addf %a_val, %b_val : f32 + memref.store %c_val, %C_ref[%i] : memref<1024xf32> + } + %2 = call @test(%A, %B) : (memref<1024xf32>, memref<1024xf32>) -> memref<1024xf32> + %cast_C = memref.cast %2 : memref<1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024xf32> to memref<*xf32> + // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024xf32> + memref.dealloc %B : memref<1024xf32> + memref.dealloc %C_ref : memref<1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +}