Skip to content

Commit

Permalink
LSC 2D prefetch: Unify type of dummy data arg to i32 (#940)
Browse files Browse the repository at this point in the history
LSC 2D prefetch: Unify type of dummy data arg to i32 since
the generated intinrsic call is independent of dummy data arg type.
Type mismatch error happens if multiple callers use different types
for dummy data arg.
For non 32bit element type, normalize 2D prefetch shape around i32
data type. Inner dimension gets scaled according to ratio between
32 and bitwidth of element type
Add mixed type 2D prefetch test case.
  • Loading branch information
silee2 authored Oct 23, 2024
1 parent e40f1ea commit 52cadab
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
29 changes: 22 additions & 7 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,31 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc,
auto intrinsicStr = getBlockIntrinsicStr("prefetch");
auto nblks = tdescTy.getArrayLength();
auto shape = tdescTy.getShape();
auto elemTy = tdescTy.getElementType();
auto noRetTy = TypeRange({});
auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// Sub 32bit data types are packed into 32bit data types (i32).
auto packFactor = 32 / bitWidth;

// If packing is needed, the innermost dimensions gets scaled by the packing
// factor. In such case, the shape[1] must be a multiple of the pack factor.
// Otherwise, packing cannot be done correctly
if (packFactor > 1) {
assert(
shape[1] % packFactor == 0 &&
"shape[1] must be a multiple of pack factor (32 / element bitwidth)");
}

// for arg8: dummy value
auto attr = elemTy.isInteger()
? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0)
: (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0);
// for arg8: dummy value, type has to be always the same since intrinsic
// func name for prefetch is the same regardless of the element type.
// Different type used for dummy causes type conflict in case of multiple
// calls with different dummy arg type.
auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
auto dummy = constant_val(attr);
return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3,
nblks, shape, payload, dummy);
return gen2DBlockIntrinsicCall(
rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks,
{shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor},
payload, dummy);
}

// generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is
Expand Down
49 changes: 37 additions & 12 deletions test/Conversion/XeGPUToVC/prefetchnd.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
// RUN: imex-opt -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC
// RUN: imex-opt --split-input-file -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC

// -----
module @gemm attributes {gpu.container_module} {

gpu.module @test_kernel {

//RAW: func.func private @llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xf32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.v128i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xi32>) -> vector<128xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.v128i32.i1.v16i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.v64i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<64xi32>) -> vector<64xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.v64i32.i1.v16i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.noresult.i1.v16i32(i8, i8, i1, i8, i8, i32, i32, vector<16xi32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.noresult.i1.v16i32", linkage_type = <Import>>}

gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index
Expand Down Expand Up @@ -55,15 +50,14 @@ module @gemm attributes {gpu.container_module} {
//CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16>

Expand All @@ -89,3 +83,34 @@ module @gemm attributes {gpu.container_module} {

}
}

// -----
module @two_type attributes {gpu.container_module} {

gpu.module @test_kernel {
gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[c0_i32:.*]] = arith.constant 0 : i32
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32>

%3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%4 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>

xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}

}
}

0 comments on commit 52cadab

Please sign in to comment.