diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index 174d02b40..28b881b6c 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -722,6 +722,63 @@ static func::CallOp gen2DStoreIntrinsicCall( nblks, shape, payload, data); } +auto get1DTdescNumTotalElems = [](TensorDescType tdescTy) -> int64_t { + return tdescTy.getNumElements() * tdescTy.getArrayLength(); +}; + +auto getElemBitWidth = [](TensorDescType tdescTy) -> unsigned { + return tdescTy.getElementType().getIntOrFloatBitWidth(); +}; + +auto isLowPrecision = [](TensorDescType tdescTy) -> bool { + // Note: Handling for sub 8bit types is unclear so report as false + auto width = getElemBitWidth(tdescTy); + return width < 32 && width >= 8; +}; + +auto getScaled1DTdesc = + [](TensorDescType tdescTy, + ConversionPatternRewriter &rewriter) -> TensorDescType { + // return if not 1D tensor desc + if (tdescTy.getShape().size() != 1) + return tdescTy; + // return if not low precision + if (!isLowPrecision(tdescTy)) + return tdescTy; + + auto scaledTy = tdescTy.getElementType(); + auto totalBytes = + get1DTdescNumTotalElems(tdescTy) * getElemBitWidth(tdescTy) / 8; + switch (totalBytes) { + // i32 for 4, 8, 12, 16, 32, 64, 128, 256 + // i64 for 24 and 512 + case 4: + case 8: + case 12: + case 16: + case 32: + case 64: + case 128: + case 256: + scaledTy = rewriter.getI32Type(); + break; + case 24: + case 512: + scaledTy = rewriter.getI64Type(); + break; + default: + break; + } + return TensorDescType::get( + tdescTy.getContext(), + {totalBytes / (scaledTy.getIntOrFloatBitWidth() / 8)}, scaledTy, + tdescTy.getEncoding(), /*sg_map*/ nullptr); +}; + +auto isScaled = [](TensorDescType tdescTy, TensorDescType scaledTy) -> bool { + return getElemBitWidth(tdescTy) != getElemBitWidth(scaledTy); +}; + #define shrui(...) rewriter.createOrFold(loc, __VA_ARGS__) class LoadNdPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -759,16 +816,25 @@ class LoadNdPattern : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "transpose is not supported for slm and 1D tensor desc"); - auto elems = tdescTy.getNumElements() * tdescTy.getArrayLength(); + auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter); + auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy); + auto scaledElemTy = scaledTdescTy.getElementType(); - if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter))) { + if (failed( + isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter))) { return rewriter.notifyMatchFailure( loc, "unsupported 1D/SLM TensorDescType."); } - + bool scaled = isScaled(tdescTy, scaledTdescTy); + auto resTy = + scaled ? VectorType::get({scaledElems}, scaledElemTy) : op.getType(); auto newValue = gen1DLoadInstrinsicCall( - rewriter, loc, op.getType(), l1hint, l3hint, elemTy, elems, + rewriter, loc, resTy, l1hint, l3hint, scaledElemTy, scaledElems, tdescTy.getMemorySpace(), adaptor.getTensorDesc()); + if (scaled) { + newValue = + rewriter.create(loc, op.getType(), newValue); + } rewriter.replaceOp(op, newValue); return success(); } else if (rank == 2) { // 2d.ugm.desc @@ -866,7 +932,6 @@ class PrefetchNdPattern : public OpConversionPattern { auto loc = op.getLoc(); auto tdescTy = op.getTensorDescType(); - auto elemTy = tdescTy.getElementType(); auto rank = tdescTy.getRank(); auto scope = tdescTy.getMemorySpace(); @@ -881,15 +946,17 @@ class PrefetchNdPattern : public OpConversionPattern { return success(); } - auto elems = tdescTy.getNumElements() * tdescTy.getArrayLength(); + auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter); + auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy); + auto scaledElemTy = scaledTdescTy.getElementType(); - if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter))) + if (failed(isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter))) return rewriter.notifyMatchFailure( loc, "unsupported 1D/SLM TensorDescType."); - auto callOp = - gen1DPrefetchIntrinsicCall(rewriter, loc, l1hint, l3hint, elemTy, - elems, scope, adaptor.getTensorDesc()); + auto callOp = gen1DPrefetchIntrinsicCall(rewriter, loc, l1hint, l3hint, + scaledElemTy, scaledElems, scope, + adaptor.getTensorDesc()); rewriter.replaceOp(op, callOp); return success(); } else if (rank == 2) { // 2d.ugm.desc @@ -914,7 +981,6 @@ class StoreNdPattern : public OpConversionPattern { auto loc = op.getLoc(); auto tdescTy = op.getTensorDescType(); - auto elemTy = tdescTy.getElementType(); auto rank = tdescTy.getRank(); auto scope = tdescTy.getMemorySpace(); @@ -925,25 +991,21 @@ class StoreNdPattern : public OpConversionPattern { auto data = adaptor.getValue(); if (rank == 1) { - // for slm and 1D tensor desc, use lsc.store, - // all non 32-bit data has to be encoded as i32. - - // get instrinsic name, the data type has to be encoded - // as vNi32 for 8-bit/16-bit data in regular store. - // for example, Vector<8x16xf16> should be encoded as V128I32. - auto lscTy = getOrigOrI32VectorType(op.getValueType()); - auto typeStr = convertVectorType(lscTy).first; - auto intrinsicStr = getLSCIntrinsicStr("store", 1, scope, typeStr); - - auto elems = tdescTy.getNumElements(); + auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter); + auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy); + auto scaledElemTy = scaledTdescTy.getElementType(); - if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter))) + if (failed(isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter))) return rewriter.notifyMatchFailure( loc, "unsupported 1D/SLM TensorDescType."); - auto callOp = - gen1DStoreInstrinsicCall(rewriter, loc, l1hint, l3hint, elemTy, elems, - scope, adaptor.getTensorDesc(), data); + if (isScaled(tdescTy, scaledTdescTy)) { + auto scaledVecTy = VectorType::get({scaledElems}, scaledElemTy); + data = rewriter.create(loc, scaledVecTy, data); + } + auto callOp = gen1DStoreInstrinsicCall(rewriter, loc, l1hint, l3hint, + scaledElemTy, scaledElems, scope, + adaptor.getTensorDesc(), data); rewriter.replaceOp(op, callOp); return success(); diff --git a/test/Conversion/XeGPUToVC/load_store_prefetch_1D_bf16.mlir b/test/Conversion/XeGPUToVC/load_store_prefetch_1D_bf16.mlir new file mode 100644 index 000000000..f78488379 --- /dev/null +++ b/test/Conversion/XeGPUToVC/load_store_prefetch_1D_bf16.mlir @@ -0,0 +1,27 @@ +// RUN: imex-opt -convert-xegpu-to-vc -cse %s | FileCheck %s + +gpu.module @load_store_bf16 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @load_store_bf16(%arg0: memref<4x2x128xbf16>, %arg1: memref<4x2x128xbf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c32 = arith.constant 32 : index + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %thread_id_z = gpu.thread_id z + %0 = arith.muli %thread_id_z, %c32 : index + %1 = xegpu.create_nd_tdesc %arg0[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> + + // CHECK: func.call @llvm.genx.lsc.prefetch.stateless.v1i1.v1i64 + xegpu.prefetch_nd %1 : !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> + %2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> + + // CHECK: %[[LOAD_VAL:.*]] = func.call @llvm.genx.lsc.load.stateless.v16i32.v1i1.v1i64 + // CHECK: %[[REAL_VAL:.*]] = vector.bitcast %[[LOAD_VAL]] : vector<16xi32> to vector<32xbf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> -> vector<32xbf16> + %4 = xegpu.create_nd_tdesc %arg1[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> + + // CHECK: %[[STORE_VAL:.*]] = vector.bitcast %[[REAL_VAL]] : vector<32xbf16> to vector<16xi32> + // CHECK: func.call @llvm.genx.lsc.store.stateless.v1i1.v1i64.v16i32 + // CHECK: %[[STORE_VAL]], %[[LAST_ARG:.*]]) : + xegpu.store_nd %3, %4 : vector<32xbf16>, !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/test/Integration/Dialect/XeGPU/load_store_with_1d_bf16_tile.mlir b/test/Integration/Dialect/XeGPU/load_store_with_1d_bf16_tile.mlir new file mode 100644 index 000000000..570cd44e3 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/load_store_with_1d_bf16_tile.mlir @@ -0,0 +1,86 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_8x32xbf16 : memref<8x32xbf16> = dense<0.0> + func.func @test(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>) -> memref<8x32xbf16> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + + %memref = gpu.alloc host_shared () : memref<8x32xbf16> + memref.copy %arg0, %memref : memref<8x32xbf16> to memref<8x32xbf16> + %memref_1 = gpu.alloc host_shared () : memref<8x32xbf16> + memref.copy %arg1, %memref_1 : memref<8x32xbf16> to memref<8x32xbf16> + %memref_2 = gpu.alloc host_shared () : memref<8x32xbf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x32xbf16>, %memref_1 : memref<8x32xbf16>, %memref_2 : memref<8x32xbf16>) + gpu.dealloc %memref : memref<8x32xbf16> + gpu.dealloc %memref_1 : memref<8x32xbf16> + return %memref_2 : memref<8x32xbf16> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>, %arg2: memref<8x32xbf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %thread_id_x = gpu.thread_id x + cf.br ^bb1 + ^bb1: + %0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32xbf16> -> vector<32xbf16> + %2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32xbf16> -> vector<32xbf16> + %4 = arith.addf %3, %1 : vector<32xbf16> + %5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16> + xegpu.store_nd %4, %5 : vector<32xbf16>, !xegpu.tensor_desc<32xbf16> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant -0.5 : f32 + %cf_upper = arith.constant 0.5 : f32 + + %A = memref.alloc() : memref<8x32xbf16> + %A_random = memref.cast %A : memref<8x32xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + %B = memref.alloc() : memref<8x32xbf16> + %B_random = memref.cast %B : memref<8x32xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // calculate the result C matrix + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %ref = memref.alloc() : memref<8x32xf32> + scf.for %i = %c0 to %c8 step %c1 { + scf.for %j = %c0 to %c32 step %c1 { + %a = memref.load %A[%i, %j] : memref<8x32xbf16> + %b = memref.load %B[%i, %j] : memref<8x32xbf16> + %a_ext = arith.extf %a : bf16 to f32 + %b_ext = arith.extf %b : bf16 to f32 + %c = arith.addf %a_ext, %b_ext : f32 + %c_trunc = arith.truncf %c : f32 to bf16 + %c_ext = arith.extf %c_trunc : bf16 to f32 + memref.store %c_ext, %ref[%i, %j] : memref<8x32xf32> + } + } + + %C = call @test(%A, %B) : (memref<8x32xbf16>, memref<8x32xbf16>) -> memref<8x32xbf16> + + %C_cast = memref.cast %C : memref<8x32xbf16> to memref<*xbf16> + %ref_cast = memref.cast %ref : memref<8x32xf32> to memref<*xf32> + //call @printMemrefBF16(%C_cast) : (memref<*xbf16>) -> () + //call @printMemrefF32(%ref_cast) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseBF16(%C_cast, %ref_cast) : (memref<*xbf16>, memref<*xf32>) -> () + return + } + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} +}