Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding cache attributes to load/store/gather/scatter on XeTile level #968

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def XeTile_AtomicRMWKindAttr : I64EnumAttr<
let cppNamespace = "::imex::xetile";
}

//TODO: !!!This is target specific information, cache attributes have to be passed transparently
// as custom arguments and handled properly on XeGPU side
//===----------------------------------------------------------------------===//
// XeTile Cache Enums.
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 18 additions & 8 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,12 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> {
```
}];

let arguments = (ins
XeTile: $source,
OptionalAttr<XeTile_PaddingValueAttr>: $padding
);
let arguments = (ins XeTile: $source,
OptionalAttr<XeTile_PaddingValueAttr>: $padding,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);

let results = (outs XeTile_2DOr4DVector: $value);

let assemblyFormat = "$source attr-dict `:` qualified(type($source)) `->` type($value)";
Expand Down Expand Up @@ -365,8 +367,10 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> {

let arguments = (ins
XeTile_2DOr4DVector: $value,
XeTile: $tile
);
XeTile: $tile,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);

let assemblyFormat = [{
$value`,`` `$tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile))
Expand Down Expand Up @@ -655,7 +659,10 @@ def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value

let arguments = (ins XeTile: $tile,
XeTile_MaskType: $mask,
OptionalAttr<XeTile_PaddingValueAttr>: $padding);
OptionalAttr<XeTile_PaddingValueAttr>: $padding,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);
let results = (outs XeTile_1DOr2DOr4DVector: $value);
let assemblyFormat = [{
$tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value)
Expand All @@ -673,7 +680,10 @@ def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "t
}];
let arguments = (ins XeTile_1DOr2DOr4DVector: $value,
XeTile: $tile,
XeTile_MaskType: $mask);
XeTile_MaskType: $mask,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);
let assemblyFormat = [{
$value `,` $tile `,` $mask attr-dict `:` type($value) `,` qualified(type($tile)) `,` type($mask)
}];
Expand Down
63 changes: 30 additions & 33 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,10 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
};

static mlir::xegpu::CachePolicy
translateCachePolicy(imex::xetile::CachePolicyAttr val) {
translateCachePolicy(imex::xetile::CachePolicyAttr val,
mlir::xegpu::CachePolicy defaultVal) {
if (!val)
return mlir::xegpu::CachePolicy::CACHED;
return defaultVal;

switch (val.getValue()) {
case imex::xetile::CachePolicy::CACHED:
Expand All @@ -522,6 +523,22 @@ translateCachePolicy(imex::xetile::CachePolicyAttr val) {
llvm_unreachable("Invalid CachePolicy value");
}

template <typename OpTy>
static auto getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal =
mlir::xegpu::CachePolicy::CACHED) {

auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) {
return mlir::xegpu::CachePolicyAttr::get(
op.getContext(), translateCachePolicy(val, defaultVal));
};

auto L1 = getCachePolicyAttr(op.getL1HintAttr());
auto L2 = getCachePolicyAttr(op.getL2HintAttr());
auto L3 = getCachePolicyAttr(op.getL3HintAttr());

return std::make_tuple(L1, L2, L3);
}

// It lowers a XeTile::prefetch_tile into one or more mlir::xegpu::prefetch_2d.
// The adaptor will provide the set of xegpu.create_nd_desc lowered for
// its input tile.
Expand Down Expand Up @@ -551,14 +568,7 @@ struct SgPrefetchTileOpPattern
return mlir::failure();
}

auto getCachePolicy = [&](imex::xetile::CachePolicyAttr val) {
return mlir::xegpu::CachePolicyAttr::get(op.getContext(),
translateCachePolicy(val));
};

auto L1 = getCachePolicy(op.getL1HintAttr());
auto L2 = getCachePolicy(op.getL2HintAttr());
auto L3 = getCachePolicy(op.getL3HintAttr());
auto [L1, L2, L3] = getCachePolicy(op);

for (auto tile : tiles) {
rewriter.create<mlir::xegpu::PrefetchNdOp>(op.getLoc(), tile, L1, L2, L3);
Expand Down Expand Up @@ -589,16 +599,7 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
auto elemTy = tileTy.getElementType();
auto sources = adaptor.getSource();

auto ctx = op.getContext();

auto getDefaultCachePolicy = [&]() {
return mlir::xegpu::CachePolicyAttr::get(
ctx, mlir::xegpu::CachePolicy::CACHED);
};

auto L1 = getDefaultCachePolicy();
auto L2 = getDefaultCachePolicy();
auto L3 = getDefaultCachePolicy();
auto [L1, L2, L3] = getCachePolicy(op);

// The tile is in col-major order, which should be canonicalized to
// row-major in canonicalization pass.
Expand Down Expand Up @@ -657,13 +658,11 @@ struct SgLoadGatherOpPattern : public XeOneToNConversion<xetile::LoadGatherOp> {
rewriter.getIntegerType(1));
llvm::SmallVector<mlir::Value> xegpuOps;
auto transposeAttr = mlir::UnitAttr();
auto cacheAttr = mlir::xegpu::CachePolicyAttr::get(
op.getContext(), mlir::xegpu::CachePolicy::CACHED);
auto [L1, L2, L3] = getCachePolicy(op);
for (auto [t, m] : llvm::zip(tiles, masks)) {
m = rewriter.create<ShapeCastOp>(op.getLoc(), maskTy, m);
auto ldOp = rewriter.create<mlir::xegpu::LoadGatherOp>(
op.getLoc(), vecTy, t, m, transposeAttr, cacheAttr, cacheAttr,
cacheAttr);
op.getLoc(), vecTy, t, m, transposeAttr, L1, L2, L3);
auto v = rewriter.create<ShapeCastOp>(op.getLoc(), resTy, ldOp);
xegpuOps.push_back(v);
}
Expand Down Expand Up @@ -691,11 +690,9 @@ struct SgStoreTileOpPattern : public XeOneToNConversion<xetile::StoreTileOp> {
<< "values: " << values.size() << "\n";
}

auto context = op.getContext();
auto WRITE_BACK = mlir::xegpu::CachePolicy::WRITE_BACK;
auto L1 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto L2 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto L3 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto [L1, L2, L3] =
getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);

for (size_t i = 0; i < tiles.size(); i++)
rewriter.create<mlir::xegpu::StoreNdOp>(op.getLoc(), values[i], tiles[i],
L1, L2, L3);
Expand Down Expand Up @@ -726,13 +723,13 @@ struct SgStoreScatterOpPattern
auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1],
rewriter.getIntegerType(1));
auto transposeAttr = mlir::UnitAttr();
auto cacheAttr = mlir::xegpu::CachePolicyAttr::get(
op.getContext(), mlir::xegpu::CachePolicy::WRITE_BACK);
auto [L1, L2, L3] =
getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);
for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) {
m = rewriter.create<ShapeCastOp>(op.getLoc(), maskTy, m);
v = rewriter.create<ShapeCastOp>(op.getLoc(), vecTy, v);
rewriter.create<mlir::xegpu::StoreScatterOp>(
op.getLoc(), v, t, m, transposeAttr, cacheAttr, cacheAttr, cacheAttr);
rewriter.create<mlir::xegpu::StoreScatterOp>(op.getLoc(), v, t, m,
transposeAttr, L1, L2, L3);
}
rewriter.eraseOp(op);
return mlir::success();
Expand Down
15 changes: 10 additions & 5 deletions lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ struct LoadTileOpPattern
tileTy.getElementType());
mlir::Value newOp = rewriter.create<xetile::LoadTileOp>(
op.getLoc(), vecTy, adaptor.getSource(),
op.getPadding().value_or(mlir::Attribute()));
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -324,7 +325,8 @@ struct LoadGatherOpPattern
auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
mlir::Value newOp = rewriter.create<xetile::LoadGatherOp>(
op.getLoc(), vecTy, source, mask,
op.getPadding().value_or(mlir::Attribute()));
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -352,7 +354,9 @@ struct StoreTileOpPattern
// its inputs has not been updated yet.
if (blockSize && valTy.getRank() == 2) {
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(op, value, tile);
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(
op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down Expand Up @@ -381,8 +385,9 @@ struct StoreScatterOpPattern
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
auto mask =
addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(op, value, tile,
mask);
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(
op, value, tile, mask, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class WGToSGLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
mlir::VectorType::get({tileTy.getShape()[0], tileTy.getShape()[1]},
tileTy.getElementType());
auto newLoadOp = rewriter.create<xetile::LoadTileOp>(
op.getLoc(), newResTy, src, op.getPaddingAttr());
op.getLoc(), newResTy, src, op.getPaddingAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
newLoadOps.push_back(newLoadOp);
newResultTypes.push_back(newLoadOp.getResult().getType());
}
Expand Down Expand Up @@ -306,7 +306,7 @@ class WGToSGStoreTileOpPattern : public XeOneToNConversion<xetile::StoreTileOp>

for (size_t i = 0; i < newValues.size(); i++) {
rewriter.create<xetile::StoreTileOp>(op.getLoc(), newValues[i],
newDstTiles[i]);
newDstTiles[i], op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
}

rewriter.eraseOp(op);
Expand Down Expand Up @@ -745,8 +745,9 @@ class WGToSGXeTileConvertLayout
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
//TODO: Set up cache attributes
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp);
storeInitTileOp, nullptr, nullptr, nullptr);

// Add barrier
rewriter.create<mlir::gpu::BarrierOp>(loc);
Expand All @@ -766,8 +767,9 @@ class WGToSGXeTileConvertLayout
loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1]));
auto loadInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
//TODO: Set up cache attributes
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute());
loc, newResTy, loadInitTileOp, mlir::Attribute(), nullptr, nullptr, nullptr);

rewriter.replaceOp(op, loadTile);
return mlir::success();
Expand Down
19 changes: 17 additions & 2 deletions test/Conversion/XeTileToXeGPU/sg_load_tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,22 @@ gpu.module @test_kernel {
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
//CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
//CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
%2 = xetile.load_tile %1 : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
%2 = xetile.load_tile %1 {padding = 0 : i32} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
}

//CHECK: gpu.func @sg_load_tile_cache_attr(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>, %[[arg2:.*]]: memref<1024x1024xf32>) {
gpu.func @sg_load_tile_cache_attr(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
//CHECK: %[[c64:.*]] = arith.constant 64 : index
%c64 = arith.constant 64 : index
//CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]]
//CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
//CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
//CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
%2 = xetile.load_tile %1 {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<streaming>} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
}
}
27 changes: 27 additions & 0 deletions test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ gpu.module @test {
gpu.return
}

//CHECK-LABEL: @test_init_tile_for_scattered_cache_attr
//CHECK-SAME: %[[arg0:.*]]: memref<1024xf16>
gpu.func @test_init_tile_for_scattered_cache_attr(%arg0: memref<1024xf16>) {
//CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<32xi1>
//CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex>
//CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
%cst = arith.constant dense<true> : vector<4x32xi1>
%cst_0 = arith.constant dense<1> : vector<4x32xindex>
%0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>
%1 = xetile.load %0, %cst {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<streaming>} : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1> -> vector<4x32xf16>
%2 = xetile.update_tile_offset %0, %cst_0 : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xindex>
xetile.store %1, %0, %cst {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<uncached>} : vector<4x32xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1>
gpu.return
}

//CHECK-LABEL: @add_kernel
//CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32>
gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) {
Expand Down
Loading
Loading