From f826127d2d617c7fa938c1918a46a1b769f1f3d1 Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Tue, 5 Sep 2023 22:22:58 +0000 Subject: [PATCH] Add a JointMatrix test case and necessary patch This commit adds: - A JointMatrix test case - A patch that contains the updated definition of JointMatrix spec supported by IGC. Current upstream definition of JointMatrix is not supported by IGC anymore. --- ...int-Matrix-support-to-match-IGC-Spec.patch | 306 ++++++++++++++++++ ...addressing_matrixUse_Param_level_zero.mlir | 234 ++++++++++++++ 2 files changed, 540 insertions(+) create mode 100644 build_tools/patches/0003--Update-the-Joint-Matrix-support-to-match-IGC-Spec.patch create mode 100644 test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir diff --git a/build_tools/patches/0003--Update-the-Joint-Matrix-support-to-match-IGC-Spec.patch b/build_tools/patches/0003--Update-the-Joint-Matrix-support-to-match-IGC-Spec.patch new file mode 100644 index 000000000..0455310e9 --- /dev/null +++ b/build_tools/patches/0003--Update-the-Joint-Matrix-support-to-match-IGC-Spec.patch @@ -0,0 +1,306 @@ +From 810eccbbbb872402391a0f01a53aaf0205ea10c4 Mon Sep 17 00:00:00 2001 +From: Md Abdullah Shahneous Bari +Date: Tue, 8 Aug 2023 00:28:31 +0000 +Subject: [PATCH] Update the Joint Matrix support to match Spec supported by + IGC + +Update the Joint Matrix support to match the following spec: +https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc +--- + .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 30 ++++++++++++++----- + .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 +++- + mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 12 ++++++-- + mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 20 +++++++++---- + .../SPIRV/Deserialization/Deserializer.cpp | 17 +++++++---- + .../Target/SPIRV/Serialization/Serializer.cpp | 5 +++- + mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 +++++----- + 7 files changed, 75 insertions(+), 31 deletions(-) + +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +index 6f0f728f811e..c2ad6ff24bea 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +@@ -4039,16 +4039,30 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr< + "image_sampler_use_info", + [SPIRV_ISUI_SamplerUnknown, SPIRV_ISUI_NeedSampler, SPIRV_ISUI_NoSampler]>; + +-def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>; +-def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>; +-def SPIRV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>; +-def SPIRV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>; +- +-def SPIRV_MatrixLayoutAttr : +- SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [ +- SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB ++// Change the layout parameter to IGC spec, the currnet MLIR version ++// does not match the IGC spec, IGC spec has been updated ++// https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc ++ ++def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 0>; ++def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>; ++def SPIRV_ML_Packed : I32EnumAttrCase<"Packed", 2>; ++def SPIRV_ML_Unused : I32EnumAttrCase<"Unused", 3>; ++ ++ def SPIRV_MatrixLayoutAttr : ++ SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [ ++ SPIRV_ML_RowMajor, SPIRV_ML_ColumnMajor, SPIRV_ML_Packed, SPIRV_ML_Unused + ]>; + ++def SPIRV_ML_MATRIX_A : I32EnumAttrCase<"MatrixA", 0>; ++def SPIRV_ML_MATRIX_B : I32EnumAttrCase<"MatrixB", 1>; ++def SPIRV_ML_MATRIX_ACC : I32EnumAttrCase<"Accumulator", 2>; ++ ++def SPIRV_MatrixUseAttr : ++ SPIRV_I32EnumAttr<"MatrixUse", "valid SPIR-V MatrixUse", "matrixUse", [ ++ SPIRV_ML_MATRIX_A, SPIRV_ML_MATRIX_B, SPIRV_ML_MATRIX_ACC ++ ]>; ++ ++ + // Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension. + def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>; + def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>; +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +index 07f2f158ecab..e0b3c5448a44 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +@@ -459,7 +459,8 @@ public: + using Base::Base; + + static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, +- unsigned columns, MatrixLayout matrixLayout); ++ unsigned columns, MatrixLayout matrixLayout, ++ MatrixUse matrixUse); + Type getElementType() const; + + /// Return the scope of the joint matrix. +@@ -472,6 +473,9 @@ public: + /// return the layout of the matrix + MatrixLayout getMatrixLayout() const; + ++ /// return the use of the matrix ++ MatrixUse getMatrixUse() const; ++ + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage = std::nullopt); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, +diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +index 9188f8b699b4..4c099bf77a88 100644 +--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp ++++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +@@ -394,7 +394,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect, + + // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x` + // element-type +-// `,` layout `,` scope`>` ++// `,` layout `,` scope ++// `,` use`>` + static Type parseJointMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) +@@ -421,10 +422,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect, + if (parser.parseComma() || + spirv::parseEnumKeywordAttr(scope, parser, "scope ")) + return Type(); ++ MatrixUse matrixUse; ++ if (parser.parseComma() || ++ parseEnumKeywordAttr(matrixUse, parser, "matrixUse ")) ++ return Type(); + if (parser.parseGreater()) + return Type(); + return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1], +- matrixLayout); ++ matrixLayout, matrixUse); + } + + // TODO: Reorder methods to be utilities first and parse*Type +@@ -952,7 +957,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { + os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; + os << type.getElementType() << ", " + << stringifyMatrixLayout(type.getMatrixLayout()); +- os << ", " << stringifyScope(type.getScope()) << ">"; ++ os << ", " << stringifyScope(type.getScope()) << ", " ++ << stringifyMatrixUse(type.getMatrixUse()) << ">"; + } + + static void print(MatrixType type, DialectAsmPrinter &os) { +diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +index 741d8069471d..49ded5c60951 100644 +--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp ++++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +@@ -352,7 +352,8 @@ void CooperativeMatrixNVType::getCapabilities( + //===----------------------------------------------------------------------===// + + struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { +- using KeyTy = std::tuple; ++ using KeyTy = ++ std::tuple; + + static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { +@@ -361,26 +362,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { + } + + bool operator==(const KeyTy &key) const { +- return key == KeyTy(elementType, rows, columns, matrixLayout, scope); ++ return key == ++ KeyTy(elementType, rows, columns, matrixLayout, scope, matrixUse); + } + + JointMatrixTypeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), rows(std::get<1>(key)), +- columns(std::get<2>(key)), scope(std::get<4>(key)), +- matrixLayout(std::get<3>(key)) {} ++ columns(std::get<2>(key)), matrixLayout(std::get<3>(key)), ++ scope(std::get<4>(key)), matrixUse(std::get<5>(key)) {} + + Type elementType; + unsigned rows; + unsigned columns; + Scope scope; + MatrixLayout matrixLayout; ++ MatrixUse matrixUse; + }; + + JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope, + unsigned rows, unsigned columns, +- MatrixLayout matrixLayout) { ++ MatrixLayout matrixLayout, ++ MatrixUse matrixUse) { + return Base::get(elementType.getContext(), elementType, rows, columns, +- matrixLayout, scope); ++ matrixLayout, scope, matrixUse); + } + + Type JointMatrixINTELType::getElementType() const { +@@ -397,6 +401,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const { + return getImpl()->matrixLayout; + } + ++MatrixUse JointMatrixINTELType::getMatrixUse() const { ++ return getImpl()->matrixUse; ++} ++ + void JointMatrixINTELType::getExtensions( + SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) { +diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +index 90416289134b..4598dc608034 100644 +--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp ++++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +@@ -939,7 +939,7 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef operands) { + + LogicalResult + spirv::Deserializer::processJointMatrixType(ArrayRef operands) { +- if (operands.size() != 6) { ++ if (operands.size() != 7) { + return emitError(unknownLoc, "OpTypeJointMatrix must have element " + "type and row x column parameters"); + } +@@ -949,7 +949,13 @@ spirv::Deserializer::processJointMatrixType(ArrayRef operands) { + return emitError(unknownLoc, "OpTypeJointMatrix references undefined ") + << operands[1]; + } +- ++ auto matrixUse = ++ spirv::symbolizeMatrixUse(getConstantInt(operands[6]).getInt()); ++ if (!matrixUse) { ++ return emitError(unknownLoc, ++ "OpTypeJointMatrix references undefined Use ") ++ << operands[6]; ++ } + auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt()); + if (!scope) { + return emitError(unknownLoc, +@@ -960,14 +966,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef operands) { + spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt()); + if (!matrixLayout) { + return emitError(unknownLoc, +- "OpTypeJointMatrix references undefined scope ") ++ "OpTypeJointMatrix references undefined Layout ") + << operands[4]; + } + unsigned rows = getConstantInt(operands[2]).getInt(); + unsigned columns = getConstantInt(operands[3]).getInt(); + +- typeMap[operands[0]] = spirv::JointMatrixINTELType::get( +- elementTy, scope.value(), rows, columns, matrixLayout.value()); ++ typeMap[operands[0]] = ++ spirv::JointMatrixINTELType::get(elementTy, scope.value(), rows, columns, ++ matrixLayout.value(), matrixUse.value()); + return success(); + } + +diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +index 988e60d08edf..b6ec58648d72 100644 +--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp ++++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +@@ -222,7 +222,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, + case spirv::Decoration::LinkageAttributes: { + // Get the value of the Linkage Attributes + // e.g., LinkageAttributes=["linkageName", linkageType]. +- auto linkageAttr = llvm::dyn_cast(attr.getValue()); ++ auto linkageAttr = ++ llvm::dyn_cast(attr.getValue()); + auto linkageName = linkageAttr.getLinkageName(); + auto linkageType = linkageAttr.getLinkageType().getValue(); + // Encode the Linkage Name (string literal to uint32_t). +@@ -639,6 +640,8 @@ LogicalResult Serializer::prepareBasicType( + static_cast(jointMatrixType.getMatrixLayout()))); + operands.push_back( + getConstantOp(static_cast(jointMatrixType.getScope()))); ++ operands.push_back( ++ getConstantOp(static_cast(jointMatrixType.getMatrixUse()))); + return success(); + } + +diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +index ccf4240f8e56..a793564e0477 100644 +--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp ++++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +@@ -396,10 +396,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + +- os << formatv( +- " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" +- " && \"cannot have more than one bit set\");\n", +- underlyingType); ++ os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" ++ " && \"cannot have more than one bit set\");\n", ++ underlyingType); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { +@@ -523,7 +522,8 @@ static void emitAttributeSerialization(const Attribute &attr, + << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); + if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || + attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || +- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { ++ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" || ++ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") { + // These two enums are encoded as to constant values in SPIR-V blob, + // but we directly use the constant value as attribute in SPIR-V dialect. So + // need to handle them separately from normal enum attributes. +@@ -818,7 +818,8 @@ static void emitAttributeDeserialization(const Attribute &attr, + raw_ostream &os) { + if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || + attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || +- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { ++ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" || ++ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") { + // These two enums are encoded as to constant values in SPIR-V blob, + // but we directly use the constant value as attribute in SPIR-V dialect. So + // need to handle them separately from normal enum attributes. +@@ -926,7 +927,8 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef loc, + // Process operands/attributes + for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); +- if (auto *valueArg = llvm::dyn_cast_if_present(argument)) { ++ if (auto *valueArg = ++ llvm::dyn_cast_if_present(argument)) { + if (valueArg->isVariableLength()) { + if (i != e - 1) { + PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or " +-- +2.34.1 + diff --git a/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir b/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir new file mode 100644 index 000000000..141b51a93 --- /dev/null +++ b/test/SPIRV/JointMatrix/gemm_using_joint_matrix_Physical_64_addressing_matrixUse_Param_level_zero.mlir @@ -0,0 +1,234 @@ +module @gemm_using_jointmatrix_module attributes {gpu.container_module} { + memref.global "private" constant @__constant_A_2048x2048xbf16 : memref<2048x2048xbf16> = dense<1.100000e+00> + memref.global "private" constant @__constant_B_1024x2048x2xbf16 : memref<1024x2048x2xbf16> = dense<2.200000e+00> + memref.global "private" constant @__constant_C_2048x2048xf32 : memref<2048x2048xf32> = dense<0.000000e+00> + + // memref.global "private" constant @__constant_test_store : memref<1xf32> = dense<0.000000e+00> + // M = 8 + // K = 16 + // N = 16 + // SG_SIZE = 16 + func.func @test(%arg_A: memref<2048x2048xbf16>, %arg_B: memref<1024x2048x2xbf16>, %arg_C: memref<2048x2048xf32>) -> memref<2048x2048xf32> attributes {llvm.emit_c_interface} { + %c2048 = arith.constant 2048 : index + %c1024 = arith.constant 1024 : index + %c256 = arith.constant 256 : index + %c128 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + + %memref_arg_A_i8 = gpu.alloc host_shared () : memref<8388608xi8> + %memref_arg_A_bf16 = memref.view %memref_arg_A_i8[%c0][] : memref<8388608xi8> to memref<2048x2048xbf16> + %memref_arg_A_i16 = memref.view %memref_arg_A_i8[%c0][] : memref<8388608xi8> to memref<2048x2048xi16> + memref.copy %arg_A, %memref_arg_A_bf16 : memref<2048x2048xbf16> to memref<2048x2048xbf16> + %memref_arg_A_i16_flat = memref.cast %memref_arg_A_i16 : memref<2048x2048xi16> to memref<*xi16> + + %memref_arg_B_i8 = gpu.alloc host_shared () : memref<8388608xi8> + %memref_arg_B_bf16 = memref.view %memref_arg_B_i8[%c0][] : memref<8388608xi8> to memref<1024x2048x2xbf16> //VNNI transformed + %memref_arg_B_i16 = memref.view %memref_arg_B_i8[%c0][] : memref<8388608xi8> to memref<1024x2048x2xi16> //VNNI transformed + memref.copy %arg_B, %memref_arg_B_bf16 : memref<1024x2048x2xbf16> to memref<1024x2048x2xbf16> + %memref_arg_B_i16_flat = memref.cast %memref_arg_B_i16 : memref<1024x2048x2xi16> to memref<*xi16> + + %memref_arg_C_i8 = gpu.alloc host_shared () : memref<16777216xi8> + %memref_arg_C_f32 = memref.view %memref_arg_C_i8[%c0][] : memref<16777216xi8> to memref<2048x2048xf32> + memref.copy %arg_C, %memref_arg_C_f32 : memref<2048x2048xf32> to memref<2048x2048xf32> + %memref_arg_C_f32_flat = memref.cast %memref_arg_C_f32 : memref<2048x2048xf32> to memref<*xf32> + + // To use sycl runtime, the blocks and threads needs to be passed in a slightly different way to match the Global and local size + // + + + // Calling convetion in MLIR: + // =========================== + // blocks in (gridX, gridY, gridZ) threads in (blockX, blockY, blockZ) + + // Calling convetion in SYCL/DPC++: + // ================================= + // nd_range({global_size.x, global_size.y, global_size.z}, {local_size.x, local_size.y, local_size.z}) + + // Conversion between MLIR and SYCL/DPC++ convetion: + // =================================================== + // Change of dimensions (X and Z dimensions are interchanged): + // ============================================================= + // Gobal Range/Size: + // =================== + // global_size.x = blockZ * gridZ, + // global_size.y = blockY * gridY, + // global_size.y = blockX * gridX + + // Local Range/Size: + // ==================== + // local_size.x = blockZ + // local_size.y = blockY + // local_size.z = blockX + + // For details see: mlir-extensions/lib/ExecutionEngine/SYCLRUNTIME/SyclRuntimeWrappers.cpp + + gpu.launch_func @gemm_using_jointmatrix_module::@gemm_using_jointmatrix blocks in (%c256, %c128, %c1) threads in (%c1, %c16, %c1) args(%memref_arg_A_i16 : memref<2048x2048xi16>, %memref_arg_B_i16 : memref<1024x2048x2xi16>, %memref_arg_C_f32 : memref<2048x2048xf32>) + + gpu.dealloc %memref_arg_A_i8 : memref<8388608xi8> + gpu.dealloc %memref_arg_B_i8 : memref<8388608xi8> + return %memref_arg_C_f32 : memref<2048x2048xf32> + + } + + func.func @main() attributes {llvm.emit_c_interface} { + %A = memref.get_global @__constant_A_2048x2048xbf16 : memref<2048x2048xbf16> + %B = memref.get_global @__constant_B_1024x2048x2xbf16 : memref<1024x2048x2xbf16> + %C = memref.get_global @__constant_C_2048x2048xf32 : memref<2048x2048xf32> + + %result = call @test(%A, %B, %C) : (memref<2048x2048xbf16>, memref<1024x2048x2xbf16>, memref<2048x2048xf32>) -> memref<2048x2048xf32> + %cast = memref.cast %result : memref<2048x2048xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + + spirv.module @__spv__gemm_using_jointmatrix_module Physical64 OpenCL requires #spirv.vce attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // workgroup_size = [1, 16, 1], subgroup_size = 16 + spirv.GlobalVariable @__builtin_var_NumWorkgroups__ built_in("NumWorkgroups") : !spirv.ptr, Input> + spirv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spirv.ptr, Input> + spirv.GlobalVariable @__builtin_var_SubgroupId__ built_in("SubgroupId") : !spirv.ptr + spirv.GlobalVariable @__builtin_var_SubgroupSize__ built_in("SubgroupSize") : !spirv.ptr + spirv.GlobalVariable @__builtin_var_GlobalInvocationId__ built_in("GlobalInvocationId") : !spirv.ptr, Input> + spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr, Input> + // gpu.known_block_size = array, gpu.known_grid_size = array, + spirv.func @gemm_using_jointmatrix(%arg0 : !spirv.ptr, CrossWorkgroup>, %arg1 : !spirv.ptr, CrossWorkgroup>, %arg2 : !spirv.ptr, CrossWorkgroup>) "None" attributes { workgroup_attributions = 0 : i64} { + %c0 = spirv.Constant 0 : i64 + %c1 = spirv.Constant 1 : i64 + + %subgroup_size = spirv.Constant 16 : i64 + %c_M = spirv.Constant 2048 : i64 + %c_N = spirv.Constant 2048 : i64 + %c_K = spirv.Constant 2048 : i64 + + %c_tM = spirv.Constant 8 : i64 + %c_tK = spirv.Constant 16 : i64 + %c_tN = spirv.Constant 16 : i64 + + %c_vnni_factor = spirv.Constant 2 : i64 + + // Get workgroup IDs + %__builtin_var_WorkgroupId___addr = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr, Input> + %0 = spirv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi64> + %workgroupId_x = spirv.CompositeExtract %0[0 : i32] : vector<3xi64> + %workgroupId_y = spirv.CompositeExtract %0[1 : i32] : vector<3xi64> + %workgroupId_z = spirv.CompositeExtract %0[2 : i32] : vector<3xi64> + + // Get the subgroup ID + %__builtin_var_SubgroupId___addr = spirv.mlir.addressof @__builtin_var_SubgroupId__ : !spirv.ptr + %subgroupId = spirv.Load "Input" %__builtin_var_SubgroupId___addr : i64 + + // Get the numworkgroups + %__builtin_var_NumWorkgroups___addr = spirv.mlir.addressof @__builtin_var_NumWorkgroups__ : !spirv.ptr, Input> + %1 = spirv.Load "Input" %__builtin_var_NumWorkgroups___addr : vector<3xi64> + %numworkgroups_x = spirv.CompositeExtract %1[0 : i32] : vector<3xi64> + %numworkgroups_y = spirv.CompositeExtract %1[1 : i32] : vector<3xi64> + %numworkgroups_z = spirv.CompositeExtract %1[2 : i32] : vector<3xi64> + + // Get the global invocation ID (global ID) + %__builtin_var_GlobalInvocationId___addr = spirv.mlir.addressof @__builtin_var_GlobalInvocationId__ : !spirv.ptr, Input> + %2 = spirv.Load "Input" %__builtin_var_GlobalInvocationId___addr : vector<3xi64> + %globalId_x = spirv.CompositeExtract %2[0 : i32] : vector<3xi64> + %globalId_y = spirv.CompositeExtract %2[1 : i32] : vector<3xi64> + %globalId_z = spirv.CompositeExtract %2[2 : i32] : vector<3xi64> + + // Get the local ID (thred ID) + + %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr, Input> + %3 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64> + %localId_x = spirv.CompositeExtract %3[0 : i32] : vector<3xi64> + %localId_y = spirv.CompositeExtract %3[1 : i32] : vector<3xi64> + %localId_z = spirv.CompositeExtract %3[2 : i32] : vector<3xi64> + + %sg_start_x = spirv.ISub %globalId_x, %localId_x : i64 + %sg_start_y = spirv.ISub %globalId_y, %localId_y : i64 + // Load C + // %load_offset_C = sg_start_x * tM * colsB + sg_start_y / SG_SIZE * tN + + %mul_c_1 = spirv.IMul %sg_start_x, %c_tM : i64 + %mul_c_2 = spirv.IMul %mul_c_1, %c_N : i64 + + %div_c_1 = spirv.UDiv %sg_start_y, %subgroup_size : i64 + %mul_c_3 = spirv.IMul %div_c_1, %c_tN : i64 + + %load_offset_C = spirv.IAdd %mul_c_2, %mul_c_3 : i64 + %load_address_C = spirv.AccessChain %arg2[%load_offset_C] : !spirv.ptr, CrossWorkgroup>, i64 + + %joint_matrix_C = spirv.INTEL.JointMatrixLoad %load_address_C, %c_N {memory_access = #spirv.memory_access} : (!spirv.ptr, i64) -> !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator> + + // Loop through (k = 0; k < colsA / tK; k++) + %loop_cnt = spirv.UDiv %c_K, %c_tK : i64 + spirv.mlir.loop { + spirv.Branch ^bb1(%c0, %joint_matrix_C: i64, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>) + ^bb1(%k: i64, %matrixC1: !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>): // 2 preds: ^bb0, ^bb2 + %5 = spirv.ULessThan %k, %loop_cnt : i64 + spirv.BranchConditional %5, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + // Loading A + // %load_offset_A = sg_start_x * tM * colsA + k * tK + // %load_offset_A = (%globalId_x - %localId_x) * %c_K + (%k * %c_tK) + %mul_1 = spirv.IMul %sg_start_x, %c_tM : i64 + %mul_2 = spirv.IMul %mul_1, %c_K : i64 // sg_start_x * tM * colsA + + %mul_3 = spirv.IMul %k, %c_tK : i64 // k * tK + %load_offset_A = spirv.IAdd %mul_2, %mul_3 : i64 + %load_address_A = spirv.AccessChain %arg0[%load_offset_A] : !spirv.ptr, CrossWorkgroup>, i64 + + %joint_matrix_A = spirv.INTEL.JointMatrixLoad %load_address_A, %c_K {memory_access = #spirv.memory_access} : (!spirv.ptr, i64) -> !spirv.jointmatrix<8x16xi16, RowMajor, Subgroup, MatrixA> + + // Loading B + // %load_offset_B = (k * tK / vnniFactor) * (colsB * vnniFactor) + sg_start_y / SG_SIZE * tN * vnniFactor + %div_1 = spirv.UDiv %c_tK, %c_vnni_factor : i64 // k * tK + %mul_4 = spirv.IMul %k, %div_1 : i64 // (k * tK / vnniFactor) + + %mul_5 = spirv.IMul %c_N, %c_vnni_factor : i64 // (colsB * vnniFactor) + + %mul_6 = spirv.IMul %mul_4, %mul_5 : i64 // (k * tK / vnniFactor) * (colsB * vnniFactor) + + %div_2 = spirv.UDiv %sg_start_y, %subgroup_size : i64 // sg_start_y / SG_SIZE + %mul_7 = spirv.IMul %div_2, %c_tN : i64 // sg_start_y / SG_SIZE * tN + %mul_8 = spirv.IMul %mul_7, %c_vnni_factor : i64 // sg_start_y / SG_SIZE * tN * vnniFactor + + %load_offset_B = spirv.IAdd %mul_6, %mul_8 : i64 + + %load_address_B = spirv.AccessChain %arg1[%load_offset_B] : !spirv.ptr, CrossWorkgroup>, i64 + %stride_B = spirv.IMul %c_N, %c_vnni_factor : i64 + + + %joint_matrix_B = spirv.INTEL.JointMatrixLoad %load_address_B, %stride_B {memory_access = #spirv.memory_access} : (!spirv.ptr, i64) -> !spirv.jointmatrix<16x16xi16, Packed, Subgroup, MatrixB> + + %r = spirv.INTEL.JointMatrixMad %joint_matrix_A, %joint_matrix_B, %matrixC1 : !spirv.jointmatrix<8x16xi16, RowMajor, Subgroup, MatrixA>, !spirv.jointmatrix<16x16xi16, Packed, Subgroup, MatrixB> -> !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator> + + %incr_k = spirv.IAdd %k, %c1 : i64 + %6 = spirv.ULessThan %incr_k, %loop_cnt : i64 + spirv.BranchConditional %6, ^continue(%incr_k, %r : i64, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>), ^store(%incr_k, %r : i64, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>) + + ^store(%k2: i64, %matrixC2: !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>): + spirv.INTEL.JointMatrixStore %load_address_C, %matrixC2, %c_N {memory_access = #spirv.memory_access} : (!spirv.ptr, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>, i64) + spirv.Branch ^continue(%k2, %matrixC2: i64, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>) + + ^continue(%k3: i64, %matrixC3: !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>): + spirv.Branch ^bb1(%k3, %matrixC3: i64, !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup, Accumulator>) + ^bb3: // pred: ^bb1 + spirv.mlir.merge + } + spirv.Return + } + spirv.EntryPoint "Kernel" @gemm_using_jointmatrix, @__builtin_var_NumWorkgroups__, @__builtin_var_WorkgroupId__, @__builtin_var_SubgroupId__, @__builtin_var_SubgroupSize__, @__builtin_var_GlobalInvocationId__, @__builtin_var_LocalInvocationId__ + // Setting up workgroup size (@details in mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir) + // spirv.ExecutionMode @gemm_using_jointmatrix "LocalSize", 1, 16, 1 + + // Setting up subgroup size for the specific kernel + spirv.ExecutionMode @gemm_using_jointmatrix "SubgroupSize", 16 + } + + gpu.module @gemm_using_jointmatrix_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits>} { + gpu.func @gemm_using_jointmatrix(%arg0 : memref<2048x2048xi16>, %arg1 : memref<1024x2048x2xi16>, %arg2 : memref<2048x2048xf32>) kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // skipping the gpu.func body, since we already have spirv.func body, this won't be used + gpu.return + } + } + +}