Skip to content

Commit

Permalink
add memory scope attribute for TileType (#819)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Aug 2, 2024
1 parent 4e86e50 commit b523ab2
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 26 deletions.
24 changes: 17 additions & 7 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
OptionalParameter<"xetile::WorkGroupMapAttr">:$wg_map,
DefaultValuedParameter<"mlir::DenseI32ArrayAttr", "mlir::DenseI32ArrayAttr::get($_ctxt, {1, 0})">:$order,
OptionalParameter<"mlir::DenseI64ArrayAttr">:$inner_blocks,
OptionalParameter<"mlir::DenseI32ArrayAttr">:$wg_data
OptionalParameter<"mlir::DenseI32ArrayAttr">:$wg_data,
OptionalParameter<"mlir::Attribute">:$memory_scope
);
let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = true;
Expand All @@ -73,27 +74,36 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"llvm::ArrayRef<int64_t>", "{}">:$inner_blocks,
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data),
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data,
CArg<"int", "0">:$memory_scope),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks),
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data));
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data),
mlir::IntegerAttr::get(intType, memory_scope));
}]>,
AttrBuilder<(ins CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order),
AttrBuilder<(ins CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"int", "0">:$memory_scope),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(),
mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::DenseI32ArrayAttr::get($_ctxt, {}));
mlir::DenseI32ArrayAttr::get($_ctxt, {}),
mlir::IntegerAttr::get(intType, memory_scope));
}]>,
AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map,
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data),
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data,
CArg<"int", "0">:$memory_scope),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data));
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data),
mlir::IntegerAttr::get(intType, memory_scope));
}]>
];
}
Expand Down
15 changes: 15 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
return getDynamicStrides().size();
}

mlir::Attribute getSourceMemorySpace() {
if (isSourceMemRef())
return mlir::cast<mlir::MemRefType>(getSourceType()).getMemorySpace();
return mlir::Attribute();
}

unsigned getSourceMemorySpaceAsInt() {
auto attr = getSourceMemorySpace();
if (attr) {
if (mlir::isa<mlir::IntegerAttr>(attr))
return static_cast<unsigned>(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
return 0;
}

/// Returns the offsets info to the source. It consolidates
/// information from both dynamic_offsets and static_offsets
/// parameters. static_offsets parameter always has the expected
Expand Down
19 changes: 19 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,25 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
return mlir::DenseI32ArrayAttr::get(getContext(), {1, 0});
}

mlir::Attribute getMemoryScope() {
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding)
return encoding.getMemoryScope();
return mlir::Attribute();
}

int getMemoryScopeAsInt() {
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding && encoding.getMemoryScope()) {
auto memoryScope = encoding.getMemoryScope();
assert(mlir::isa<mlir::IntegerAttr>(memoryScope) &&
"Using `getMemorySpaceAsInt` with non-Integer attribute");
return mlir::cast<mlir::IntegerAttr>(memoryScope).getInt();
}
// return default value 0 indicating Global memory
return 0;
}

}];

let assemblyFormat = "`<` custom<XeTileType>($shape, $elementType, $encoding) `>`";
Expand Down
8 changes: 6 additions & 2 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
auto shape = llvm::to_vector(tileTy.getShape());
auto indexType = rewriter.getIndexType();

auto memoryScope = op.getSourceMemorySpaceAsInt() == 3
? mlir::xegpu::MemoryScope::SLM
: mlir::xegpu::MemoryScope::Global;

if (tileTy.getRank() != 2)
return op.emitOpError("The tile shape should be 2D.");

Expand Down Expand Up @@ -457,8 +461,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
std::swap(offsetsX, offsetsY);

auto tDescTy = mlir::xegpu::TensorDescType::get(
innerBlk, elemTy, false /*scattered*/, array_length,
mlir::xegpu::MemoryScope::Global, true /*boundary_check*/);
innerBlk, elemTy, false /*scattered*/, array_length, memoryScope,
true /*boundary_check*/);

auto createIndexConstant = [&](mlir::Type type, int64_t value) {
auto attr = rewriter.getIndexAttr(value);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/XeTile/IR/XeTileDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mlir::LogicalResult XeTileAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::imex::xetile::SubGroupMapAttr sg_map, xetile::WorkGroupMapAttr wg_map,
mlir::DenseI32ArrayAttr order, mlir::DenseI64ArrayAttr inner_blocks,
mlir::DenseI32ArrayAttr wg_data) {
mlir::DenseI32ArrayAttr wg_data, mlir::Attribute memoryScope) {

if (order != mlir::DenseI32ArrayAttr() && order.size() != 2)
emitError() << "expect integer array of size 2 for order";
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/XeTile/IR/XeTileOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ mlir::LogicalResult InitTileOp::verify() {
row_major = false;
}

if (getSourceMemorySpace() != tileTy.getMemoryScope())
return emitOpError(
"memory space of the tile doesn't match with the source.");

if (isSourceMemRef() && sourceMemRefHasStaticShape()) {
auto memrefType = mlir::dyn_cast<mlir::MemRefType>(getSourceType());

Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/XeTile/Transforms/BlockAligning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ struct InitTileOpPattern

auto attr = imex::xetile::XeTileAttr::get(
op.getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
tileTy.getOrder(), newBlockSize, tileTy.getWgData());
tileTy.getOrder(), newBlockSize, tileTy.getWgData(),
tileTy.getMemoryScope());

auto newTileTy = imex::xetile::TileType::get(tileTy.getShape(),
tileTy.getElementType(), attr);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ struct InitTileOpPattern

auto attr = imex::xetile::XeTileAttr::get(
op.getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
tileTy.getOrder(), innerBlocks, tileTy.getWgData());
tileTy.getOrder(), innerBlocks, tileTy.getWgData(),
tileTy.getMemoryScope());

auto newTileTy =
imex::xetile::TileType::get(tileTy.getShape(), elemTy, attr);
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/XeTile/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ struct InitTileOpPattern final
imex::swapLastTwoElements(sourceShape), newStrides);

// Create a new initTileOp with the new source by using the order attribute
auto orderAttr = sourceIsRowMajor
? mlir::DenseI32ArrayAttr::get(getContext(), {0, 1})
: mlir::DenseI32ArrayAttr::get(getContext(), {1, 0});
auto newTileAttr = imex::xetile::XeTileAttr::get(
getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
(sourceIsRowMajor ? mlir::DenseI32ArrayAttr::get(getContext(), {0, 1})
: mlir::DenseI32ArrayAttr::get(getContext(), {1, 0})),
tileTy.getInnerBlocks(), tileTy.getWgData());
getContext(), tileTy.getSgMap(), tileTy.getWgMap(), orderAttr,
tileTy.getInnerBlocks(), tileTy.getWgData(), tileTy.getMemoryScope());
auto transposedTileTy = imex::xetile::TileType::get(
imex::swapLastTwoElements(initOp.getType().getShape()),
initOp.getElementType(), newTileAttr);
Expand Down
20 changes: 10 additions & 10 deletions test/Conversion/XeTileToXeGPU/sg_mixed_scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,23 @@ gpu.module @postop_reduce_m attributes {spirv.target_env = #spirv.target_env<#sp
//CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
%40 = vector.multi_reduction <add>, %39, %cst_1 [0] : vector<32x32xf32> to vector<32xf32>
%41 = vector.shape_cast %40 : vector<32xf32> to vector<1x32xf32>
%alloc = memref.alloc() : memref<8x128xf32, #spirv.storage_class<Workgroup>>
%alloc = memref.alloc() : memref<8x128xf32, 3>

//CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, #spirv.storage_class<Workgroup>> -> !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true, scattered = false>>
//CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, 3> -> !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 1 : i64, boundary_check = true, scattered = false>>
//CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : index
//CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, #spirv.storage_class<Workgroup>> -> !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true, scattered = false>>
%42 = xetile.init_tile %alloc[%17, %13] : memref<8x128xf32, #spirv.storage_class<Workgroup>> -> !xetile.tile<1x32xf32>
//CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, 3> -> !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 1 : i64, boundary_check = true, scattered = false>>
%42 = xetile.init_tile %alloc[%17, %13] : memref<8x128xf32, 3> -> !xetile.tile<1x32xf32, #xetile.tile_attr<memory_scope = 3>>

//CHECK-COUNT-2: vector.extract_strided_slice %{{.*}} {offsets = {{.*}}, sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
//CHECK-COUNT-2: xegpu.store_nd %{{.*}}, %{{.*}} <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true, scattered = false>>
xetile.store_tile %41, %42 : vector<1x32xf32>, !xetile.tile<1x32xf32>
//CHECK-COUNT-2: xegpu.store_nd %{{.*}}, %{{.*}} <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 1 : i64, boundary_check = true, scattered = false>>
xetile.store_tile %41, %42 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr<memory_scope = 3>>

//CHECK: xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, #spirv.storage_class<Workgroup>> -> !xegpu.tensor_desc<8x4xf32, #xegpu.tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true, scattered = false>>
//CHECK: xegpu.load_nd {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x4xf32, #xegpu.tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true, scattered = false>> -> vector<8x4xf32>
//CHECK: xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf32, 3> -> !xegpu.tensor_desc<8x4xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 1 : i64, boundary_check = true, scattered = false>>
//CHECK: xegpu.load_nd {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x4xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 1 : i64, boundary_check = true, scattered = false>> -> vector<8x4xf32>
//CHECK-COUNT-8: vector.extract_strided_slice %{{.*}} {offsets = {{.*}}, sizes = [1, 4], strides = [1, 1]} : vector<8x4xf32> to vector<1x4xf32>
//CHECK-COUNT-8: arith.addf %{{.*}}, %{{.*}} : vector<1x4xf32>
%43 = xetile.init_tile %alloc[%21, %23] : memref<8x128xf32, #spirv.storage_class<Workgroup>> -> !xetile.tile<8x4xf32>
%44 = xetile.load_tile %43 { padding = 0.000000e+00 : f32 } : !xetile.tile<8x4xf32> -> vector<8x4xf32>
%43 = xetile.init_tile %alloc[%21, %23] : memref<8x128xf32, 3> -> !xetile.tile<8x4xf32, #xetile.tile_attr<memory_scope = 3>>
%44 = xetile.load_tile %43 { padding = 0.000000e+00 : f32 } : !xetile.tile<8x4xf32, #xetile.tile_attr<memory_scope = 3>> -> vector<8x4xf32>
%45 = vector.multi_reduction <add>, %44, %cst_2 [0] : vector<8x4xf32> to vector<4xf32>
%46 = vector.shape_cast %45 : vector<4xf32> to vector<1x4xf32>
%47 = arith.addf %arg5, %46 : vector<1x4xf32>
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/XeTile/IR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ func.func @tile_unpack_invalid_output_shape(%in : vector<4x4x16x16xf16>) {
%out = xetile.tile_unpack %in {inner_blocks = [16, 16]} : vector<4x4x16x16xf16> -> vector<32x64xf16>
}

// -----
func.func @test_init_tile_with_mismatch_memory_space(%a: memref<1024x1024xf16, 3>) {
// expected-error@+1 {{memory space of the tile doesn't match with the source}}
%1 = xetile.init_tile %a[8, 16] : memref<1024x1024xf16, 3> -> !xetile.tile<32x64xf16>
return
}

// -----
// expected-error@+1 {{expect integer array of size 2 for wi_layout}}
#sg_map_2 = #xetile.sg_map< wi_layout = [2, 8, 2], wi_data = [1, 2]>
Expand Down
11 changes: 11 additions & 0 deletions test/Dialect/XeTile/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@
#wg_map_b = #xetile.wg_map<sg_layout = [16, 1], sg_data = [16, 1]>
#wg_map_b2 = #xetile.wg_map<sg_layout = [4, 4], sg_data = [64, 64]>

func.func @test_init_tile_for_slm(%a: memref<1024x1024xf16, 3>) {
//CHECK: xetile.init_tile {{.*}}[8, 16] : memref<1024x1024xf16, 3> -> !xetile.tile<32x64xf16, #xetile.tile_attr<memory_scope = 3 : i64>>
%1 = xetile.init_tile %a[8, 16] : memref<1024x1024xf16, 3> -> !xetile.tile<32x64xf16, #xetile.tile_attr<memory_scope = 3>>
return
}

func.func @test_init_tile_for_global(%a: memref<1024x1024xf16, 0>) {
//CHECK: xetile.init_tile {{.*}}[8, 16] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16>
%1 = xetile.init_tile %a[8, 16] : memref<1024x1024xf16, 0> -> !xetile.tile<32x64xf16>
return
}

// init_tile with a static shaped memref
// CHECK-LABEL: func @test_init_tile_using_static_memref({{.*}}) {
Expand Down

0 comments on commit b523ab2

Please sign in to comment.