diff --git a/build_tools/patches/0014-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch b/build_tools/patches/0014-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch new file mode 100644 index 000000000..73c274a1f --- /dev/null +++ b/build_tools/patches/0014-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch @@ -0,0 +1,440 @@ +From a2215d32e5fa84745a0843639a50c4b6e7d0a008 Mon Sep 17 00:00:00 2001 +From: Md Abdullah Shahneous Bari +Date: Wed, 24 Jul 2024 17:26:22 +0000 +Subject: [PATCH] [SPIR-V] Enable native bf16 support in SPIR-V dialect. +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +Enables Khronos extension: SPV_KHR_bfloat16. +Most of the ops specified in the extension is supported. +Some notable exceptions are: OpDot, OpCooperativeMatrixMulAddKHR. + +Also adds native bf16 support for several arithmetic and math ops: +Supported arithmetic ops: + • OpFAdd + • OpFSub + • OpFMul + • OpFDiv + +OpenCL extended instructions: + o fabs + o fmax + o fmin + o fma + o tanh +--- + .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 10 ++-- + .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 46 ++++++++++++++++--- + .../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 18 ++++---- + .../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td | 12 ++--- + mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 6 ++- + mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 27 +++++++++-- + .../SPIRV/Deserialization/Deserializer.cpp | 14 ++++-- + .../Target/SPIRV/Serialization/Serializer.cpp | 3 ++ + 8 files changed, 101 insertions(+), 35 deletions(-) + +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +index 22d5afcd7738..de9e11493793 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +@@ -82,7 +82,7 @@ class SPIRV_ArithmeticExtendedBinaryOp { ++def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_AnyFloat, [Commutative]> { + let summary = "Floating-point addition of Operand 1 and Operand 2."; + + let description = [{ +@@ -104,7 +104,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> + + // ----- + +-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> { ++def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_AnyFloat, []> { + let summary = "Floating-point division of Operand 1 divided by Operand 2."; + + let description = [{ +@@ -154,7 +154,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> { + + // ----- + +-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> { ++def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_AnyFloat, [Commutative]> { + let summary = "Floating-point multiplication of Operand 1 and Operand 2."; + + let description = [{ +@@ -229,7 +229,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> { + + // ----- + +-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> { ++def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_AnyFloat, []> { + let summary = "Floating-point subtraction of Operand 2 from Operand 1."; + + let description = [{ +@@ -450,7 +450,7 @@ def SPIRV_DotOp : SPIRV_Op<"Dot", + ); + + let results = (outs +- SPIRV_Float:$result ++ SPIRV_AnyFloat:$result + ); + + let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)"; +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +index 91a8bb51ad65..b102c143ee83 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +@@ -343,6 +343,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup + def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>; + def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>; + def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>; ++def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>; + + def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>; + def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>; +@@ -435,7 +436,7 @@ def SPIRV_ExtensionAttr : + SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask, + SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate, + SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation, +- SPV_KHR_cooperative_matrix, ++ SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16, + SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing, + SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density, + SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer, +@@ -1193,6 +1194,27 @@ def SPIRV_C_ShaderClockKHR : I32EnumAttrCase<"Shade + Extension<[SPV_KHR_shader_clock]> + ]; + } ++ ++def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> { ++ list availability = [ ++ Extension<[SPV_KHR_bfloat16]> ++ ]; ++} ++ ++def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> { ++ list implies = [SPIRV_C_BFloat16TypeKHR]; ++ list availability = [ ++ Extension<[SPV_KHR_bfloat16]> ++ ]; ++} ++ ++def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> { ++ list implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR]; ++ list availability = [ ++ Extension<[SPV_KHR_bfloat16]> ++ ]; ++} ++ + def SPIRV_C_FragmentFullyCoveredEXT : I32EnumAttrCase<"FragmentFullyCoveredEXT", 5265> { + list implies = [SPIRV_C_Shader]; + list availability = [ +@@ -1491,6 +1513,7 @@ def SPIRV_CapabilityAttr : + SPIRV_C_RayQueryKHR, SPIRV_C_RayTracingKHR, SPIRV_C_Float16ImageAMD, + SPIRV_C_ImageGatherBiasLodAMD, SPIRV_C_FragmentMaskAMD, SPIRV_C_StencilExportEXT, + SPIRV_C_ImageReadWriteLodAMD, SPIRV_C_Int64ImageEXT, SPIRV_C_ShaderClockKHR, ++ SPIRV_C_BFloat16TypeKHR, SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, + SPIRV_C_FragmentFullyCoveredEXT, SPIRV_C_MeshShadingNV, SPIRV_C_FragmentDensityEXT, + SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray, + SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV, +@@ -4148,16 +4171,21 @@ def SPIRV_Bool : TypeAlias; + def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; + def SPIRV_Int16 : TypeAlias; + def SPIRV_Int32 : TypeAlias; ++def SPIRV_BFloat16KHR : TypeAlias; + def SPIRV_Float32 : TypeAlias; +-def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; +-def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; ++// def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; ++def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>; ++// def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; ++def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>; ++// Use this type for all kinds of floats. ++def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>; + // Remove the vector size restriction + // Although the vector size can be upto (2^64-1), uint64 + // In tablegen, int is signed int, hence using the upper + // limit of int64 (2^63-1) rather than uint64, it should serve the purpose + // for all practical cases + def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0x7FFFFFFFFFFFFFFF], +- [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; ++ [SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>; + // Component type check is done in the type parser for the following SPIR-V + // dialect-specific types so we use "Any" here. + def SPIRV_AnyPtr : DialectType; + +-def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>; ++def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>; + def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>; + def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>; + def SPIRV_Composite : + AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, + SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>; + def SPIRV_Type : AnyTypeOf<[ +- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, ++ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector, + SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, + SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, + SPIRV_AnySampledImage +@@ -4771,4 +4799,10 @@ def SPIRV_FPFastMathModeAttr : + SPIRV_FPFMM_AllowReassocINTEL + ]>; + ++def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0>; ++def SPIRV_FP_Encoding : ++ SPIRV_I32EnumAttr<"FPEncoding", "Valid floating-point encoding", "fpEncoding", [ ++ SPIRV_FPE_BFloat16KHR ++ ]>; ++ + #endif // MLIR_DIALECT_SPIRV_IR_BASE +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +index b5ca27d7d753..a91d2ffffc24 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +@@ -386,7 +386,7 @@ def SPIRV_CLExpOp : SPIRV_CLUnaryArithmeticOp<"exp", 19, SPIRV_Float> { + + // ----- + +-def SPIRV_CLFAbsOp : SPIRV_CLUnaryArithmeticOp<"fabs", 23, SPIRV_Float> { ++def SPIRV_CLFAbsOp : SPIRV_CLUnaryArithmeticOp<"fabs", 23, SPIRV_AnyFloat> { + let summary = "Absolute value of operand"; + + let description = [{ +@@ -409,7 +409,7 @@ def SPIRV_CLFAbsOp : SPIRV_CLUnaryArithmeticOp<"fabs", 23, SPIRV_Float> { + + // ----- + +-def SPIRV_CLFMaxOp : SPIRV_CLBinaryArithmeticOp<"fmax", 27, SPIRV_Float> { ++def SPIRV_CLFMaxOp : SPIRV_CLBinaryArithmeticOp<"fmax", 27, SPIRV_AnyFloat> { + let summary = "Return maximum of two floating-point operands"; + + let description = [{ +@@ -433,7 +433,7 @@ def SPIRV_CLFMaxOp : SPIRV_CLBinaryArithmeticOp<"fmax", 27, SPIRV_Float> { + + // ----- + +-def SPIRV_CLFMinOp : SPIRV_CLBinaryArithmeticOp<"fmin", 28, SPIRV_Float> { ++def SPIRV_CLFMinOp : SPIRV_CLBinaryArithmeticOp<"fmin", 28, SPIRV_AnyFloat> { + let summary = "Return minimum of two floating-point operands"; + + let description = [{ +@@ -479,7 +479,7 @@ def SPIRV_CLFloorOp : SPIRV_CLUnaryArithmeticOp<"floor", 25, SPIRV_Float> { + + // ----- + +-def SPIRV_CLFmaOp : SPIRV_CLTernaryArithmeticOp<"fma", 26, SPIRV_Float> { ++def SPIRV_CLFmaOp : SPIRV_CLTernaryArithmeticOp<"fma", 26, SPIRV_AnyFloat> { + let summary = [{ + Compute the correctly rounded floating-point representation of the sum + of c with the infinitely precise product of a and b. Rounding of +@@ -789,7 +789,7 @@ def SPIRV_CLTanOp : SPIRV_CLUnaryArithmeticOp<"tan", 62, SPIRV_Float > { + + // ----- + +-def SPIRV_CLTanhOp : SPIRV_CLUnaryArithmeticOp<"tanh", 63, SPIRV_Float> { ++def SPIRV_CLTanhOp : SPIRV_CLUnaryArithmeticOp<"tanh", 63, SPIRV_AnyFloat> { + let summary = "Compute hyperbolic tangent of x radians."; + + let description = [{ +@@ -864,10 +864,10 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> { + + Result Type must be i32. + +- Format must be a pointer(constant) to i8. If there are insufficient +- arguments for the format, the behavior is undefined. If the format ++ Format must be a pointer(constant) to i8. If there are insufficient ++ arguments for the format, the behavior is undefined. If the format + is exhausted while arguments remain, the excess arguments are evaluated +- (as always) but are otherwise ignored. The printf instruction returns ++ (as always) but are otherwise ignored. The printf instruction returns + when the end of the format string is encountered. + + +@@ -883,7 +883,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> { + SPIRV_AnyPtr:$format, + Variadic:$arguments + ); +- ++ + let results = (outs + SPIRV_Integer:$result + ); +diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +index b05ee0251df5..a5c8aa8fb450 100644 +--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td ++++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> { + + // ----- + +-def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> { ++def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> { + let summary = [{ + Convert value numerically from floating point to signed integer, with + round toward 0.0. +@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float + + // ----- + +-def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> { ++def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> { + let summary = [{ + Convert value numerically from floating point to unsigned integer, with + round toward 0.0. +@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float + // ----- + + def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF", +- SPIRV_Float, ++ SPIRV_AnyFloat, + SPIRV_Integer, + [SignedOp]> { + let summary = [{ +@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF", + // ----- + + def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF", +- SPIRV_Float, ++ SPIRV_AnyFloat, + SPIRV_Integer, + [UnsignedOp]> { + let summary = [{ +@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF", + // ----- + + def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert", +- SPIRV_Float, +- SPIRV_Float, ++ SPIRV_AnyFloat, ++ SPIRV_AnyFloat, + [UsableInSpecConstantOp]> { + let summary = [{ + Convert value numerically from one floating-point width to another +diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +index b9f906ada3ee..5101fffd8d0e 100644 +--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp ++++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +@@ -171,8 +171,10 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, + + // Check other allowed types + if (auto t = llvm::dyn_cast(type)) { +- if (type.isBF16()) { +- parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); ++ if (!ScalarType::isValid(t)) { ++ parser.emitError(typeLoc, ++ "only 16/32/64-bit float type allowed but found ") ++ << type; + return Type(); + } + } else if (auto t = llvm::dyn_cast(type)) { +diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +index e78726e43045..d62ef36a5cf3 100644 +--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp ++++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +@@ -604,7 +604,7 @@ bool ScalarType::classof(Type type) { + } + + bool ScalarType::isValid(FloatType type) { +- return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16(); ++ return llvm::is_contained({16u, 32u, 64u}, type.getWidth()); + } + + bool ScalarType::isValid(IntegerType type) { +@@ -613,6 +613,14 @@ bool ScalarType::isValid(IntegerType type) { + + void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + std::optional storage) { ++ ++ // bf16 case ++ if (llvm::isa(*this)) { ++ static const Extension exts[] = {Extension::SPV_KHR_bfloat16}; ++ ArrayRef ref(exts, std::size(exts)); ++ extensions.push_back(ref); ++ } ++ + // 8- or 16-bit integer/floating-point numbers will require extra extensions + // to appear in interface storage classes. See SPV_KHR_16bit_storage and + // SPV_KHR_8bit_storage for more details. +@@ -631,7 +639,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + [[fallthrough]]; + case StorageClass::Input: + case StorageClass::Output: +- if (getIntOrFloatBitWidth() == 16) { ++ if (getIntOrFloatBitWidth() == 16 && !llvm::isa(*this)) { + static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; + ArrayRef ref(exts, std::size(exts)); + extensions.push_back(ref); +@@ -718,7 +726,20 @@ void ScalarType::getCapabilities( + } else { + assert(llvm::isa(*this)); + switch (bitwidth) { +- WIDTH_CASE(Float, 16); ++ case 16: { ++ if (llvm::isa(*this)) { ++ static const Capability caps[] = {Capability::BFloat16TypeKHR}; ++ ArrayRef ref(caps, std::size(caps)); ++ capabilities.push_back(ref); ++ ++ } else { ++ static const Capability caps[] = {Capability::Float16}; ++ ArrayRef ref(caps, std::size(caps)); ++ capabilities.push_back(ref); ++ } ++ break; ++ } ++ // WIDTH_CASE(Float, 16); + WIDTH_CASE(Float, 64); + case 32: + break; +diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +index ef6e22aff12e..b5a857452ff0 100644 +--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp ++++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +@@ -809,14 +809,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, + typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); + } break; + case spirv::Opcode::OpTypeFloat: { +- if (operands.size() != 2) +- return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); ++ if (operands.size() < 2 || operands.size() > 3) ++ return emitError( ++ unknownLoc, ++ "OpTypeFloat must have bitwidth parameter and optional FP Encoding"); + + Type floatTy; + switch (operands[1]) { +- case 16: +- floatTy = opBuilder.getF16Type(); ++ case 16: { ++ if (operands.size() == 3 && operands[2] == 0) ++ floatTy = opBuilder.getBF16Type(); ++ else ++ floatTy = opBuilder.getF16Type(); + break; ++ } + case 32: + floatTy = opBuilder.getF32Type(); + break; +diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +index bdf786ff0afd..6a773009c780 100644 +--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp ++++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +@@ -468,6 +468,9 @@ LogicalResult Serializer::prepareBasicType( + if (auto floatType = dyn_cast(type)) { + typeEnum = spirv::Opcode::OpTypeFloat; + operands.push_back(floatType.getWidth()); ++ // Add extra parameter (FPEncoding) to opTypeFloat for bf16 data type ++ if (floatType.isBF16()) ++ operands.push_back(static_cast(spirv::FPEncoding::BFloat16KHR)); + return success(); + } + +-- +2.34.1 + diff --git a/lib/Transforms/SetSPIRVCapabilities.cpp b/lib/Transforms/SetSPIRVCapabilities.cpp index b7b3787a6..ab508c73a 100644 --- a/lib/Transforms/SetSPIRVCapabilities.cpp +++ b/lib/Transforms/SetSPIRVCapabilities.cpp @@ -54,11 +54,12 @@ struct SetSPIRVCapabilitiesPass spirv::Capability caps_opencl[] = { // clang-format off spirv::Capability::Addresses, + spirv::Capability::Bfloat16ConversionINTEL, + spirv::Capability::BFloat16TypeKHR, spirv::Capability::Float16Buffer, spirv::Capability::Int64, spirv::Capability::Int16, spirv::Capability::Int8, - spirv::Capability::Bfloat16ConversionINTEL, spirv::Capability::Kernel, spirv::Capability::Linkage, spirv::Capability::Vector16, @@ -77,10 +78,14 @@ struct SetSPIRVCapabilitiesPass // clang-format on }; spirv::Extension exts_opencl[] = { - spirv::Extension::SPV_INTEL_bfloat16_conversion, + // clang-format off spirv::Extension::SPV_EXT_shader_atomic_float_add, + spirv::Extension::SPV_KHR_bfloat16, spirv::Extension::SPV_KHR_expect_assume, - spirv::Extension::SPV_INTEL_vector_compute}; + spirv::Extension::SPV_INTEL_bfloat16_conversion, + spirv::Extension::SPV_INTEL_vector_compute + // clang-format on + }; spirv::Extension exts_vulkan[] = { spirv::Extension::SPV_KHR_storage_buffer_storage_class}; if (m_clientAPI == "opencl") { diff --git a/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir new file mode 100644 index 000000000..5a5cdc264 --- /dev/null +++ b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir @@ -0,0 +1,51 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @eltwise_add attributes {gpu.container_module} { + memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01> + func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> { + %c20 = arith.constant 20 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %memref = gpu.alloc host_shared () : memref<10x20xbf16> + memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16> + %memref_0 = gpu.alloc host_shared () : memref<10x20xbf16> + memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16> + %memref_1 = gpu.alloc host_shared () : memref<10x20xbf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>) + %alloc = memref.alloc() : memref<10x20xbf16> + memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16> + gpu.dealloc %memref_1 : memref<10x20xbf16> + gpu.dealloc %memref_0 : memref<10x20xbf16> + gpu.dealloc %memref : memref<10x20xbf16> + return %alloc : memref<10x20xbf16> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16> + %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16> + %2 = arith.addf %0, %1 : bf16 + memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16> + gpu.return + } + } + func.func @main() { + %0 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16> + %1 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16> + %2 = call @test(%0, %1) : (memref<10x20xbf16>, memref<10x20xbf16>) -> memref<10x20xbf16> + %cast = memref.cast %2 : memref<10x20xbf16> to memref<*xbf16> + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-COUNT-200: 1 + call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () + return + } + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/Gpu/gpu-to-llvm.pp b/test/Integration/Dialect/Gpu/gpu-to-llvm.pp new file mode 100644 index 000000000..4ed01a759 --- /dev/null +++ b/test/Integration/Dialect/Gpu/gpu-to-llvm.pp @@ -0,0 +1,28 @@ +// gpu dialect with intel intrinsic functions (func dialect) to +// llvm dialect (for host code) and +// spirv dialect (for device code) lowering pipeline. +// Ready for imex runner starting from GPU dialect. +builtin.module( + imex-vector-linearize + reconcile-unrealized-casts + imex-convert-gpu-to-spirv + spirv.module(spirv-lower-abi-attrs + spirv-update-vce) + func.func(llvm-request-c-wrappers) + serialize-spirv + convert-vector-to-scf + convert-gpu-to-gpux + convert-scf-to-cf + convert-cf-to-llvm + convert-vector-to-llvm + convert-index-to-llvm + convert-arith-to-llvm + convert-func-to-llvm + convert-math-to-llvm + convert-gpux-to-llvm + convert-index-to-llvm + expand-strided-metadata + lower-affine + finalize-memref-to-llvm + reconcile-unrealized-casts) +// End diff --git a/test/Integration/Dialect/XeGPU/gemm_256x256x256_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeGPU/gemm_256x256x256_bf16_bf16_f32.mlir new file mode 100644 index 000000000..648bae142 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/gemm_256x256x256_bf16_bf16_f32.mlir @@ -0,0 +1,615 @@ +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<256x256xbf16>, %B: memref<256x256xbf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %A_gpu = gpu.alloc host_shared () : memref<256x256xbf16> + memref.copy %A, %A_gpu : memref<256x256xbf16> to memref<256x256xbf16> + %B_gpu = gpu.alloc host_shared () : memref<256x256xbf16> + memref.copy %B, %B_gpu : memref<256x256xbf16> to memref<256x256xbf16> + %C_gpu = gpu.alloc host_shared () : memref<256x256xf32> + memref.copy %C, %C_gpu : memref<256x256xf32> to memref<256x256xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<256x256xbf16>, %B_gpu : memref<256x256xbf16>, %C_gpu : memref<256x256xf32>) + gpu.dealloc %A_gpu : memref<256x256xbf16> + gpu.dealloc %B_gpu : memref<256x256xbf16> + return %C_gpu : memref<256x256xf32> + } + + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<256x256xbf16>, %B: memref<256x256xbf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // constants + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c48 = arith.constant 48 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // get IDs + %wg_id_x = gpu.block_id x + %wg_id_y = gpu.block_id y + // %sg_id = gpu.subgroup_id : index + + // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout + // C sg tile size is 32x64 + // SG layout for one C tile update + // |0|1|2|3| + // |4|5|6|7| + // ......... + // |28|29|30|31| + // --> y means cols + // | + // V x means rows + + // get unique sg ID in global context + %global_sg_id_x = gpu.global_id x + %global_sg_id_y = gpu.global_id y + %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index + %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index + + // compute SG C tile offsets in x and y dims + %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index + %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index + + // each SG needs to do the follwoing compute to update its 32x64 sub tile + // (32x256)x(256x64)=(32x64) + // DPAS size is (8x16)x(16x16)=(8x16) + // K loop adavances in steps of 32, so inside the compute is (32x32)x(32x64) = (32x64) + // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) + // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) + // this will require 32 DPAS ops (4x2x2) inside the K loop + + // WG tiles are 256x256 so there offsets are same for A, B and C + %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index + %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index + + %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index + %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index + + // prefetching A and B slice within the 256x256 WG tile + // + // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice + // each 1x4 row of SGs do a colloborative prefetch of 8x32 slice of the 32x32 tile + // SG 0 -> slice 0 | + // SG 1 -> slice 1 | + // SG 2 -> slice 2 > SG 0,1,2,3 share data prefetch from the top 32x32 tile. + // SG 3 -> slice 3 | + // SG 4 -> slice 4 + // .... + // SG 31 -> slice 31 + %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index + %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index + // create A preftech tiles and prefetch + // stage 1 + %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) + %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 3 + %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + + // stage 4 + %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + + // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. + // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. + // this because the B tile arrangement within the 32x256 slice is as follows + // 32x64 | 32x64 | 32x64 | 32x64 + // in terms of 8x32 slices the arrangement is, + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // So SGs 0,1,2,3,....31 prefetch in following fashion + // | 0 | 16|| 1 | 17 || 2 | 18 || 3 | 19 | + // | 4 | 20|| 5 | 21 || 6 | 22 || 7 | 23 | + // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | + // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | + // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. + + // calculate the x offsets and y offsets within the 32x256 slice + // XeTLA like co-operative prefetch for B + %B_sg_prefetch_offset_x_temp0 = arith.divui %local_sg_id_x, %c2 : index + %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index + + %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index + %B_sg_prefetch_offset_y_temp1 = arith.remui %local_sg_id_x, %c2 : index + %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index + + %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index + %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index + + + // create B prefetch tiles and prefetch + %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) + %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 3 + %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 4 + %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + + %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %A_sg_init_tile_1 = xegpu.update_nd_offset %A_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + //create B tiles + %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + // ************************* // + + // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix + %zero_vec = arith.constant dense<0.0> : vector<128xf32> + %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + + // Multi nbarrier implementation, + // one set nbarrier is used to sync subgroups with same sg_id_x (local_sg_id_x) + // second set nbarrier us used to sync subgroups with same sg_id_y (local_sg_id_y) + // In this case wg_size = 8,4 (wg_size_x = 8; wg_size_y = 4) + // So in Y-direction we need 4 nbarrier (to sync subgroups with same sg_id_y) + // In X-direction we need 8 nbarrier (to sync subgroups with same sg_id_x) + %c_wg_size_x = arith.constant 8 : index + %c_wg_size_y = arith.constant 4 : index + %num_nbarrier = arith.addi %c_wg_size_y, %c_wg_size_x : index // 8+4=12 + xegpu.alloc_nbarrier 12 // = 12 + + // First set of nbarriers work across coloumns, we have 4 coloums of subgroups, + // Hnece 4 nbrrier + // Each nbarrier has 8 producers and consumers + // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) + + // %nbarrier_role = arith.constant 0 : i8 + %nbarrier_threads_y = arith.constant 8 : i8 + %nbarrier_id_y = arith.index_cast %local_sg_id_y : index to i8 + %nbarrier_y = xegpu.init_nbarrier %nbarrier_id_y, %nbarrier_threads_y : i8, i8 -> !xegpu.nbarrier + + // Second set of barriers work on across rows of subgroups, + // we have 8 rows of subgroups. Hnece, 8 nbarrier + // Each nbarrier has 4 producers and consumers + // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) + + // We already have 4 (=%c_wg_size_y) nbarriers with id 0-3, + // Now the next set of barrier id would start from 4, hence, + %nbarrier_threads_x = arith.constant 4 : i8 + %index_nbarrier_id_x = arith.addi %c_wg_size_y, %local_sg_id_x : index + %nbarrier_id_x = arith.index_cast %index_nbarrier_id_x : index to i8 + %nbarrier_x = xegpu.init_nbarrier %nbarrier_id_x, %nbarrier_threads_x : i8, i8 -> !xegpu.nbarrier + + // K loop advances in 32 steps + %k_loop_result:24 = scf.for %k = %c0 to %c256 step %c32 iter_args ( + %A_tile_0 = %A_sg_init_tile_0, + %A_tile_1 = %A_sg_init_tile_1, + + %B_tile_0 = %B_sg_init_tile_0, + %B_tile_1 = %B_sg_init_tile_1, + %B_tile_2 = %B_sg_init_tile_2, + %B_tile_3 = %B_sg_init_tile_3, + + %c_val_0_0 = %c_init_val_0_0, + %c_val_0_1 = %c_init_val_0_1, + %c_val_0_2 = %c_init_val_0_2, + %c_val_0_3 = %c_init_val_0_3, + %c_val_1_0 = %c_init_val_1_0, + %c_val_1_1 = %c_init_val_1_1, + %c_val_1_2 = %c_init_val_1_2, + %c_val_1_3 = %c_init_val_1_3, + %c_val_2_0 = %c_init_val_2_0, + %c_val_2_1 = %c_init_val_2_1, + %c_val_2_2 = %c_init_val_2_2, + %c_val_2_3 = %c_init_val_2_3, + %c_val_3_0 = %c_init_val_3_0, + %c_val_3_1 = %c_init_val_3_1, + %c_val_3_2 = %c_init_val_3_2, + %c_val_3_3 = %c_init_val_3_3, + + %A_prefetch_tile = %A_sg_prefetch_tile_iter2, + %B_prefetch_tile = %B_sg_prefetch_tile_iter2 + ) -> + (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, + !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + ) + { + // all SGs must arrive here first + %every_8th_iter = arith.remui %k, %c32 : index + %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 + %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 + scf.if %every_8th_iter_cond { + xegpu.nbarrier_arrive %nbarrier_y : !xegpu.nbarrier + xegpu.nbarrier_arrive %nbarrier_x : !xegpu.nbarrier + } + + // Load smaller load (16 registers) with cache line size width : 64 bytes, 32 elements + // Although maximum load size supported is 2KB or 32 registers, we use smaller loads, for 2 main reasons: + // ** 1. Hide load latency: we do smaller load means for B, we do 4 loads, we set up the loads and dpas orderring in + // such a way that, the first set of DPAS works on data loaded by first 2 load operations, as a result the + // second set of loads' latency can be hidden by the first set of DPAS operations. + // + // ** 2. Reduce the impact of L3 miss: Larger load means more cache lines to be loaded, more chance of potential L3 miss + // which could increase the load time + + // load B tiles + %b_val_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_2 = xegpu.load_nd %B_tile_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_3 = xegpu.load_nd %B_tile_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + + // load A tiles + %a_val_0 = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + %a_val_1 = xegpu.load_nd %A_tile_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + + xegpu.compile_hint + + // prefetch A and B tiles + xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + + xegpu.compile_hint + + // advance A and B prefetch tiles + %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + // advance A and B tiles + %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %b_val_0_flat = vector.shape_cast %b_val_0 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_1_flat = vector.shape_cast %b_val_1 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_2_flat = vector.shape_cast %b_val_2 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_3_flat = vector.shape_cast %b_val_3 : vector<2x8x16x2xbf16> to vector<512xbf16> + + // b[0,0], b[0,1] + %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_0_1_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[0,2], b[0,3] + %b_val_0_2_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_0_3_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[1,0], b[1,1] + %b_val_1_0_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_1_1_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[1,2], b[1,3] + %b_val_1_2_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xbf16> to vector<8x16x2xbf16> + + + // xegpu.compile_hint + %a_val_0_flat = vector.shape_cast %a_val_0 : vector<2x16x16xbf16> to vector<512xbf16> + %a_val_1_flat = vector.shape_cast %a_val_1 : vector<2x16x16xbf16> to vector<512xbf16> + + %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_0_1_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + %a_val_1_1_flat = vector.extract_strided_slice %a_val_0_flat {offsets = [384], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_2_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_3_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [128], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [384], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + + + // do DPAS + xegpu.compile_hint + + %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + xegpu.compile_hint + + %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + // barrier wait + scf.if %every_8th_iter_cond { + xegpu.nbarrier_wait %nbarrier_y : !xegpu.nbarrier + xegpu.nbarrier_wait %nbarrier_x : !xegpu.nbarrier + } + + scf.yield %next_A_tile_0, %next_A_tile_1, %next_B_tile_0, %next_B_tile_1, %next_B_tile_2, %next_B_tile_3, + %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, + %next_A_prefetch_tile, %next_B_prefetch_tile + : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, + !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + } + + // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles + // in the layout of 4x4x8x16. The max store size HW supoprt in f32 is 8x16. + + %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#6, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#7, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_02 = xegpu.update_nd_offset %c_sg_tile_01, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#8, %c_sg_tile_02 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_03 = xegpu.update_nd_offset %c_sg_tile_02, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#9, %c_sg_tile_03 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#10, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#11, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_12 = xegpu.update_nd_offset %c_sg_tile_02, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#12, %c_sg_tile_12 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_13 = xegpu.update_nd_offset %c_sg_tile_03, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#13, %c_sg_tile_13 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#14, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#15, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_22 = xegpu.update_nd_offset %c_sg_tile_12, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#16, %c_sg_tile_22 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_23 = xegpu.update_nd_offset %c_sg_tile_13, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#17, %c_sg_tile_23 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#18, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#19, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_32 = xegpu.update_nd_offset %c_sg_tile_22, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#20, %c_sg_tile_32 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_33 = xegpu.update_nd_offset %c_sg_tile_23, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#21, %c_sg_tile_33 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + gpu.return + } + } + + // compute CPU reference (takes minutes) + func.func @cpu_reference(%A : memref<256x256xbf16>, %B : memref<256x256xbf16>, %C : memref<256x256xf32>) { + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %c_curr = memref.load %C[%i, %j] : memref<256x256xf32> + %c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 { + %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { + %k_dpas = arith.addi %k_tile, %k : index + %a_val = memref.load %A[%i, %k_dpas] : memref<256x256xbf16> + %b_val = memref.load %B[%k_dpas, %j] : memref<256x256xbf16> + %a_cast = arith.extf %a_val : bf16 to f32 + %b_cast = arith.extf %b_val : bf16 to f32 + %t = arith.mulf %a_cast, %b_cast : f32 + %c_sum = arith.addf %t, %c_dpas_partial : f32 + scf.yield %c_sum : f32 + } + scf.yield %c_val_dpas : f32 + } + memref.store %c_val , %C[%i, %j] : memref<256x256xf32> + } + } + return + } + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : bf16 + %c2_f16 = arith.constant 2.0 : bf16 + %c256 = arith.constant 256 : index + %cf_0 = arith.constant 0.0 : bf16 + %cf_1 = arith.constant 1.0 : bf16 + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant 0.0 : f32 + %cf_upper = arith.constant 1.0 : f32 + + %A = memref.alloc() : memref<256x256xbf16> + %B = memref.alloc() : memref<256x256xbf16> + %C = memref.alloc() : memref<256x256xf32> + %C_ref = memref.alloc() : memref<256x256xf32> + + // Use one of the two options to initialize the A matrix + // Option 1: intialize matrix A ; A[i, j] = j + // scf.for %i = %c0 to %c256 step %c1 { + // scf.for %j = %c0 to %c256 step %c1 { + // %t = index.castu %j : index to i16 + // %val = arith.uitofp %t : i16 to bf16 + // memref.store %val, %A[%i, %j] : memref<256x256xbf16> + // // memref.store %c1_f16, %A[%i, %j] : memref<256x256xbf16> + // // memref.store %c2_f16, %B[%i, %j] : memref<256x256xbf16> + // } + // } + // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) + %A_random = memref.cast %A : memref<256x256xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // Use one of the two options below to initialize the B matrix + // Option 1: make matrix B an identity matrix + // scf.for %i = %c0 to %c256 step %c1 { + // scf.for %j = %c0 to %c256 step %c1 { + // %i_i32 = index.castu %i : index to i32 + // %j_i32 = index.castu %j : index to i32 + // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + // scf.if %i_j_same { + // memref.store %cf_1, %B[%i, %j] : memref<256x256xbf16> + // } else { + // memref.store %cf_0, %B[%i, %j] : memref<256x256xbf16> + // } + // } + // } + // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) + %B_random = memref.cast %B : memref<256x256xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f16 = arith.constant 0.0 : bf16 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + } + } + + // run GPU + %2 = call @test(%A, %B, %C) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> memref<256x256xf32> + + // run CPU + call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> () + + %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<256x256xbf16> + memref.dealloc %B : memref<256x256xbf16> + memref.dealloc %C : memref<256x256xf32> + memref.dealloc %C_ref : memref<256x256xf32> + gpu.dealloc %2 : memref<256x256xf32> + return + } + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + +} diff --git a/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir new file mode 100644 index 000000000..af8c31649 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir @@ -0,0 +1,621 @@ +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + // Explicit memory copy to and from host + %A_gpu = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %A_gpu, %A : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %B_gpu = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %B_gpu, %B : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %C_gpu = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %C_gpu, %C : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) + %C_host = memref.alloc() : memref<4096x4096xf32> + gpu.memcpy %C_host, %C_gpu : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %A_gpu : memref<4096x4096xbf16> + gpu.dealloc %B_gpu : memref<4096x4096xbf16> + gpu.dealloc %C_gpu : memref<4096x4096xf32> + return %C_host : memref<4096x4096xf32> + // ******************************************* + + } + + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // constants + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c48 = arith.constant 48 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // get IDs + %wg_id_x = gpu.block_id x + %wg_id_y = gpu.block_id y + // %sg_id = gpu.subgroup_id : index + + // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout + // C sg tile size is 32x64 + // SG layout for one C tile update + // |0|1|2|3| + // |4|5|6|7| + // ......... + // |28|29|30|31| + // --> y means cols + // | + // V x means rows + + // get unique sg ID in global context + %global_sg_id_x = gpu.global_id x + %global_sg_id_y = gpu.global_id y + %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index + %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index + + // compute SG C tile offsets in x and y dims + %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index + %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index + + // each SG needs to do the follwoing compute to update its 32x64 sub tile + // (32x4096)x(4096x64)=(32x64) + // DPAS size is (8x16)x(16x16)=(8x16) + // K loop adavances in steps of 32, so inside the compute is (32x32)x(32x64) = (32x64) + // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) + // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) + // this will require 32 DPAS ops (4x2x2) inside the K loop + + // WG tiles are 256x256 so there offsets are same for A, B and C + %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index + %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index + + %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index + %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index + + // prefetching A and B slice within the 256x256 WG tile + // + // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice + // each 1x4 row of SGs do a colloborative prefetch of 8x32 slice of the 32x32 tile + // SG 0 -> slice 0 | + // SG 1 -> slice 1 | + // SG 2 -> slice 2 > SG 0,1,2,3 share data prefetch from the top 32x32 tile. + // SG 3 -> slice 3 | + // SG 4 -> slice 4 + // .... + // SG 31 -> slice 31 + %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index + %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index + // create A preftech tiles and prefetch + // stage 1 + %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) + %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 3 + %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + + // stage 4 + %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + + // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. + // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. + // this because the B tile arrangement within the 32x256 slice is as follows + // 32x64 | 32x64 | 32x64 | 32x64 + // in terms of 8x32 slices the arrangement is, + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 || 8x32 | 8x32 + // So SGs 0,1,2,3,....31 prefetch in following fashion + // | 0 | 16|| 1 | 17 || 2 | 18 || 3 | 19 | + // | 4 | 20|| 5 | 21 || 6 | 22 || 7 | 23 | + // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | + // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | + // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. + + // calculate the x offsets and y offsets within the 32x256 slice + // XeTLA like co-operative prefetch for B + %B_sg_prefetch_offset_x_temp0 = arith.divui %local_sg_id_x, %c2 : index + %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index + + %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index + %B_sg_prefetch_offset_y_temp1 = arith.remui %local_sg_id_x, %c2 : index + %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index + + %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index + %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index + + + // create B prefetch tiles and prefetch + %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) + %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 3 + %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + // stage 4 + %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + + %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %A_sg_init_tile_1 = xegpu.update_nd_offset %A_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + //create B tiles + %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + // ************************* // + + // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix + %zero_vec = arith.constant dense<0.0> : vector<128xf32> + %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> + + // Multi nbarrier implementation, + // one set nbarrier is used to sync subgroups with same sg_id_x (local_sg_id_x) + // second set nbarrier us used to sync subgroups with same sg_id_y (local_sg_id_y) + // In this case wg_size = 8,4 (wg_size_x = 8; wg_size_y = 4) + // So in Y-direction we need 4 nbarrier (to sync subgroups with same sg_id_y) + // In X-direction we need 8 nbarrier (to sync subgroups with same sg_id_x) + %c_wg_size_x = arith.constant 8 : index + %c_wg_size_y = arith.constant 4 : index + %num_nbarrier = arith.addi %c_wg_size_y, %c_wg_size_x : index // 8+4=12 + xegpu.alloc_nbarrier 12 // = 12 + + // First set of nbarriers work across coloumns, we have 4 coloums of subgroups, + // Hnece 4 nbrrier + // Each nbarrier has 8 producers and consumers + // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) + + // %nbarrier_role = arith.constant 0 : i8 + %nbarrier_threads_y = arith.constant 8 : i8 + %nbarrier_id_y = arith.index_cast %local_sg_id_y : index to i8 + %nbarrier_y = xegpu.init_nbarrier %nbarrier_id_y, %nbarrier_threads_y : i8, i8 -> !xegpu.nbarrier + + // Second set of barriers work on across rows of subgroups, + // we have 8 rows of subgroups. Hnece, 8 nbarrier + // Each nbarrier has 4 producers and consumers + // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) + + // We already have 4 (=%c_wg_size_y) nbarriers with id 0-3, + // Now the next set of barrier id would start from 4, hence, + %nbarrier_threads_x = arith.constant 4 : i8 + %index_nbarrier_id_x = arith.addi %c_wg_size_y, %local_sg_id_x : index + %nbarrier_id_x = arith.index_cast %index_nbarrier_id_x : index to i8 + %nbarrier_x = xegpu.init_nbarrier %nbarrier_id_x, %nbarrier_threads_x : i8, i8 -> !xegpu.nbarrier + + // K loop advances in 32 steps + %k_loop_result:24 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( + %A_tile_0 = %A_sg_init_tile_0, + %A_tile_1 = %A_sg_init_tile_1, + + %B_tile_0 = %B_sg_init_tile_0, + %B_tile_1 = %B_sg_init_tile_1, + %B_tile_2 = %B_sg_init_tile_2, + %B_tile_3 = %B_sg_init_tile_3, + + %c_val_0_0 = %c_init_val_0_0, + %c_val_0_1 = %c_init_val_0_1, + %c_val_0_2 = %c_init_val_0_2, + %c_val_0_3 = %c_init_val_0_3, + %c_val_1_0 = %c_init_val_1_0, + %c_val_1_1 = %c_init_val_1_1, + %c_val_1_2 = %c_init_val_1_2, + %c_val_1_3 = %c_init_val_1_3, + %c_val_2_0 = %c_init_val_2_0, + %c_val_2_1 = %c_init_val_2_1, + %c_val_2_2 = %c_init_val_2_2, + %c_val_2_3 = %c_init_val_2_3, + %c_val_3_0 = %c_init_val_3_0, + %c_val_3_1 = %c_init_val_3_1, + %c_val_3_2 = %c_init_val_3_2, + %c_val_3_3 = %c_init_val_3_3, + + %A_prefetch_tile = %A_sg_prefetch_tile_iter2, + %B_prefetch_tile = %B_sg_prefetch_tile_iter2 + ) -> + (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, + !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + ) + { + // all SGs must arrive here first + %every_8th_iter = arith.remui %k, %c32 : index + %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 + %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 + scf.if %every_8th_iter_cond { + xegpu.nbarrier_arrive %nbarrier_y : !xegpu.nbarrier + xegpu.nbarrier_arrive %nbarrier_x : !xegpu.nbarrier + } + + // Load smaller load (16 registers) with cache line size width : 64 bytes, 32 elements + // Although maximum load size supported is 2KB or 32 registers, we use smaller loads, for 2 main reasons: + // ** 1. Hide load latency: we do smaller load means for B, we do 4 loads, we set up the loads and dpas orderring in + // such a way that, the first set of DPAS works on data loaded by first 2 load operations, as a result the + // second set of loads' latency can be hidden by the first set of DPAS operations. + // + // ** 2. Reduce the impact of L3 miss: Larger load means more cache lines to be loaded, more chance of potential L3 miss + // which could increase the load time + + // load B tiles + %b_val_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_2 = xegpu.load_nd %B_tile_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %b_val_3 = xegpu.load_nd %B_tile_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + + // load A tiles + %a_val_0 = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + %a_val_1 = xegpu.load_nd %A_tile_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + + xegpu.compile_hint + + // prefetch A and B tiles + xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> + + xegpu.compile_hint + + // advance A and B prefetch tiles + %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + // advance A and B tiles + %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + + %b_val_0_flat = vector.shape_cast %b_val_0 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_1_flat = vector.shape_cast %b_val_1 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_2_flat = vector.shape_cast %b_val_2 : vector<2x8x16x2xbf16> to vector<512xbf16> + %b_val_3_flat = vector.shape_cast %b_val_3 : vector<2x8x16x2xbf16> to vector<512xbf16> + + // b[0,0], b[0,1] + %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_0_1_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[0,2], b[0,3] + %b_val_0_2_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_0_3_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[1,0], b[1,1] + %b_val_1_0_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_1_1_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xbf16> to vector<8x16x2xbf16> + + // b[1,2], b[1,3] + %b_val_1_2_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xbf16> to vector<8x16x2xbf16> + %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : + vector<512xbf16> to vector<256xbf16> + %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xbf16> to vector<8x16x2xbf16> + + + // xegpu.compile_hint + %a_val_0_flat = vector.shape_cast %a_val_0 : vector<2x16x16xbf16> to vector<512xbf16> + %a_val_1_flat = vector.shape_cast %a_val_1 : vector<2x16x16xbf16> to vector<512xbf16> + + %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_0_1_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + %a_val_1_1_flat = vector.extract_strided_slice %a_val_0_flat {offsets = [384], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_2_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_3_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [128], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xbf16> to vector<8x8x2xbf16> + + %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [384], sizes = [128], strides = [1]} : + vector<512xbf16> to vector<128xbf16> + %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xbf16> to vector<8x8x2xbf16> + + + // do DPAS + xegpu.compile_hint + + %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + xegpu.compile_hint + + %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + // barrier wait + scf.if %every_8th_iter_cond { + xegpu.nbarrier_wait %nbarrier_y : !xegpu.nbarrier + xegpu.nbarrier_wait %nbarrier_x : !xegpu.nbarrier + } + + scf.yield %next_A_tile_0, %next_A_tile_1, %next_B_tile_0, %next_B_tile_1, %next_B_tile_2, %next_B_tile_3, + %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, + %next_A_prefetch_tile, %next_B_prefetch_tile + : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, + vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, + !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + } + + // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles + // in the layout of 4x4x8x16. The max store size HW supoprt in f32 is 8x16. + + %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#6, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#7, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_02 = xegpu.update_nd_offset %c_sg_tile_01, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#8, %c_sg_tile_02 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_03 = xegpu.update_nd_offset %c_sg_tile_02, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#9, %c_sg_tile_03 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#10, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#11, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_12 = xegpu.update_nd_offset %c_sg_tile_02, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#12, %c_sg_tile_12 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_13 = xegpu.update_nd_offset %c_sg_tile_03, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#13, %c_sg_tile_13 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#14, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#15, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_22 = xegpu.update_nd_offset %c_sg_tile_12, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#16, %c_sg_tile_22 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_23 = xegpu.update_nd_offset %c_sg_tile_13, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#17, %c_sg_tile_23 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#18, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.compile_hint + + %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#19, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_32 = xegpu.update_nd_offset %c_sg_tile_22, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#20, %c_sg_tile_32 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + %c_sg_tile_33 = xegpu.update_nd_offset %c_sg_tile_23, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %k_loop_result#21, %c_sg_tile_33 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + + gpu.return + } + } + + // compute CPU reference (takes minutes) + func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { + %c4096 = arith.constant 4096 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> + %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { + %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { + %k_dpas = arith.addi %k_tile, %k : index + %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> + %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> + %a_cast = arith.extf %a_val : bf16 to f32 + %b_cast = arith.extf %b_val : bf16 to f32 + %t = arith.mulf %a_cast, %b_cast : f32 + %c_sum = arith.addf %t, %c_dpas_partial : f32 + scf.yield %c_sum : f32 + } + scf.yield %c_val_dpas : f32 + } + memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> + } + } + return + } + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : bf16 + %c2_f16 = arith.constant 2.0 : bf16 + %c4096 = arith.constant 4096 : index + %cf_0 = arith.constant 0.0 : bf16 + %cf_1 = arith.constant 1.0 : bf16 + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant 0.0 : f32 + %cf_upper = arith.constant 1.0 : f32 + + %A = memref.alloc() : memref<4096x4096xbf16> + %B = memref.alloc() : memref<4096x4096xbf16> + %C = memref.alloc() : memref<4096x4096xf32> + %C_ref = memref.alloc() : memref<4096x4096xf32> + + // Use one of the two options to initialize the A matrix + // Option 1: intialize matrix A ; A[i, j] = j + // scf.for %i = %c0 to %c4096 step %c1 { + // scf.for %j = %c0 to %c4096 step %c1 { + // %t = index.castu %j : index to i16 + // %val = arith.uitofp %t : i16 to bf16 + // memref.store %val, %A[%i, %j] : memref<4096x4096xbf16> + // // memref.store %c1_f16, %A[%i, %j] : memref<4096x4096xbf16> + // // memref.store %c2_f16, %B[%i, %j] : memref<4096x4096xbf16> + // } + // } + // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) + %A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // Use one of the two options below to initialize the B matrix + // Option 1: make matrix B an identity matrix + // scf.for %i = %c0 to %c4096 step %c1 { + // scf.for %j = %c0 to %c4096 step %c1 { + // %i_i32 = index.castu %i : index to i32 + // %j_i32 = index.castu %j : index to i32 + // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + // scf.if %i_j_same { + // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xbf16> + // } else { + // memref.store %cf_0, %B[%i, %j] : memref<4096x4096xbf16> + // } + // } + // } + // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) + %B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f16 = arith.constant 0.0 : bf16 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + } + } + + // run GPU + %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + + // run CPU + call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () + + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<4096x4096xbf16> + memref.dealloc %B : memref<4096x4096xbf16> + memref.dealloc %C : memref<4096x4096xf32> + memref.dealloc %C_ref : memref<4096x4096xf32> + return + } + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} + +} diff --git a/test/PlaidML/CppEdsl.Convolution_BF16.mlir b/test/PlaidML/CppEdsl.Convolution_BF16.mlir new file mode 100644 index 000000000..c75b471a8 --- /dev/null +++ b/test/PlaidML/CppEdsl.Convolution_BF16.mlir @@ -0,0 +1,51 @@ +// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/linalg-to-cpu.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils \ +// RUN: --entry-point-result=void --filecheck +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module @convolution { +func.func @test(%arg0: tensor<1x56x56x64xbf16>, %arg1: tensor<3x3x64x64xbf16>) -> tensor<1x56x56x64xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<1x56x56x64xbf16> + %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x56x56x64xbf16>) outs(%0 : tensor<1x56x56x64xbf16>) { + ^bb0(%arg2: bf16, %arg3: bf16): + linalg.yield %arg2 : bf16 + } -> tensor<1x56x56x64xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %2 = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst_0 : bf16 + } : tensor<1x56x56x64xbf16> to tensor<1x58x58x64xbf16> + %3 = tensor.empty() : tensor<1x56x56x64xbf16> + %4 = linalg.fill ins(%cst : bf16) outs(%3 : tensor<1x56x56x64xbf16>) -> tensor<1x56x56x64xbf16> + %5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%2, %arg1 : tensor<1x58x58x64xbf16>, tensor<3x3x64x64xbf16>) outs(%4 : tensor<1x56x56x64xbf16>) attrs = {iterator_ranges = [1, 56, 56, 64, 3, 3, 64]} { + ^bb0(%arg2: bf16, %arg3: bf16, %arg4: bf16): + %6 = arith.mulf %arg2, %arg3 : bf16 + %7 = arith.addf %arg4, %6 : bf16 + linalg.yield %7 : bf16 + } -> tensor<1x56x56x64xbf16> + return %5 : tensor<1x56x56x64xbf16> + } + func.func @main() { + %0 = arith.constant dense<1.0> : tensor<1x56x56x64xbf16> + %1 = arith.constant dense<0.5> : tensor<3x3x64x64xbf16> + %2 = call @test(%0, %1) : (tensor<1x56x56x64xbf16>, tensor<3x3x64x64xbf16>) -> tensor<1x56x56x64xbf16> + %unranked = tensor.cast %2 : tensor<1x56x56x64xbf16> to tensor<*xbf16> + call @printMemrefBF16(%unranked) : (tensor<*xbf16>) -> () + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-NEXT: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + return + } + func.func private @printMemrefBF16(%ptr : tensor<*xbf16>) +} diff --git a/test/PlaidML/OpTest.BroadcastNonNumpy_BF16.mlir b/test/PlaidML/OpTest.BroadcastNonNumpy_BF16.mlir new file mode 100644 index 000000000..ffa0780ae --- /dev/null +++ b/test/PlaidML/OpTest.BroadcastNonNumpy_BF16.mlir @@ -0,0 +1,39 @@ +// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/linalg-to-cpu.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils \ +// RUN: --entry-point-result=void --filecheck +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +#map0 = affine_map<(d0, d1) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +module @broadcast_non_numpy { + func.func @test(%arg0: tensor<3xbf16>) -> tensor<3x4xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<3x4xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<3x4xbf16>) -> tensor<3x4xbf16> + %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<3xbf16>) outs(%1 : tensor<3x4xbf16>) attrs = {iterator_ranges = [3, 4], name = "broadcast"} { + ^bb0(%arg1: bf16, %arg2: bf16): + linalg.yield %arg1 : bf16 + } -> tensor<3x4xbf16> + return %2 : tensor<3x4xbf16> + } + func.func @main() { + %0 = arith.constant dense<[1.0, 2.0, 3.0]> : tensor<3xbf16> + %2 = call @test(%0) : (tensor<3xbf16>) -> tensor<3x4xbf16> + %unranked = tensor.cast %2 : tensor<3x4xbf16> to tensor<*xbf16> + call @printMemrefBF16(%unranked) : (tensor<*xbf16>) -> () + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3] + return + } + + func.func private @printMemrefBF16(%ptr : tensor<*xbf16>) +} diff --git a/test/PlaidML/OpTest.Relu_BF16.mlir b/test/PlaidML/OpTest.Relu_BF16.mlir new file mode 100644 index 000000000..0fe913ffc --- /dev/null +++ b/test/PlaidML/OpTest.Relu_BF16.mlir @@ -0,0 +1,64 @@ +// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/linalg-to-cpu.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils \ +// RUN: --entry-point-result=void --filecheck +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> ()> +module @relu { +func.func @main() { + %0= arith.constant dense<[[-0.1, -0.2, -0.3, 0.4, 0.5], [0.1, -0.2, 0.3, -0.4, 0.5], [0.1, 0.2, 0.3, -0.4, -0.5], [0.1, 0.2, 0.3, 0.4, 0.5]]>:tensor<4x5xbf16> + %1 = call @test(%0) : (tensor<4x5xbf16>) -> tensor<4x5xbf16> + %unranked = tensor.cast %1 : tensor<4x5xbf16>to tensor<*xbf16> + call @printMemrefBF16(%unranked) : (tensor<*xbf16>) -> () + return +} +func.func private @printMemrefBF16(tensor<*xbf16>) +func.func @test(%arg0: tensor<4x5xbf16>)->tensor<4x5xbf16>{ + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4x5xi1> + %1 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst : tensor<4x5xbf16>, bf16) outs(%0 : tensor<4x5xi1>) { + ^bb0(%arg1: bf16, %arg2: bf16, %arg3: i1): + %arg1_f32 = arith.extf %arg1 : bf16 to f32 + %arg2_f32 = arith.extf %arg2 : bf16 to f32 + %4 = arith.cmpf olt, %arg1_f32, %arg2_f32 : f32 + linalg.yield %4 : i1 + } -> tensor<4x5xi1> + %2 = tensor.empty() : tensor<4x5xbf16> + %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%1, %cst, %arg0 : tensor<4x5xi1>, bf16, tensor<4x5xbf16>) outs(%2 : tensor<4x5xbf16>) { + ^bb0(%arg1: i1, %arg2: bf16, %arg3: bf16, %arg4: bf16): + %4 = arith.select %arg1, %arg2, %arg3 : bf16 + linalg.yield %4 : bf16 + } -> tensor<4x5xbf16> + return %3 : tensor<4x5xbf16> + } +} +// CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} +// CHECK-SAME: rank = {{.}} offset = {{.}} sizes = [4, 5] strides = {{.*}} data = +// CHECK: 0 +// CHECK: 0 +// CHECK: 0 +// CHECK: 0.4 +// CHECK: 0.5 +// CHECK: 0.1 +// CHECK: 0 +// CHECK: 0.3 +// CHECK: 0 +// CHECK: 0.5 +// CHECK: 0.1 +// CHECK: 0.2 +// CHECK: 0.3 +// CHECK: 0 +// CHECK: 0 +// CHECK: 0.1 +// CHECK: 0.2 +// CHECK: 0.3 +// CHECK: 0.4 +// CHECK: 0.5 diff --git a/test/PlaidML/OpTest.Transpose_BF16.mlir b/test/PlaidML/OpTest.Transpose_BF16.mlir new file mode 100644 index 000000000..be87190ac --- /dev/null +++ b/test/PlaidML/OpTest.Transpose_BF16.mlir @@ -0,0 +1,66 @@ +// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/linalg-to-cpu.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils \ +// RUN: --entry-point-result=void --filecheck +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +module @transpose { + func.func @test(%arg0: tensor<10x20xbf16>) -> tensor<20x10xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<20x10xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<20x10xbf16>) -> tensor<20x10xbf16> + %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<10x20xbf16>) outs(%1 : tensor<20x10xbf16>) attrs = {iterator_ranges = [20, 10], name = "transpose"} { + ^bb0(%arg1: bf16, %arg2: bf16): + linalg.yield %arg1 : bf16 + } -> tensor<20x10xbf16> + return %2 : tensor<20x10xbf16> + } + func.func @main() { + %0 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0] + ]> : tensor<10x20xbf16> + %2 = call @test(%0) : (tensor<10x20xbf16>) -> tensor<20x10xbf16> + %unranked = tensor.cast %2 : tensor<20x10xbf16> to tensor<*xbf16> + call @printMemrefBF16(%unranked) : (tensor<*xbf16>) -> () + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + // CHECK-NEXT: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + // CHECK-NEXT: [5, 5, 5, 5, 5, 5, 5, 5, 5, 5] + // CHECK-NEXT: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6] + // CHECK-NEXT: [7, 7, 7, 7, 7, 7, 7, 7, 7, 7] + // CHECK-NEXT: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8] + // CHECK-NEXT: [9, 9, 9, 9, 9, 9, 9, 9, 9, 9] + // CHECK-NEXT: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10] + // CHECK-NEXT: [11, 11, 11, 11, 11, 11, 11, 11, 11, 11] + // CHECK-NEXT: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12] + // CHECK-NEXT: [13, 13, 13, 13, 13, 13, 13, 13, 13, 13] + // CHECK-NEXT: [14, 14, 14, 14, 14, 14, 14, 14, 14, 14] + // CHECK-NEXT: [15, 15, 15, 15, 15, 15, 15, 15, 15, 15] + // CHECK-NEXT: [16, 16, 16, 16, 16, 16, 16, 16, 16, 16] + // CHECK-NEXT: [17, 17, 17, 17, 17, 17, 17, 17, 17, 17] + // CHECK-NEXT: [18, 18, 18, 18, 18, 18, 18, 18, 18, 18] + // CHECK-NEXT: [19, 19, 19, 19, 19, 19, 19, 19, 19, 19] + // CHECK-NEXT: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20] + return + } + + func.func private @printMemrefBF16(%ptr : tensor<*xbf16>) +} diff --git a/test/Transforms/set-spirv-capability.mlir b/test/Transforms/set-spirv-capability.mlir index fe2965768..954015b50 100644 --- a/test/Transforms/set-spirv-capability.mlir +++ b/test/Transforms/set-spirv-capability.mlir @@ -4,7 +4,7 @@ module attributes {gpu.container_module} { // OPENCL: module attributes {gpu.container_module} { -// OPENCL: gpu.module @main_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { +// OPENCL: gpu.module @main_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // VULKAN: module attributes {gpu.container_module} { // VULKAN: gpu.module @main_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits<>>} { gpu.module @main_kernel {