Skip to content

Commit

Permalink
[SYCL][NFCI] Finalize switch to SPV_KHR_cooperative_matrix (#16045)
Browse files Browse the repository at this point in the history
Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Nov 21, 2024
1 parent 925ff76 commit 3edd618
Show file tree
Hide file tree
Showing 95 changed files with 12 additions and 2,353 deletions.
77 changes: 1 addition & 76 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,34 +350,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
return ResultType;
}

template <bool NeedTypeInterpret = false>
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
// TODO: we should actually have exactly 5 template parameters: 1 for
// type and 4 for type parameters. But in previous version of the SPIR-V
// spec we have Layout matrix type parameter, that was later removed.
// Once we update to the newest version of the spec - this should be updated.
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
"Wrong JointMatrixINTEL template parameters number");
// This is required to represent optional 'Component Type Interpretation'
// parameter
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong JointMatrixINTEL template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}
// Don't add type interpretation for legacy matrices.
// Legacy matrices has 5 template parameters, while new representation
// has 6.
if (NeedTypeInterpret && TemplateArgs.size() != 5)
Params.push_back(Val);

return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

llvm::Type *
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs) {
Expand All @@ -394,49 +366,6 @@ getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st JointMatrixINTEL template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

// Per JointMatrixINTEL spec the type can have an optional
// 'Component Type Interpretation' parameter. We should emit it in case
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
// 'tf32' as 'float' types.
if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.starts_with("class.sycl::") ||
LlvmTyName.starts_with("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
// 'tf32' interpretation is mapped to '0'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
// 'bfloat16' interpretation is mapped to '1'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
Expand Down Expand Up @@ -733,11 +662,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
if (ClangETy && ClangETy->isStructureOrClassType()) {
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_CooperativeMatrixKHR") {
"__spv::__spirv_CooperativeMatrixKHR") {
ResultType = ConvertSPVCooperativeMatrixType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
Expand Down
8 changes: 0 additions & 8 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,6 @@ class CodeGenTypes {
/// load/store type are the same.
llvm::Type *convertTypeForLoadStore(QualType T, llvm::Type *LLVMTy = nullptr);

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
Expand Down
41 changes: 0 additions & 41 deletions clang/test/CodeGenSYCL/joint_matrix.cpp

This file was deleted.

150 changes: 0 additions & 150 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,155 +27,6 @@

extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
int32_t CoordY,
uint32_t Height,
uint32_t Width,
const T Value);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY,
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *
__spirv_JointMatrixMadINTEL(
__spv::__spirv_JointMatrixINTEL<TA, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<TB, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixUUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixUSMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixSUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);

template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -304,7 +155,6 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
Expand Down
10 changes: 0 additions & 10 deletions sycl/include/sycl/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
enum class MatrixOperands : uint32_t {
// SPV_KHR_cooperative_matrix operands
NoneKHR = 0,
Expand All @@ -133,19 +132,10 @@ enum class MatrixOperands : uint32_t {
MatrixCBFloat16ComponentsINTEL = 0x80,
MatrixResultBFloat16ComponentsINTEL = 0x100
};
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::MatrixA>
struct __spirv_JointMatrixINTEL;
#else
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
struct __spirv_CooperativeMatrixKHR;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

struct __spirv_TaskSequenceINTEL;

Expand Down
Loading

0 comments on commit 3edd618

Please sign in to comment.