diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index a7a551d07..8298126dc 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -560,6 +560,42 @@ llvm::SmallVector defaultStrides(llvm::ArrayRef shape); /// capability). bool isVectorAnyINTELType(mlir::Type type); +/// convert OpFoldResult to Value by replacing integer +/// attributes with arith::ConstantOps. It also performs +/// simple type conversions +mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc, + mlir::PatternRewriter &rewriter, + mlir::Type type = nullptr); + +// A universal get method for offsets or shapes or strides (OSS) of +// xetile::InitTileOp and xegpu::CreateNdDescOp op. +// OSS (Offsets, Shapes, Strides) information provided +// to InitTileOp & CreateNdDescOp is multifaceted. In other words oss info +// provided to InitTileOp & CreateNdDescOp in multiple ways, especially the +// shapes and strides: +// 1. For static memrefs: the shapes and strides info are inherent in the +// memref data type + +// 2. For dynamic memrefs and i64/i32 source: the shapes and strides info is +// provided via the operands `sizes` and `strides` repectively, however these +// operands can also take two different types: + +// 2.1 Constant type: constant attribute can be passed +// 2.2 Value type: a value type can be passed + +// This function collects these info based on different scenarios and returns +// them in Value types. + +// One can pass the result of getMixedOffsets(), getMixedSizes(), +// getMixedStrides() to the following utility to get them as Value types. +// Since both xetile::InitTileOp and xegpu::CreateNdDescOp ops implement the +// OffsetSizeAndStrideOpInterface, getMixedOffsets(), getMixedSizes(), +// getMixedStrides() takes care of the different scenarios mentioned above. + +llvm::SmallVector getStridesOrOffsetsOrShapesInValueType( + mlir::PatternRewriter &rewriter, + ::llvm::SmallVector mixedOSS, mlir::Location loc); + } // namespace imex #endif diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 8fa5b779e..05562a390 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -13,28 +13,26 @@ /// //===----------------------------------------------------------------------===// -#include - +#include "imex/Conversion/XeGPUToVC/XeGPUToVC.h" +#include "imex/Conversion/ArithToVC/ArithToVC.h" +#include "imex/Conversion/MathToVC/MathToVC.h" +#include "imex/Utils/VCUtils.h" +#include "imex/Utils/XeCommon.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" - -#include "imex/Conversion/ArithToVC/ArithToVC.h" -#include "imex/Conversion/MathToVC/MathToVC.h" -#include "imex/Utils/VCUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -71,23 +69,6 @@ static bool isOneOrUnknow(OpFoldResult ofr) { return !val || *val == 1; } -// convert OpFoldResult to Value by replacing integer -// attributes with arith::ConstantOps. It also performs -// simple type conversions -static Value getValueOrConstantOp(OpFoldResult ofr, Location loc, - ConversionPatternRewriter &rewriter, - Type type = nullptr) { - if (ofr.is()) - return ofr.get(); - - auto intAttr = cast(ofr.get()); - - if (type) - intAttr = IntegerAttr::get(type, intAttr.getInt()); - - return rewriter.create(loc, intAttr); -} - static Value castValueTo(Value val, Type toType, Location loc, ConversionPatternRewriter &rewriter) { @@ -120,36 +101,6 @@ static Value castValueTo(Value val, Type toType, Location loc, #define addi(a, b) rewriter.createOrFold(loc, a, b) #define subi(a, b) rewriter.createOrFold(loc, a, b) -// A universal get method for strides of CreateNdDescOp op - -// Get the effective strides of the CreateNdDescOp op -// Stride information provided to CreateNdDescOp is multifaceted -// In other words strides info provided to CreateNdDescOp in multiple ways: -// 1. For static memrefs: the strides info are inherent in the memref data type -// 2. For dynamic memrefs and i64/i32 source: the strides info is provided via -// the operand `strides`, however the strides operand can also take two -// different types: -// 2.1 Constant type: constant attribute can be passed -// 2.2 Value type: a value type can be passed -// This function collects these info based on different scenarios and returns -// them in Value types. - -// We use getMixedStrides() to collect the strides info instead of handling -// aforementioned cases manually, since, it already uses -// OffsetSizeAndStrideOpInterface, getMixedStrides() already takes care of -// memref type. -SmallVector getStridesInValueType(ConversionPatternRewriter &rewriter, - CreateNdDescOp op) { - SmallVector stridesVal; - auto mixedStrides = op.getMixedStrides(); - for (size_t i = 0; i < mixedStrides.size(); i++) { - auto stride = - getValueOrConstantOp(mixedStrides[i], op.getLoc(), rewriter, indexTy); - stridesVal.push_back(stride); - } - return stridesVal; -} - // Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a // 2d memory region with respect to the two inner-most dimensions. Other // outer dimensions affect the base address of the 2d plane. For 2d, we @@ -177,9 +128,11 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter, auto tileRank = op.getTensorDesc().getType().getRank(); auto offsets = op.getMixedOffsets(); - auto strides = getStridesInValueType(rewriter, op); + auto strides = getStridesOrOffsetsOrShapesInValueType( + rewriter, op.getMixedStrides(), loc); - // Calculate the effective rank of the source based on strides arrayref size + // Calculate the effective rank of the source based on strides arrayref + // size auto effectiveRank = strides.size(); int64_t ranksToAdjust = effectiveRank; auto bytesPerElem = diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 2dd8b3293..c0607cf34 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -15,10 +15,12 @@ #include "XeTileOpConversion.h" #include "imex/Utils/XeArch.h" +#include "imex/Utils/XeCommon.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "llvm/ADT/SmallVector.h" #include #include #include @@ -28,7 +30,7 @@ #include namespace imex { - +using namespace mlir; using mlir::vector::CreateMaskOp; using mlir::vector::ExtractOp; using mlir::vector::ExtractStridedSliceOp; @@ -479,14 +481,32 @@ class SgInitTileOpPattern : public XeOneToNConversion { tDescOffsets.push_back(tDescOffsetX); tDescOffsets.push_back(tDescOffsetY); - // TODO: this needs improvement, it assumes the source is static - // memeref. + // Handle memref source. if (auto MemRefTypedSource = - mlir::cast>(source)) { + mlir::dyn_cast>(source)) { + // Hnadle the case where the shape is static. + if (MemRefTypedSource.getType().hasStaticShape()) { + auto createNdOp = rewriter.create( + op.getLoc(), tDescTy /*resultTy*/, + MemRefTypedSource + /*source*/, + tDescOffsets /*offsets*/); + + xegpuOps[blocks[1] * i + j] = createNdOp; + } else { + // Handle the case where the shape is dynamic. + auto createNdOp = rewriter.create( + loc, tDescTy, MemRefTypedSource, tDescOffsets, + op.getMixedSizes(), op.getMixedStrides()); + xegpuOps[blocks[1] * i + j] = createNdOp; + } + } else if (auto intSourceType = + mlir::dyn_cast>( + source)) { + // Handle the case where the source is an integer. auto createNdOp = rewriter.create( - op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/, - tDescOffsets /*offsets*/); - + loc, tDescTy, intSourceType, tDescOffsets, op.getMixedSizes(), + op.getMixedStrides()); xegpuOps[blocks[1] * i + j] = createNdOp; } else { return mlir::failure(); diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp index 1db7ffc23..fa18e692b 100644 --- a/lib/Utils/XeCommon.cpp +++ b/lib/Utils/XeCommon.cpp @@ -230,4 +230,34 @@ bool isVectorAnyINTELType(mlir::Type type) { spirvSupportedSizes.end()); } +// convert OpFoldResult to Value by replacing integer +// attributes with arith::ConstantOps. It also performs +// simple type conversions +mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc, + mlir::PatternRewriter &rewriter, + mlir::Type type) { + if (ofr.is()) + return ofr.get(); + + auto intAttr = llvm::cast(ofr.get()); + + if (type) + intAttr = mlir::IntegerAttr::get(type, intAttr.getInt()); + + return rewriter.create(loc, intAttr); +} + +llvm::SmallVector getStridesOrOffsetsOrShapesInValueType( + mlir::PatternRewriter &rewriter, + ::llvm::SmallVector mixedOSS, mlir::Location loc) { + llvm::SmallVector valueVec; + // auto mixedStrides = op.getMixedStrides(); + for (size_t i = 0; i < mixedOSS.size(); i++) { + auto oss = getValueOrConstantOp(mixedOSS[i], loc, rewriter, + rewriter.getIndexType()); + valueVec.push_back(oss); + } + return valueVec; +} + } // namespace imex diff --git a/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir new file mode 100644 index 000000000..f6c14049b --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir @@ -0,0 +1,35 @@ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_kernel { + //CHECK: gpu.func @sg_init_tile(%[[arg0:.*]]: memref<1024x1024xf32>, %[[arg1:.*]]: memref) { + gpu.func @sg_init_tile(%a: memref<1024x1024xf32>, %b: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %result = arith.constant dense<0.0>: vector<32x32xf32> + + // CHECK: %[[SRC_AS_INDEX:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<1024x1024xf32> -> index + %src_as_index = memref.extract_aligned_pointer_as_index %a : memref<1024x1024xf32> -> index + // CHECK-NEXT: %[[SRC_AS_INT:.*]] = arith.index_cast %[[SRC_AS_INDEX]] : index to i64 + %src_as_int = arith.index_cast %src_as_index : index to i64 + + //CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %static_memref_src = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + + // CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : memref -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %dynamic_memref_src = xetile.init_tile %b[%c0, %c32],[%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<32x32xf32> + + // CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[SRC_AS_INT]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %int_src = xetile.init_tile %src_as_int[%c0, %c32], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xetile.tile<32x32xf32> + + //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xetile.store_tile %result, %static_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %result, %dynamic_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %result, %int_src: vector<32x32xf32>, !xetile.tile<32x32xf32> + + gpu.return + } +} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir new file mode 100644 index 000000000..c2dc9c0d8 --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir @@ -0,0 +1,153 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +// NOTES : +// This example assumes one subgroup per one workgroup and the kernel specifies the computation +// done by a single subgroup. + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> + memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> + // Make the memrefs dynamic + %A_gpu_cast = memref.cast %A_gpu : memref<1024x1024xf16> to memref + %B_gpu_cast = memref.cast %B_gpu : memref<1024x1024xf16> to memref + %C_gpu_cast = memref.cast %C_gpu : memref<1024x1024xf32> to memref + + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu_cast : memref, %B_gpu_cast : memref, %C_gpu_cast : memref) + gpu.dealloc %A_gpu : memref<1024x1024xf16> + gpu.dealloc %B_gpu : memref<1024x1024xf16> + return %C_gpu : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref, %B: memref, %C: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<32x32xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] + : !xetile.tile<16x32xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] + : !xetile.tile<32x32xf16> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<1024x1024xf16> + %B = memref.alloc() : memref<1024x1024xf16> + %C = memref.alloc() : memref<1024x1024xf32> + %C_ref = memref.alloc() : memref<1024x1024xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> + %t = arith.mulf %a_val, %b_val : f16 + %t_cast = arith.extf %t : f16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> + + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024x1024xf16> + memref.dealloc %B : memref<1024x1024xf16> + memref.dealloc %C : memref<1024x1024xf32> + memref.dealloc %C_ref : memref<1024x1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +}