Skip to content

Commit

Permalink
Adding cache attributes to load/store/gather/scatter on XeTile level
Browse files Browse the repository at this point in the history
  • Loading branch information
Garra1980 committed Nov 22, 2024
1 parent 2b8422d commit ca096a1
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 50 deletions.
2 changes: 2 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def XeTile_AtomicRMWKindAttr : I64EnumAttr<
let cppNamespace = "::imex::xetile";
}

//TODO: !!!This is target specific information, cache attributes have to be passed transparently
// as custom arguments and handled properly on XeGPU side
//===----------------------------------------------------------------------===//
// XeTile Cache Enums.
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 18 additions & 8 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,12 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> {
```
}];

let arguments = (ins
XeTile: $source,
OptionalAttr<XeTile_PaddingValueAttr>: $padding
);
let arguments = (ins XeTile: $source,
OptionalAttr<XeTile_PaddingValueAttr>: $padding,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);

let results = (outs XeTile_2DOr4DVector: $value);

let assemblyFormat = "$source attr-dict `:` qualified(type($source)) `->` type($value)";
Expand Down Expand Up @@ -365,8 +367,10 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> {

let arguments = (ins
XeTile_2DOr4DVector: $value,
XeTile: $tile
);
XeTile: $tile,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);

let assemblyFormat = [{
$value`,`` `$tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile))
Expand Down Expand Up @@ -655,7 +659,10 @@ def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value

let arguments = (ins XeTile: $tile,
XeTile_MaskType: $mask,
OptionalAttr<XeTile_PaddingValueAttr>: $padding);
OptionalAttr<XeTile_PaddingValueAttr>: $padding,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);
let results = (outs XeTile_1DOr2DOr4DVector: $value);
let assemblyFormat = [{
$tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value)
Expand All @@ -673,7 +680,10 @@ def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "t
}];
let arguments = (ins XeTile_1DOr2DOr4DVector: $value,
XeTile: $tile,
XeTile_MaskType: $mask);
XeTile_MaskType: $mask,
OptionalAttr<XeTile_CacheHintAttr>: $l1_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l2_hint,
OptionalAttr<XeTile_CacheHintAttr>: $l3_hint);
let assemblyFormat = [{
$value `,` $tile `,` $mask attr-dict `:` type($value) `,` qualified(type($tile)) `,` type($mask)
}];
Expand Down
58 changes: 26 additions & 32 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,9 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
};

static mlir::xegpu::CachePolicy
translateCachePolicy(imex::xetile::CachePolicyAttr val) {
translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy defaultVal) {
if (!val)
return mlir::xegpu::CachePolicy::CACHED;
return defaultVal;

switch (val.getValue()) {
case imex::xetile::CachePolicy::CACHED:
Expand All @@ -522,6 +522,22 @@ translateCachePolicy(imex::xetile::CachePolicyAttr val) {
llvm_unreachable("Invalid CachePolicy value");
}

template<typename OpTy>
static auto
getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = mlir::xegpu::CachePolicy::CACHED) {

auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) {
return mlir::xegpu::CachePolicyAttr::get(op.getContext(),
translateCachePolicy(val, defaultVal));
};

auto L1 = getCachePolicyAttr(op.getL1HintAttr());
auto L2 = getCachePolicyAttr(op.getL2HintAttr());
auto L3 = getCachePolicyAttr(op.getL3HintAttr());

return std::make_tuple(L1, L2, L3);
}

// It lowers a XeTile::prefetch_tile into one or more mlir::xegpu::prefetch_2d.
// The adaptor will provide the set of xegpu.create_nd_desc lowered for
// its input tile.
Expand Down Expand Up @@ -551,14 +567,7 @@ struct SgPrefetchTileOpPattern
return mlir::failure();
}

auto getCachePolicy = [&](imex::xetile::CachePolicyAttr val) {
return mlir::xegpu::CachePolicyAttr::get(op.getContext(),
translateCachePolicy(val));
};

auto L1 = getCachePolicy(op.getL1HintAttr());
auto L2 = getCachePolicy(op.getL2HintAttr());
auto L3 = getCachePolicy(op.getL3HintAttr());
auto [L1, L2, L3] = getCachePolicy(op);

for (auto tile : tiles) {
rewriter.create<mlir::xegpu::PrefetchNdOp>(op.getLoc(), tile, L1, L2, L3);
Expand Down Expand Up @@ -589,16 +598,7 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
auto elemTy = tileTy.getElementType();
auto sources = adaptor.getSource();

auto ctx = op.getContext();

auto getDefaultCachePolicy = [&]() {
return mlir::xegpu::CachePolicyAttr::get(
ctx, mlir::xegpu::CachePolicy::CACHED);
};

auto L1 = getDefaultCachePolicy();
auto L2 = getDefaultCachePolicy();
auto L3 = getDefaultCachePolicy();
auto [L1, L2, L3] = getCachePolicy(op);

// The tile is in col-major order, which should be canonicalized to
// row-major in canonicalization pass.
Expand Down Expand Up @@ -657,13 +657,11 @@ struct SgLoadGatherOpPattern : public XeOneToNConversion<xetile::LoadGatherOp> {
rewriter.getIntegerType(1));
llvm::SmallVector<mlir::Value> xegpuOps;
auto transposeAttr = mlir::UnitAttr();
auto cacheAttr = mlir::xegpu::CachePolicyAttr::get(
op.getContext(), mlir::xegpu::CachePolicy::CACHED);
auto [L1, L2, L3] = getCachePolicy(op);
for (auto [t, m] : llvm::zip(tiles, masks)) {
m = rewriter.create<ShapeCastOp>(op.getLoc(), maskTy, m);
auto ldOp = rewriter.create<mlir::xegpu::LoadGatherOp>(
op.getLoc(), vecTy, t, m, transposeAttr, cacheAttr, cacheAttr,
cacheAttr);
op.getLoc(), vecTy, t, m, transposeAttr, L1, L2, L3);
auto v = rewriter.create<ShapeCastOp>(op.getLoc(), resTy, ldOp);
xegpuOps.push_back(v);
}
Expand Down Expand Up @@ -691,11 +689,8 @@ struct SgStoreTileOpPattern : public XeOneToNConversion<xetile::StoreTileOp> {
<< "values: " << values.size() << "\n";
}

auto context = op.getContext();
auto WRITE_BACK = mlir::xegpu::CachePolicy::WRITE_BACK;
auto L1 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto L2 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto L3 = mlir::xegpu::CachePolicyAttr::get(context, WRITE_BACK);
auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);

for (size_t i = 0; i < tiles.size(); i++)
rewriter.create<mlir::xegpu::StoreNdOp>(op.getLoc(), values[i], tiles[i],
L1, L2, L3);
Expand Down Expand Up @@ -726,13 +721,12 @@ struct SgStoreScatterOpPattern
auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1],
rewriter.getIntegerType(1));
auto transposeAttr = mlir::UnitAttr();
auto cacheAttr = mlir::xegpu::CachePolicyAttr::get(
op.getContext(), mlir::xegpu::CachePolicy::WRITE_BACK);
auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);
for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) {
m = rewriter.create<ShapeCastOp>(op.getLoc(), maskTy, m);
v = rewriter.create<ShapeCastOp>(op.getLoc(), vecTy, v);
rewriter.create<mlir::xegpu::StoreScatterOp>(
op.getLoc(), v, t, m, transposeAttr, cacheAttr, cacheAttr, cacheAttr);
op.getLoc(), v, t, m, transposeAttr, L1, L2, L3);
}
rewriter.eraseOp(op);
return mlir::success();
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ struct LoadTileOpPattern
tileTy.getElementType());
mlir::Value newOp = rewriter.create<xetile::LoadTileOp>(
op.getLoc(), vecTy, adaptor.getSource(),
op.getPadding().value_or(mlir::Attribute()));
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -324,7 +324,7 @@ struct LoadGatherOpPattern
auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
mlir::Value newOp = rewriter.create<xetile::LoadGatherOp>(
op.getLoc(), vecTy, source, mask,
op.getPadding().value_or(mlir::Attribute()));
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -352,7 +352,7 @@ struct StoreTileOpPattern
// its inputs has not been updated yet.
if (blockSize && valTy.getRank() == 2) {
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(op, value, tile);
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down Expand Up @@ -382,7 +382,7 @@ struct StoreScatterOpPattern
auto mask =
addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(op, value, tile,
mask);
mask, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class WGToSGLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
mlir::VectorType::get({tileTy.getShape()[0], tileTy.getShape()[1]},
tileTy.getElementType());
auto newLoadOp = rewriter.create<xetile::LoadTileOp>(
op.getLoc(), newResTy, src, op.getPaddingAttr());
op.getLoc(), newResTy, src, op.getPaddingAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
newLoadOps.push_back(newLoadOp);
newResultTypes.push_back(newLoadOp.getResult().getType());
}
Expand Down Expand Up @@ -306,7 +306,7 @@ class WGToSGStoreTileOpPattern : public XeOneToNConversion<xetile::StoreTileOp>

for (size_t i = 0; i < newValues.size(); i++) {
rewriter.create<xetile::StoreTileOp>(op.getLoc(), newValues[i],
newDstTiles[i]);
newDstTiles[i], op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
}

rewriter.eraseOp(op);
Expand Down Expand Up @@ -745,8 +745,9 @@ class WGToSGXeTileConvertLayout
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
//TODO: Set up cache attributes
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp);
storeInitTileOp, nullptr, nullptr, nullptr);

// Add barrier
rewriter.create<mlir::gpu::BarrierOp>(loc);
Expand All @@ -766,8 +767,9 @@ class WGToSGXeTileConvertLayout
loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1]));
auto loadInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
//TODO: Set up cache attributes
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute());
loc, newResTy, loadInitTileOp, mlir::Attribute(), nullptr, nullptr, nullptr);

rewriter.replaceOp(op, loadTile);
return mlir::success();
Expand Down
19 changes: 17 additions & 2 deletions test/Conversion/XeTileToXeGPU/sg_load_tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,22 @@ gpu.module @test_kernel {
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
//CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}>
//CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
%2 = xetile.load_tile %1 : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
%2 = xetile.load_tile %1 {padding = 0 : i32} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
}

//CHECK: gpu.func @sg_load_tile_cache_attr(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>, %[[arg2:.*]]: memref<1024x1024xf32>) {
gpu.func @sg_load_tile_cache_attr(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
//CHECK: %[[c64:.*]] = arith.constant 64 : index
%c64 = arith.constant 64 : index
//CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]]
//CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
//CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}>
//CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x32xf16>
%2 = xetile.load_tile %1 {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<streaming>} : !xetile.tile<32x32xf16> -> vector<32x32xf16>
gpu.return
}
}
27 changes: 27 additions & 0 deletions test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ gpu.module @test {
gpu.return
}

//CHECK-LABEL: @test_init_tile_for_scattered_cache_attr
//CHECK-SAME: %[[arg0:.*]]: memref<1024xf16>
gpu.func @test_init_tile_for_scattered_cache_attr(%arg0: memref<1024xf16>) {
//CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<32xi1>
//CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex>
//CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>
//CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<streaming>}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1> -> vector<32xf16>
//CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xindex>
//CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
//CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr<memory_space = global, chunk_size = 1 : i64>>, vector<32xi1>
%cst = arith.constant dense<true> : vector<4x32xi1>
%cst_0 = arith.constant dense<1> : vector<4x32xindex>
%0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>
%1 = xetile.load %0, %cst {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<streaming>} : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1> -> vector<4x32xf16>
%2 = xetile.update_tile_offset %0, %cst_0 : !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xindex>
xetile.store %1, %0, %cst {l1_hint = #xetile.cache_hint<uncached>, l3_hint = #xetile.cache_hint<uncached>} : vector<4x32xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr<scattered = true>>, vector<4x32xi1>
gpu.return
}

//CHECK-LABEL: @add_kernel
//CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32>
gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) {
Expand Down
Loading

0 comments on commit ca096a1

Please sign in to comment.