-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a JointMatrix test case and necessary patch
This commit adds: - A JointMatrix test case - A patch that contains the updated definition of JointMatrix spec supported by IGC. Current upstream definition of JointMatrix is not supported by IGC anymore.
- Loading branch information
Showing
2 changed files
with
540 additions
and
0 deletions.
There are no files selected for viewing
306 changes: 306 additions & 0 deletions
306
build_tools/patches/0003--Update-the-Joint-Matrix-support-to-match-IGC-Spec.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
From 810eccbbbb872402391a0f01a53aaf0205ea10c4 Mon Sep 17 00:00:00 2001 | ||
From: Md Abdullah Shahneous Bari <[email protected]> | ||
Date: Tue, 8 Aug 2023 00:28:31 +0000 | ||
Subject: [PATCH] Update the Joint Matrix support to match Spec supported by | ||
IGC | ||
|
||
Update the Joint Matrix support to match the following spec: | ||
https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc | ||
--- | ||
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 30 ++++++++++++++----- | ||
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 +++- | ||
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 12 ++++++-- | ||
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 20 +++++++++---- | ||
.../SPIRV/Deserialization/Deserializer.cpp | 17 +++++++---- | ||
.../Target/SPIRV/Serialization/Serializer.cpp | 5 +++- | ||
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 +++++----- | ||
7 files changed, 75 insertions(+), 31 deletions(-) | ||
|
||
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | ||
index 6f0f728f811e..c2ad6ff24bea 100644 | ||
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | ||
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | ||
@@ -4039,16 +4039,30 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr< | ||
"image_sampler_use_info", | ||
[SPIRV_ISUI_SamplerUnknown, SPIRV_ISUI_NeedSampler, SPIRV_ISUI_NoSampler]>; | ||
|
||
-def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>; | ||
-def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>; | ||
-def SPIRV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>; | ||
-def SPIRV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>; | ||
- | ||
-def SPIRV_MatrixLayoutAttr : | ||
- SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [ | ||
- SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB | ||
+// Change the layout parameter to IGC spec, the currnet MLIR version | ||
+// does not match the IGC spec, IGC spec has been updated | ||
+// https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc | ||
+ | ||
+def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 0>; | ||
+def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>; | ||
+def SPIRV_ML_Packed : I32EnumAttrCase<"Packed", 2>; | ||
+def SPIRV_ML_Unused : I32EnumAttrCase<"Unused", 3>; | ||
+ | ||
+ def SPIRV_MatrixLayoutAttr : | ||
+ SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [ | ||
+ SPIRV_ML_RowMajor, SPIRV_ML_ColumnMajor, SPIRV_ML_Packed, SPIRV_ML_Unused | ||
]>; | ||
|
||
+def SPIRV_ML_MATRIX_A : I32EnumAttrCase<"MatrixA", 0>; | ||
+def SPIRV_ML_MATRIX_B : I32EnumAttrCase<"MatrixB", 1>; | ||
+def SPIRV_ML_MATRIX_ACC : I32EnumAttrCase<"Accumulator", 2>; | ||
+ | ||
+def SPIRV_MatrixUseAttr : | ||
+ SPIRV_I32EnumAttr<"MatrixUse", "valid SPIR-V MatrixUse", "matrixUse", [ | ||
+ SPIRV_ML_MATRIX_A, SPIRV_ML_MATRIX_B, SPIRV_ML_MATRIX_ACC | ||
+ ]>; | ||
+ | ||
+ | ||
// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension. | ||
def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>; | ||
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>; | ||
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | ||
index 07f2f158ecab..e0b3c5448a44 100644 | ||
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | ||
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h | ||
@@ -459,7 +459,8 @@ public: | ||
using Base::Base; | ||
|
||
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, | ||
- unsigned columns, MatrixLayout matrixLayout); | ||
+ unsigned columns, MatrixLayout matrixLayout, | ||
+ MatrixUse matrixUse); | ||
Type getElementType() const; | ||
|
||
/// Return the scope of the joint matrix. | ||
@@ -472,6 +473,9 @@ public: | ||
/// return the layout of the matrix | ||
MatrixLayout getMatrixLayout() const; | ||
|
||
+ /// return the use of the matrix | ||
+ MatrixUse getMatrixUse() const; | ||
+ | ||
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, | ||
std::optional<StorageClass> storage = std::nullopt); | ||
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, | ||
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | ||
index 9188f8b699b4..4c099bf77a88 100644 | ||
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | ||
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | ||
@@ -394,7 +394,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect, | ||
|
||
// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x` | ||
// element-type | ||
-// `,` layout `,` scope`>` | ||
+// `,` layout `,` scope | ||
+// `,` use`>` | ||
static Type parseJointMatrixType(SPIRVDialect const &dialect, | ||
DialectAsmParser &parser) { | ||
if (parser.parseLess()) | ||
@@ -421,10 +422,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect, | ||
if (parser.parseComma() || | ||
spirv::parseEnumKeywordAttr(scope, parser, "scope <id>")) | ||
return Type(); | ||
+ MatrixUse matrixUse; | ||
+ if (parser.parseComma() || | ||
+ parseEnumKeywordAttr(matrixUse, parser, "matrixUse <id>")) | ||
+ return Type(); | ||
if (parser.parseGreater()) | ||
return Type(); | ||
return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1], | ||
- matrixLayout); | ||
+ matrixLayout, matrixUse); | ||
} | ||
|
||
// TODO: Reorder methods to be utilities first and parse*Type | ||
@@ -952,7 +957,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { | ||
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; | ||
os << type.getElementType() << ", " | ||
<< stringifyMatrixLayout(type.getMatrixLayout()); | ||
- os << ", " << stringifyScope(type.getScope()) << ">"; | ||
+ os << ", " << stringifyScope(type.getScope()) << ", " | ||
+ << stringifyMatrixUse(type.getMatrixUse()) << ">"; | ||
} | ||
|
||
static void print(MatrixType type, DialectAsmPrinter &os) { | ||
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
index 741d8069471d..49ded5c60951 100644 | ||
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
@@ -352,7 +352,8 @@ void CooperativeMatrixNVType::getCapabilities( | ||
//===----------------------------------------------------------------------===// | ||
|
||
struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { | ||
- using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>; | ||
+ using KeyTy = | ||
+ std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope, MatrixUse>; | ||
|
||
static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator, | ||
const KeyTy &key) { | ||
@@ -361,26 +362,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { | ||
} | ||
|
||
bool operator==(const KeyTy &key) const { | ||
- return key == KeyTy(elementType, rows, columns, matrixLayout, scope); | ||
+ return key == | ||
+ KeyTy(elementType, rows, columns, matrixLayout, scope, matrixUse); | ||
} | ||
|
||
JointMatrixTypeStorage(const KeyTy &key) | ||
: elementType(std::get<0>(key)), rows(std::get<1>(key)), | ||
- columns(std::get<2>(key)), scope(std::get<4>(key)), | ||
- matrixLayout(std::get<3>(key)) {} | ||
+ columns(std::get<2>(key)), matrixLayout(std::get<3>(key)), | ||
+ scope(std::get<4>(key)), matrixUse(std::get<5>(key)) {} | ||
|
||
Type elementType; | ||
unsigned rows; | ||
unsigned columns; | ||
Scope scope; | ||
MatrixLayout matrixLayout; | ||
+ MatrixUse matrixUse; | ||
}; | ||
|
||
JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope, | ||
unsigned rows, unsigned columns, | ||
- MatrixLayout matrixLayout) { | ||
+ MatrixLayout matrixLayout, | ||
+ MatrixUse matrixUse) { | ||
return Base::get(elementType.getContext(), elementType, rows, columns, | ||
- matrixLayout, scope); | ||
+ matrixLayout, scope, matrixUse); | ||
} | ||
|
||
Type JointMatrixINTELType::getElementType() const { | ||
@@ -397,6 +401,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const { | ||
return getImpl()->matrixLayout; | ||
} | ||
|
||
+MatrixUse JointMatrixINTELType::getMatrixUse() const { | ||
+ return getImpl()->matrixUse; | ||
+} | ||
+ | ||
void JointMatrixINTELType::getExtensions( | ||
SPIRVType::ExtensionArrayRefVector &extensions, | ||
std::optional<StorageClass> storage) { | ||
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | ||
index 90416289134b..4598dc608034 100644 | ||
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | ||
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | ||
@@ -939,7 +939,7 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) { | ||
|
||
LogicalResult | ||
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) { | ||
- if (operands.size() != 6) { | ||
+ if (operands.size() != 7) { | ||
return emitError(unknownLoc, "OpTypeJointMatrix must have element " | ||
"type and row x column parameters"); | ||
} | ||
@@ -949,7 +949,13 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) { | ||
return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ") | ||
<< operands[1]; | ||
} | ||
- | ||
+ auto matrixUse = | ||
+ spirv::symbolizeMatrixUse(getConstantInt(operands[6]).getInt()); | ||
+ if (!matrixUse) { | ||
+ return emitError(unknownLoc, | ||
+ "OpTypeJointMatrix references undefined Use <id> ") | ||
+ << operands[6]; | ||
+ } | ||
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt()); | ||
if (!scope) { | ||
return emitError(unknownLoc, | ||
@@ -960,14 +966,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) { | ||
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt()); | ||
if (!matrixLayout) { | ||
return emitError(unknownLoc, | ||
- "OpTypeJointMatrix references undefined scope <id> ") | ||
+ "OpTypeJointMatrix references undefined Layout <id> ") | ||
<< operands[4]; | ||
} | ||
unsigned rows = getConstantInt(operands[2]).getInt(); | ||
unsigned columns = getConstantInt(operands[3]).getInt(); | ||
|
||
- typeMap[operands[0]] = spirv::JointMatrixINTELType::get( | ||
- elementTy, scope.value(), rows, columns, matrixLayout.value()); | ||
+ typeMap[operands[0]] = | ||
+ spirv::JointMatrixINTELType::get(elementTy, scope.value(), rows, columns, | ||
+ matrixLayout.value(), matrixUse.value()); | ||
return success(); | ||
} | ||
|
||
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | ||
index 988e60d08edf..b6ec58648d72 100644 | ||
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | ||
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | ||
@@ -222,7 +222,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, | ||
case spirv::Decoration::LinkageAttributes: { | ||
// Get the value of the Linkage Attributes | ||
// e.g., LinkageAttributes=["linkageName", linkageType]. | ||
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue()); | ||
+ auto linkageAttr = | ||
+ llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue()); | ||
auto linkageName = linkageAttr.getLinkageName(); | ||
auto linkageType = linkageAttr.getLinkageType().getValue(); | ||
// Encode the Linkage Name (string literal to uint32_t). | ||
@@ -639,6 +640,8 @@ LogicalResult Serializer::prepareBasicType( | ||
static_cast<uint32_t>(jointMatrixType.getMatrixLayout()))); | ||
operands.push_back( | ||
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope()))); | ||
+ operands.push_back( | ||
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixUse()))); | ||
return success(); | ||
} | ||
|
||
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | ||
index ccf4240f8e56..a793564e0477 100644 | ||
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | ||
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | ||
@@ -396,10 +396,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, | ||
avail.getMergeInstanceType(), avail.getQueryFnName(), | ||
enumName); | ||
|
||
- os << formatv( | ||
- " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" | ||
- " && \"cannot have more than one bit set\");\n", | ||
- underlyingType); | ||
+ os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" | ||
+ " && \"cannot have more than one bit set\");\n", | ||
+ underlyingType); | ||
|
||
os << " switch (value) {\n"; | ||
for (const auto &caseSpecPair : classCasePair.getValue()) { | ||
@@ -523,7 +522,8 @@ static void emitAttributeSerialization(const Attribute &attr, | ||
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName); | ||
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || | ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || | ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { | ||
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" || | ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") { | ||
// These two enums are encoded as <id> to constant values in SPIR-V blob, | ||
// but we directly use the constant value as attribute in SPIR-V dialect. So | ||
// need to handle them separately from normal enum attributes. | ||
@@ -818,7 +818,8 @@ static void emitAttributeDeserialization(const Attribute &attr, | ||
raw_ostream &os) { | ||
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" || | ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" || | ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") { | ||
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" || | ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") { | ||
// These two enums are encoded as <id> to constant values in SPIR-V blob, | ||
// but we directly use the constant value as attribute in SPIR-V dialect. So | ||
// need to handle them separately from normal enum attributes. | ||
@@ -926,7 +927,8 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, | ||
// Process operands/attributes | ||
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { | ||
auto argument = op.getArg(i); | ||
- if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) { | ||
+ if (auto *valueArg = | ||
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) { | ||
if (valueArg->isVariableLength()) { | ||
if (i != e - 1) { | ||
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or " | ||
-- | ||
2.34.1 | ||
|
Oops, something went wrong.