Skip to content

Commit

Permalink
Add a JointMatrix test case and necessary patch
Browse files Browse the repository at this point in the history
This commit adds:
	- A JointMatrix test case
	- A patch that contains the updated definition of
	  JointMatrix spec supported by IGC. Current upstream
	  definition of JointMatrix is not supported by IGC
	  anymore.
  • Loading branch information
mshahneo authored and silee2 committed Sep 21, 2023
1 parent 8060e7f commit f826127
Show file tree
Hide file tree
Showing 2 changed files with 540 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
From 810eccbbbb872402391a0f01a53aaf0205ea10c4 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <[email protected]>
Date: Tue, 8 Aug 2023 00:28:31 +0000
Subject: [PATCH] Update the Joint Matrix support to match Spec supported by
IGC

Update the Joint Matrix support to match the following spec:
https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 30 ++++++++++++++-----
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 +++-
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 12 ++++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 20 +++++++++----
.../SPIRV/Deserialization/Deserializer.cpp | 17 +++++++----
.../Target/SPIRV/Serialization/Serializer.cpp | 5 +++-
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 +++++-----
7 files changed, 75 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6f0f728f811e..c2ad6ff24bea 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4039,16 +4039,30 @@ def SPIRV_SamplerUseAttr: SPIRV_I32EnumAttr<
"image_sampler_use_info",
[SPIRV_ISUI_SamplerUnknown, SPIRV_ISUI_NeedSampler, SPIRV_ISUI_NoSampler]>;

-def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
-def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
-def SPIRV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
-def SPIRV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>;
-
-def SPIRV_MatrixLayoutAttr :
- SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
- SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB
+// Change the layout parameter to IGC spec, the currnet MLIR version
+// does not match the IGC spec, IGC spec has been updated
+// https://github.com/MrSidims/llvm/blob/private/MrSidims/add-matrix-use/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
+
+def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 0>;
+def SPIRV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
+def SPIRV_ML_Packed : I32EnumAttrCase<"Packed", 2>;
+def SPIRV_ML_Unused : I32EnumAttrCase<"Unused", 3>;
+
+ def SPIRV_MatrixLayoutAttr :
+ SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
+ SPIRV_ML_RowMajor, SPIRV_ML_ColumnMajor, SPIRV_ML_Packed, SPIRV_ML_Unused
]>;

+def SPIRV_ML_MATRIX_A : I32EnumAttrCase<"MatrixA", 0>;
+def SPIRV_ML_MATRIX_B : I32EnumAttrCase<"MatrixB", 1>;
+def SPIRV_ML_MATRIX_ACC : I32EnumAttrCase<"Accumulator", 2>;
+
+def SPIRV_MatrixUseAttr :
+ SPIRV_I32EnumAttr<"MatrixUse", "valid SPIR-V MatrixUse", "matrixUse", [
+ SPIRV_ML_MATRIX_A, SPIRV_ML_MATRIX_B, SPIRV_ML_MATRIX_ACC
+ ]>;
+
+
// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension.
def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 07f2f158ecab..e0b3c5448a44 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -459,7 +459,8 @@ public:
using Base::Base;

static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
- unsigned columns, MatrixLayout matrixLayout);
+ unsigned columns, MatrixLayout matrixLayout,
+ MatrixUse matrixUse);
Type getElementType() const;

/// Return the scope of the joint matrix.
@@ -472,6 +473,9 @@ public:
/// return the layout of the matrix
MatrixLayout getMatrixLayout() const;

+ /// return the use of the matrix
+ MatrixUse getMatrixUse() const;
+
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 9188f8b699b4..4c099bf77a88 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -394,7 +394,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,

// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
-// `,` layout `,` scope`>`
+// `,` layout `,` scope
+// `,` use`>`
static Type parseJointMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -421,10 +422,14 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
if (parser.parseComma() ||
spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
+ MatrixUse matrixUse;
+ if (parser.parseComma() ||
+ parseEnumKeywordAttr(matrixUse, parser, "matrixUse <id>"))
+ return Type();
if (parser.parseGreater())
return Type();
return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
- matrixLayout);
+ matrixLayout, matrixUse);
}

// TODO: Reorder methods to be utilities first and parse*Type
@@ -952,7 +957,8 @@ static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
<< stringifyMatrixLayout(type.getMatrixLayout());
- os << ", " << stringifyScope(type.getScope()) << ">";
+ os << ", " << stringifyScope(type.getScope()) << ", "
+ << stringifyMatrixUse(type.getMatrixUse()) << ">";
}

static void print(MatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 741d8069471d..49ded5c60951 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -352,7 +352,8 @@ void CooperativeMatrixNVType::getCapabilities(
//===----------------------------------------------------------------------===//

struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
+ using KeyTy =
+ std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope, MatrixUse>;

static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@@ -361,26 +362,29 @@ struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
}

bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
+ return key ==
+ KeyTy(elementType, rows, columns, matrixLayout, scope, matrixUse);
}

JointMatrixTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<4>(key)),
- matrixLayout(std::get<3>(key)) {}
+ columns(std::get<2>(key)), matrixLayout(std::get<3>(key)),
+ scope(std::get<4>(key)), matrixUse(std::get<5>(key)) {}

Type elementType;
unsigned rows;
unsigned columns;
Scope scope;
MatrixLayout matrixLayout;
+ MatrixUse matrixUse;
};

JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
unsigned rows, unsigned columns,
- MatrixLayout matrixLayout) {
+ MatrixLayout matrixLayout,
+ MatrixUse matrixUse) {
return Base::get(elementType.getContext(), elementType, rows, columns,
- matrixLayout, scope);
+ matrixLayout, scope, matrixUse);
}

Type JointMatrixINTELType::getElementType() const {
@@ -397,6 +401,10 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
return getImpl()->matrixLayout;
}

+MatrixUse JointMatrixINTELType::getMatrixUse() const {
+ return getImpl()->matrixUse;
+}
+
void JointMatrixINTELType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 90416289134b..4598dc608034 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -939,7 +939,7 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {

LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
- if (operands.size() != 6) {
+ if (operands.size() != 7) {
return emitError(unknownLoc, "OpTypeJointMatrix must have element "
"type and row x column parameters");
}
@@ -949,7 +949,13 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
<< operands[1];
}
-
+ auto matrixUse =
+ spirv::symbolizeMatrixUse(getConstantInt(operands[6]).getInt());
+ if (!matrixUse) {
+ return emitError(unknownLoc,
+ "OpTypeJointMatrix references undefined Use <id> ")
+ << operands[6];
+ }
auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
if (!scope) {
return emitError(unknownLoc,
@@ -960,14 +966,15 @@ spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
if (!matrixLayout) {
return emitError(unknownLoc,
- "OpTypeJointMatrix references undefined scope <id> ")
+ "OpTypeJointMatrix references undefined Layout <id> ")
<< operands[4];
}
unsigned rows = getConstantInt(operands[2]).getInt();
unsigned columns = getConstantInt(operands[3]).getInt();

- typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
- elementTy, scope.value(), rows, columns, matrixLayout.value());
+ typeMap[operands[0]] =
+ spirv::JointMatrixINTELType::get(elementTy, scope.value(), rows, columns,
+ matrixLayout.value(), matrixUse.value());
return success();
}

diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 988e60d08edf..b6ec58648d72 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -222,7 +222,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr =
+ llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -639,6 +640,8 @@ LogicalResult Serializer::prepareBasicType(
static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
operands.push_back(
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+ operands.push_back(
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixUse())));
return success();
}

diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index ccf4240f8e56..a793564e0477 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -396,10 +396,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
avail.getMergeInstanceType(), avail.getQueryFnName(),
enumName);

- os << formatv(
- " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
- " && \"cannot have more than one bit set\");\n",
- underlyingType);
+ os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
+ " && \"cannot have more than one bit set\");\n",
+ underlyingType);

os << " switch (value) {\n";
for (const auto &caseSpecPair : classCasePair.getValue()) {
@@ -523,7 +522,8 @@ static void emitAttributeSerialization(const Attribute &attr,
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@@ -818,7 +818,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
+ attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr" ||
+ attr.getAttrDefName() == "SPIRV_MatrixUseAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@@ -926,7 +927,8 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
// Process operands/attributes
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
- if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
+ if (auto *valueArg =
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
--
2.34.1

Loading

0 comments on commit f826127

Please sign in to comment.