From ca096a1a4f028192e2b151bc454678e691393328 Mon Sep 17 00:00:00 2001 From: Garra1980 Date: Fri, 22 Nov 2024 23:50:50 +0100 Subject: [PATCH] Adding cache attributes to load/store/gather/scatter on XeTile level --- include/imex/Dialect/XeTile/IR/XeTileAttrs.td | 2 + include/imex/Dialect/XeTile/IR/XeTileOps.td | 26 ++++++--- .../XeTileToXeGPU/XeTileOpConversion.cpp | 58 +++++++++---------- lib/Dialect/XeTile/Transforms/Blocking.cpp | 8 +-- lib/Dialect/XeTile/Transforms/WgToSg.cpp | 10 ++-- .../XeTileToXeGPU/sg_load_tile.mlir | 19 +++++- .../XeTileToXeGPU/sg_scattered_ops.mlir | 27 +++++++++ .../XeTileToXeGPU/sg_store_tile.mlir | 24 ++++++++ 8 files changed, 124 insertions(+), 50 deletions(-) diff --git a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td index b6e6735c5..6d1e5a092 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td +++ b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td @@ -133,6 +133,8 @@ def XeTile_AtomicRMWKindAttr : I64EnumAttr< let cppNamespace = "::imex::xetile"; } +//TODO: !!!This is target specific information, cache attributes have to be passed transparently +// as custom arguments and handled properly on XeGPU side //===----------------------------------------------------------------------===// // XeTile Cache Enums. //===----------------------------------------------------------------------===// diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index 18e85354d..44e446a42 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -317,10 +317,12 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { ``` }]; - let arguments = (ins - XeTile: $source, - OptionalAttr: $padding - ); + let arguments = (ins XeTile: $source, + OptionalAttr: $padding, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint); + let results = (outs XeTile_2DOr4DVector: $value); let assemblyFormat = "$source attr-dict `:` qualified(type($source)) `->` type($value)"; @@ -365,8 +367,10 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { let arguments = (ins XeTile_2DOr4DVector: $value, - XeTile: $tile - ); + XeTile: $tile, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint); let assemblyFormat = [{ $value`,`` `$tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile)) @@ -655,7 +659,10 @@ def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value let arguments = (ins XeTile: $tile, XeTile_MaskType: $mask, - OptionalAttr: $padding); + OptionalAttr: $padding, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint); let results = (outs XeTile_1DOr2DOr4DVector: $value); let assemblyFormat = [{ $tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value) @@ -673,7 +680,10 @@ def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "t }]; let arguments = (ins XeTile_1DOr2DOr4DVector: $value, XeTile: $tile, - XeTile_MaskType: $mask); + XeTile_MaskType: $mask, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint); let assemblyFormat = [{ $value `,` $tile `,` $mask attr-dict `:` type($value) `,` qualified(type($tile)) `,` type($mask) }]; diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index e119338bd..cd4153f66 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -501,9 +501,9 @@ class SgInitTileOpPattern : public XeOneToNConversion { }; static mlir::xegpu::CachePolicy -translateCachePolicy(imex::xetile::CachePolicyAttr val) { +translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy defaultVal) { if (!val) - return mlir::xegpu::CachePolicy::CACHED; + return defaultVal; switch (val.getValue()) { case imex::xetile::CachePolicy::CACHED: @@ -522,6 +522,22 @@ translateCachePolicy(imex::xetile::CachePolicyAttr val) { llvm_unreachable("Invalid CachePolicy value"); } +template +static auto +getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = mlir::xegpu::CachePolicy::CACHED) { + + auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) { + return mlir::xegpu::CachePolicyAttr::get(op.getContext(), + translateCachePolicy(val, defaultVal)); + }; + + auto L1 = getCachePolicyAttr(op.getL1HintAttr()); + auto L2 = getCachePolicyAttr(op.getL2HintAttr()); + auto L3 = getCachePolicyAttr(op.getL3HintAttr()); + + return std::make_tuple(L1, L2, L3); +} + // It lowers a XeTile::prefetch_tile into one or more mlir::xegpu::prefetch_2d. // The adaptor will provide the set of xegpu.create_nd_desc lowered for // its input tile. @@ -551,14 +567,7 @@ struct SgPrefetchTileOpPattern return mlir::failure(); } - auto getCachePolicy = [&](imex::xetile::CachePolicyAttr val) { - return mlir::xegpu::CachePolicyAttr::get(op.getContext(), - translateCachePolicy(val)); - }; - - auto L1 = getCachePolicy(op.getL1HintAttr()); - auto L2 = getCachePolicy(op.getL2HintAttr()); - auto L3 = getCachePolicy(op.getL3HintAttr()); + auto [L1, L2, L3] = getCachePolicy(op); for (auto tile : tiles) { rewriter.create(op.getLoc(), tile, L1, L2, L3); @@ -589,16 +598,7 @@ struct SgLoadTileOpPattern : public XeOneToNConversion { auto elemTy = tileTy.getElementType(); auto sources = adaptor.getSource(); - auto ctx = op.getContext(); - - auto getDefaultCachePolicy = [&]() { - return mlir::xegpu::CachePolicyAttr::get( - ctx, mlir::xegpu::CachePolicy::CACHED); - }; - - auto L1 = getDefaultCachePolicy(); - auto L2 = getDefaultCachePolicy(); - auto L3 = getDefaultCachePolicy(); + auto [L1, L2, L3] = getCachePolicy(op); // The tile is in col-major order, which should be canonicalized to // row-major in canonicalization pass. @@ -657,13 +657,11 @@ struct SgLoadGatherOpPattern : public XeOneToNConversion { rewriter.getIntegerType(1)); llvm::SmallVector xegpuOps; auto transposeAttr = mlir::UnitAttr(); - auto cacheAttr = mlir::xegpu::CachePolicyAttr::get( - op.getContext(), mlir::xegpu::CachePolicy::CACHED); + auto [L1, L2, L3] = getCachePolicy(op); for (auto [t, m] : llvm::zip(tiles, masks)) { m = rewriter.create(op.getLoc(), maskTy, m); auto ldOp = rewriter.create( - op.getLoc(), vecTy, t, m, transposeAttr, cacheAttr, cacheAttr, - cacheAttr); + op.getLoc(), vecTy, t, m, transposeAttr, L1, L2, L3); auto v = rewriter.create(op.getLoc(), resTy, ldOp); xegpuOps.push_back(v); } @@ -691,11 +689,8 @@ struct SgStoreTileOpPattern : public XeOneToNConversion { << "values: " << values.size() << "\n"; } - auto context = op.getContext(); - auto WRITE_BACK = mlir::xegpu::CachePolicy::WRITE_BACK; - auto L1 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK); - auto L2 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK); - auto L3 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK); + auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); + for (size_t i = 0; i < tiles.size(); i++) rewriter.create(op.getLoc(), values[i], tiles[i], L1, L2, L3); @@ -726,13 +721,12 @@ struct SgStoreScatterOpPattern auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], rewriter.getIntegerType(1)); auto transposeAttr = mlir::UnitAttr(); - auto cacheAttr = mlir::xegpu::CachePolicyAttr::get( - op.getContext(), mlir::xegpu::CachePolicy::WRITE_BACK); + auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) { m = rewriter.create(op.getLoc(), maskTy, m); v = rewriter.create(op.getLoc(), vecTy, v); rewriter.create( - op.getLoc(), v, t, m, transposeAttr, cacheAttr, cacheAttr, cacheAttr); + op.getLoc(), v, t, m, transposeAttr, L1, L2, L3); } rewriter.eraseOp(op); return mlir::success(); diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index ffc152d69..24e9e0a3e 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -290,7 +290,7 @@ struct LoadTileOpPattern tileTy.getElementType()); mlir::Value newOp = rewriter.create( op.getLoc(), vecTy, adaptor.getSource(), - op.getPadding().value_or(mlir::Attribute())); + op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOp = addUnpackOp(newOp, rewriter); rewriter.replaceOp(op, newOp); return mlir::success(); @@ -324,7 +324,7 @@ struct LoadGatherOpPattern auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter); mlir::Value newOp = rewriter.create( op.getLoc(), vecTy, source, mask, - op.getPadding().value_or(mlir::Attribute())); + op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newOp = addUnpackOp(newOp, rewriter); rewriter.replaceOp(op, newOp); return mlir::success(); @@ -352,7 +352,7 @@ struct StoreTileOpPattern // its inputs has not been updated yet. if (blockSize && valTy.getRank() == 2) { value = addPackOp(value, blockSize.asArrayRef(), rewriter); - rewriter.replaceOpWithNewOp(op, value, tile); + rewriter.replaceOpWithNewOp(op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); return mlir::success(); } return mlir::failure(); @@ -382,7 +382,7 @@ struct StoreScatterOpPattern auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter); rewriter.replaceOpWithNewOp(op, value, tile, - mask); + mask, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); return mlir::success(); } return mlir::failure(); diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 1b1656783..8880077ab 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -242,7 +242,7 @@ class WGToSGLoadTileOpPattern : public XeOneToNConversion { mlir::VectorType::get({tileTy.getShape()[0], tileTy.getShape()[1]}, tileTy.getElementType()); auto newLoadOp = rewriter.create( - op.getLoc(), newResTy, src, op.getPaddingAttr()); + op.getLoc(), newResTy, src, op.getPaddingAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); newLoadOps.push_back(newLoadOp); newResultTypes.push_back(newLoadOp.getResult().getType()); } @@ -306,7 +306,7 @@ class WGToSGStoreTileOpPattern : public XeOneToNConversion for (size_t i = 0; i < newValues.size(); i++) { rewriter.create(op.getLoc(), newValues[i], - newDstTiles[i]); + newDstTiles[i], op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); } rewriter.eraseOp(op); @@ -745,8 +745,9 @@ class WGToSGXeTileConvertLayout loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1])); auto storeInitTileOp = rewriter.create( loc, srcTileTy, slm, llvm::ArrayRef({storeOffsetX, storeOffsetY})); + //TODO: Set up cache attributes rewriter.create(loc, adaptor.getSource()[0], - storeInitTileOp); + storeInitTileOp, nullptr, nullptr, nullptr); // Add barrier rewriter.create(loc); @@ -766,8 +767,9 @@ class WGToSGXeTileConvertLayout loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1])); auto loadInitTileOp = rewriter.create( loc, dstTileTy, slm, llvm::ArrayRef({loadOffsetX, loadOffsetY})); + //TODO: Set up cache attributes auto loadTile = rewriter.create( - loc, newResTy, loadInitTileOp, mlir::Attribute()); + loc, newResTy, loadInitTileOp, mlir::Attribute(), nullptr, nullptr, nullptr); rewriter.replaceOp(op, loadTile); return mlir::success(); diff --git a/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir index 194edb960..9e2262ca5 100644 --- a/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir @@ -12,7 +12,22 @@ gpu.module @test_kernel { %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> //CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> //CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - %2 = xetile.load_tile %1 : !xetile.tile<32x32xf16> -> vector<32x32xf16> - gpu.return + %2 = xetile.load_tile %1 {padding = 0 : i32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + gpu.return + } + + //CHECK: gpu.func @sg_load_tile_cache_attr(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>, %[[arg2:.*]]: memref<1024x1024xf32>) { + gpu.func @sg_load_tile_cache_attr(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + //CHECK: %[[c64:.*]] = arith.constant 64 : index + %c64 = arith.constant 64 : index + //CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]] + //CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + //CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> + //CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + %2 = xetile.load_tile %1 {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + gpu.return } } diff --git a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir index 1f3eacdac..f044f1d03 100644 --- a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir @@ -29,6 +29,33 @@ gpu.module @test { gpu.return } + //CHECK-LABEL: @test_init_tile_for_scattered_cache_attr + //CHECK-SAME: %[[arg0:.*]]: memref<1024xf16> + gpu.func @test_init_tile_for_scattered_cache_attr(%arg0: memref<1024xf16>) { + //CHECK: %[[cst:.*]] = arith.constant dense : vector<32xi1> + //CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex> + //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + %cst = arith.constant dense : vector<4x32xi1> + %cst_0 = arith.constant dense<1> : vector<4x32xindex> + %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> + %1 = xetile.load %0, %cst {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xi1> -> vector<4x32xf16> + %2 = xetile.update_tile_offset %0, %cst_0 : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xindex> + xetile.store %1, %0, %cst {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : vector<4x32xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x32xi1> + gpu.return + } + //CHECK-LABEL: @add_kernel //CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32> gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) { diff --git a/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir index 901b84f1d..daea233a0 100644 --- a/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir @@ -17,4 +17,28 @@ gpu.module @test_kernel { xetile.store_tile %result, %1: vector<32x32xf32>, !xetile.tile<32x32xf32> gpu.return } + + // test cache attributes acceptance + //CHECK: gpu.func @sg_tiled_store_cache_attr(%[[arg0:.*]]: memref<1024x1024xf32>) { + gpu.func @sg_tiled_store_cache_attr(%a: memref<1024x1024xf32>) { + + %result = arith.constant dense<0.0>: vector<32x32xf32> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %1 = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + + //CHECK: xegpu.store_nd %cst, %0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %3 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %cst, %7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xetile.store_tile %result, %1 {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : vector<32x32xf32>, !xetile.tile<32x32xf32> + gpu.return + } }