From 4c67d4b452ee56fb8788f638a74458b8dadf6b2f Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Tue, 24 Dec 2024 00:32:08 -0800 Subject: [PATCH] Brgemm pattern matching --- include/TPP/Dialect/Xsmm/XsmmUtils.h | 48 ++ include/TPP/Transforms/Utils/VNNIUtils.h | 4 +- .../ConvertVectorToXsmm.cpp | 178 ++++++- .../ConvertVectorToXsmmPDL.pdll | 86 +++- lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 477 +++++++++++++++++- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 78 +++ 6 files changed, 813 insertions(+), 58 deletions(-) diff --git a/include/TPP/Dialect/Xsmm/XsmmUtils.h b/include/TPP/Dialect/Xsmm/XsmmUtils.h index 65784ba4b..ad13b1c5f 100644 --- a/include/TPP/Dialect/Xsmm/XsmmUtils.h +++ b/include/TPP/Dialect/Xsmm/XsmmUtils.h @@ -22,12 +22,33 @@ class Operation; class PatternRewriter; class VectorType; class MemRefType; +class ModuleOp; namespace func { class CallOp; } +namespace vector { +class ContractionOp; +} + namespace xsmm { + +struct BrgemmInfo { + int64_t m; + int64_t n; + int64_t k; + int64_t batch; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isVnni = false; +}; + class UnaryKindAttr; struct UnaryInfo { @@ -89,8 +110,14 @@ FailureOr getUnaryFlags(Type inputType, Type outputType); // Compute the broadcasting flags for 'operandType' based on 'outputType'. enum class OperandPos { LHS = 0, RHS = 1 }; +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber); FailureOr getBinaryFlags(Type operandType, Type outputType, OperandPos operandNumber); +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber); FailureOr getFusedBrgemmSequenceFromProducer(Operation *op); @@ -98,10 +125,26 @@ ArrayAttr getUnaryDispatchFlags(UnaryOp op); ArrayAttr getBinaryDispatchFlags(BinaryOp op); +int64_t getOredFlags(ArrayAttr flags); + +func::CallOp buildInvokeCall(RewriterBase &rewriter, Operation *parentOp, + ModuleOp module, SmallVector inputOperands, + SmallVector prependValues, int prependIndex, + SmallVector operands, StringRef invokeName, + DataTypeAttr dtype, bool getResults = false); + template FailureOr> getBrgemmFlags(PatternRewriter &rewriter, DispatchOpTy dispatchOpTy, bool returnNone); +std::optional +getPosInCodomain(unsigned dim, vector::ContractionOp contractOp, AffineMap map); +FailureOr checkAccess(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + unsigned m, unsigned n, unsigned k, + std::optional batchPos, + SmallVector inputs, + ArrayRef indexingMap); SmallVector extractOperandTypes(OpBuilder &builder, ArrayRef operands); @@ -122,6 +165,11 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType, SmallVector operands, TypeRange results, FlatSymbolRefAttr fnName, Operation *parentOp, Operation *insertBefore); +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap); } // namespace utils } // namespace xsmm } // namespace mlir diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 562bdb088..bda252367 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -9,6 +9,7 @@ #ifndef TPP_TRANSFORMS_UTILS_VNNIUTILS_H #define TPP_TRANSFORMS_UTILS_VNNIUTILS_H +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Support/LogicalResult.h" #include #include @@ -51,7 +52,8 @@ bool isInVnniLayout(int64_t expectedRank, VectorType vector); FailureOr isInVnniLayout(linalg::GenericOp linalgOp, AffineMap affineMap, int64_t blockingFactor); - +FailureOr isInVnniLayout(mlir::vector::ContractionOp contractOp, + int64_t blockingFactor); } // namespace utils } // namespace vnni } // namespace mlir diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index 04c1f6fdb..a3665e6a2 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.h" #include "TPP/Dialect/Xsmm/XsmmUtils.h" +#include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/TransformUtils.h" #include "TPP/Transforms/Utils/VNNIUtils.h" #include "TPP/Transforms/Utils/ValueUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -68,9 +70,9 @@ getUnaryXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo, } static std::pair -convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, - Operation *input, Operation *output, Type outputType) { - LLVM_DEBUG(llvm::dbgs() << "convertTransposeOp\n"); +convertTranspose(PatternRewriter &rewriter, Operation *transposeOp, + Operation *input, Operation *output, Type outputType) { + LLVM_DEBUG(llvm::dbgs() << "convertTranspose\n"); VectorType outType = cast(outputType); xsmm::UnaryKind opType; xsmm::UnaryInfo unaryInfo; @@ -111,11 +113,10 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, output, transposeOp, unaryFlags.getInt()); } -static LogicalResult validateTransposeOp(PatternRewriter &rewriter, - Operation *transposeOp, - Operation *input, Operation *output, - Type outputType) { - LLVM_DEBUG(llvm::dbgs() << "validateTransposeOp\n"); +static LogicalResult validateTranspose(PatternRewriter &rewriter, + Operation *transposeOp, Operation *input, + Operation *output, Type outputType) { + LLVM_DEBUG(llvm::dbgs() << "validateTranspose\n"); Value result = input->getResult(0); Value source = input->getOperand(0); VectorType outType = cast(outputType); @@ -187,10 +188,10 @@ convertBroadcast(PatternRewriter &rewriter, Operation *broadcastOp, broadcastOp, unaryFlags.getInt()); } -static LogicalResult validateBroadcastOp(PatternRewriter &rewriter, - Operation *broadcastOp, - Operation *input, Operation *output) { - LLVM_DEBUG(llvm::dbgs() << "validateBroadcastOp\n"); +static LogicalResult validateBroadcast(PatternRewriter &rewriter, + Operation *broadcastOp, Operation *input, + Operation *output) { + LLVM_DEBUG(llvm::dbgs() << "validateBroadcast\n"); auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(), output->getOperand(1).getType()); auto inputMemRefType = dyn_cast(input->getOperand(0).getType()); @@ -205,15 +206,162 @@ static LogicalResult validateBroadcastOp(PatternRewriter &rewriter, return success(); } +FailureOr +computeBrgemmInfo(PatternRewriter &rewriter, Operation *contractOp, + Operation *input0, Operation *input1, Operation *input2) { + SmallVector inputs; + LLVM_DEBUG(llvm::dbgs() << "computeBrgemminfo\n"); + + inputs.push_back(input0->getResult(0)); + inputs.push_back(input1->getResult(0)); + inputs.push_back(input2->getResult(0)); + + SmallVector outputs; + outputs.push_back(nullptr); + auto failedOrbrgemmInfo = mlir::xsmm::utils::isMappableToBrgemm( + rewriter, dyn_cast(contractOp), inputs, + outputs, + dyn_cast(contractOp).getIndexingMapsArray()); + if (failed(failedOrbrgemmInfo)) + return failure(); + xsmm::BrgemmInfo brgemmInfo = *failedOrbrgemmInfo; + return brgemmInfo; +} + +static std::pair +createBrgemmImpl(PatternRewriter &rewriter, Operation *contractOp, Value input0, + Value input1, Value input2, xsmm::BrgemmInfo brgemmInfo, + SmallVector flags) { + SmallVector inputs; + inputs.push_back(input0); + inputs.push_back(input1); + inputs.push_back(input2); + auto m = brgemmInfo.m; + auto n = brgemmInfo.n; + auto k = brgemmInfo.k; + auto batch = brgemmInfo.batch; + int64_t lda = brgemmInfo.lda; + int64_t ldb = brgemmInfo.ldb; + int64_t ldc = brgemmInfo.ldc; + int64_t strideA = brgemmInfo.strideA; + int64_t strideB = brgemmInfo.strideB; + auto loc = contractOp->getLoc(); + auto dtype = + xsmm::utils::getDataType(rewriter, contractOp->getOperand(0).getType()); + SmallVector dispatchOperands; + // Dispatch the data type. + dispatchOperands.push_back(dyn_cast(dtype).getInt()); + + ArrayAttr brgemmFlags = rewriter.getArrayAttr(flags); + SmallVector invokeOperands; + std::string dispatchName = "xsmm_gemm_dispatch"; + std::string invokeName = "xsmm_gemm_invoke"; + + if (batch != 0) { + dispatchName = "xsmm_brgemm_dispatch"; + invokeName = "xsmm_brgemm_invoke"; + } + + dispatchOperands.append( + SmallVector{m, n, k, lda, ldb, ldc}); + if (batch != 0) { + dispatchOperands.push_back(strideA); + dispatchOperands.push_back(strideB); + } + int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags); + dispatchOperands.push_back(oredFlag); + + auto dispatchCall = xsmm::utils::buildXsmmCall( + rewriter, xsmm::utils::XsmmCallType::DISPATCH, loc, dtype, + dispatchOperands, IntegerType::get(rewriter.getContext(), 64), + SymbolRefAttr::get(contractOp->getContext(), dispatchName), contractOp, + nullptr); + + SmallVector operandRange{ + dyn_cast(dtype).getInt(), + xsmm::utils::XsmmCall{xsmm::utils::XsmmCallType::DISPATCH, + dispatchCall.getResult(0)}, + input0.getDefiningOp()->getOperand(0), + input1.getDefiningOp()->getOperand(0), + input2.getDefiningOp()->getOperand(0)}; + + if (batch != 0) { + operandRange.push_back(batch); + } + auto invokeCall = xsmm::utils::buildXsmmCall( + rewriter, xsmm::utils::XsmmCallType::INVOKE, loc, dtype, operandRange, + TypeRange(), SymbolRefAttr::get(contractOp->getContext(), invokeName), + contractOp, input2.getDefiningOp()); + return std::make_pair(&*dispatchCall, &*invokeCall); +} + +static std::pair +createBrgemmWithBetaZero(PatternRewriter &rewriter, Operation *contractOp, + Operation *input0, Operation *input1, + Operation *input2, Operation *betaZero) { + LLVM_DEBUG(llvm::dbgs() << "createBrgemmWithBetaZero\n"); + auto brgemmInfo = + computeBrgemmInfo(rewriter, contractOp, input0, input1, input2); + SmallVector flags; + if (brgemmInfo->isVnni) { + flags.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), + xsmm::GemmFlags::VNNI_B)); + } + flags.push_back( + xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::BETA_0)); + + return createBrgemmImpl(rewriter, contractOp, input0->getResult(0), + input1->getResult(0), input2->getResult(0), + *brgemmInfo, flags); +} + +static std::pair +createBrgemm(PatternRewriter &rewriter, Operation *contractOp, + Operation *input0, Operation *input1, Operation *input2) { + LLVM_DEBUG(llvm::dbgs() << "createBrgemm\n"); + FailureOr brgemmInfo; + brgemmInfo = computeBrgemmInfo(rewriter, contractOp, input0, input1, input2); + SmallVector flags; + if (brgemmInfo->isVnni) { + flags.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), + xsmm::GemmFlags::VNNI_B)); + } + + return createBrgemmImpl(rewriter, contractOp, input0->getResult(0), + input1->getResult(0), input2->getResult(0), + *brgemmInfo, flags); +} + +static LogicalResult validateBrgemm(PatternRewriter &rewriter, + Operation *contractOp, Operation *input0, + Operation *input1, Operation *input2, + Operation *result) { + LLVM_DEBUG(llvm::dbgs() << "validateBrgemm\n"); + FailureOr brgemmInfo = + computeBrgemmInfo(rewriter, contractOp, input0, input1, input2); + + if (failed(brgemmInfo)) { + return failure(contractOp); + } + + return success(contractOp); +} + static void registerNativeRewrite(RewritePatternSet &patterns) { patterns.getPDLPatterns().registerRewriteFunction("ConvertTranspose", - convertTransposeOp); + convertTranspose); patterns.getPDLPatterns().registerConstraintFunction("ValidateTranspose", - validateTransposeOp); + validateTranspose); patterns.getPDLPatterns().registerRewriteFunction("ConvertBroadcast", convertBroadcast); patterns.getPDLPatterns().registerConstraintFunction("ValidateBroadcast", - validateBroadcastOp); + validateBroadcast); + patterns.getPDLPatterns().registerRewriteFunction("CreateBrgemmWithBetaZero", + createBrgemmWithBetaZero); + patterns.getPDLPatterns().registerRewriteFunction("CreateBrgemm", + createBrgemm); + patterns.getPDLPatterns().registerConstraintFunction("ValidateBrgemm", + validateBrgemm); } namespace mlir { diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll index 004bc6cff..0f1fdb9bf 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmmPDL.pdll @@ -1,40 +1,74 @@ +#include "mlir/Dialect/Arith/IR/ArithOps.td" #include "mlir/Dialect/Func/IR/FuncOps.td" #include "mlir/Dialect/Vector/IR/VectorOps.td" -#include "mlir/Dialect/Arith/IR/ArithOps.td" -#include "mlir/IR/OpBase.td" #include "mlir/IR/BuiltinTypes.td" +#include "mlir/IR/OpBase.td" + +Rewrite ConvertTranspose(op: Op, input0: Op, output: Op, outputType: TypeRange) ->(dispatch: Op, invoke: Op); +Constraint ValidateTranspose(op: Op, input0: Op, output: Op, outputType: TypeRange); -Rewrite ConvertTranspose(op:Op, input0:Op, output: Op, outputType:TypeRange)->(dispatch:Op, invoke:Op); +Rewrite ConvertBroadcast(op: Op, input: Op, output: Op)->(dispatch: Op, invoke: Op); -Constraint ValidateTranspose(op:Op, input0:Op, output:Op, outputType:TypeRange); +Constraint ValidateBroadcast(op: Op, input0: Op, output: Op); +Constraint ValidateBrgemm(op: Op, input0: Op<>, input1: Op<>, input2: Op<>, output: Op<>); -Rewrite ConvertBroadcast(op:Op, input:Op, output:Op)->(dispatch:Op, invoke:Op); +Rewrite CreateBrgemmWithBetaZero(op: Op, input0: Op, input1: Op, input2: Op, betaZero: Op)->(dispatch: Op, invoke: Op); -Constraint ValidateBroadcast(op:Op, input0:Op, output:Op); +Rewrite CreateBrgemm(op: Op, input0: Op, input1: Op, input2: Op)->(dispatch: Op, invoke: Op); -Pattern ConvertTransposePattern{ - let input0 = op(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange); - let transpose = op(input0)->(transposeOutput0:Type); - let output0 = op(transpose, alloc1:Value, outindices:ValueRange, constIndex2:ValueRange)->(typeRange:TypeRange); - ValidateTranspose(transpose, input0, output0, transposeOutput0); - rewrite transpose with{ - let replacement = ConvertTranspose(transpose, input0, output0, transposeOutput0); - replace transpose with (replacement.dispatch, replacement.invoke); - erase output0; - }; +Pattern ConvertContractToBrgemmWithBetaZero { + let input0 = op(alloc0: Value, indices0 : ValueRange, const0: Value, constIndex0: ValueRange) ->(output0: TypeRange); + let input1 = op(alloc1: Value, indices1: ValueRange, const0, constIndex1: ValueRange)->(output1: TypeRange); + let input2 = op(alloc2: Value, indices2:ValueRange, const0, constIndex2: ValueRange)->(output2: TypeRange); + let cst = op()->(constantVector : AnyVectorOfNonZeroRank); + let betaZero = op(cst, alloc2, input3: ValueRange, input4: ValueRange); + let root = op(input0, input1, input2)->(output : TypeRange); + let contractOutput = op(root, alloc2, outIndices: ValueRange, outBounds: ValueRange); + ValidateBrgemm(root, input0, input1, input2, contractOutput); + + rewrite root with { + let replacement = CreateBrgemmWithBetaZero(root, input0, input1, input2, contractOutput); + replace root with(replacement.dispatch, replacement.invoke); + erase contractOutput; + erase betaZero; + }; +} + +Pattern ConvertContractToBrgemm { + let input0 = op(alloc0: Value, indices0: ValueRange, const0: Value, constIndex0: ValueRange)->(output0: TypeRange); + let input1 = op(alloc1: Value, indices1: ValueRange, const1: Value, constIndex1: ValueRange)->(output1: TypeRange); + let input2 = op(alloc2: Value, indices2: ValueRange, const2: Value, constIndex2: ValueRange)->(output2:TypeRange); + let root = op(input0, input1, input2)->(output : TypeRange); + let rootOutput = op(root, alloc2, outindices: ValueRange, constIndex3: ValueRange); + ValidateBrgemm(root, input0, input1, input2, rootOutput); + rewrite root with { + let replacement = CreateBrgemm(root, input0, input1, input2); + erase rootOutput; + }; } -Pattern ConvertBroadcastPattern{ - let input0 = op(alloc0:Value, indices0:ValueRange, const0:Value, constIndex:ValueRange)->(output:TypeRange); - let broadcast = op(input0)->(broadcastOutput0:Type); - let output0 = op(broadcast, alloc1:Value, outindices:ValueRange, constIndex2:ValueRange)->(typeRange:TypeRange); - ValidateBroadcast(broadcast, input0, output0); - rewrite broadcast with{ - let replacement = ConvertBroadcast(broadcast, input0, output0); - replace broadcast with (replacement.dispatch, replacement.invoke); - erase output0; - }; +Pattern ConvertTransposePattern { + let input0 = op(alloc0: Value, indices0: ValueRange, const0: Value, constIndex: ValueRange)->(output : TypeRange); + let transpose = op(input0)->(transposeOutput0 : Type); + let output0 = op(transpose, alloc1 : Value, outindices : ValueRange, constIndex2 : ValueRange)->(typeRange : TypeRange); + ValidateTranspose(transpose, input0, output0, transposeOutput0); + rewrite transpose with { + let replacement = ConvertTranspose(transpose, input0, output0, transposeOutput0); + replace transpose with(replacement.dispatch, replacement.invoke); + erase output0; + }; +} +Pattern ConvertBroadcastPattern { + let input0 = op(alloc0: Value, indices0: ValueRange, const0: Value, constIndex: ValueRange)->(output: TypeRange); + let broadcast = op(input0)->(broadcastOutput0 : Type); + let output0 = op(broadcast, alloc1: Value, outindices: ValueRange, constIndex2: ValueRange)->(typeRange: TypeRange); + ValidateBroadcast(broadcast, input0, output0); + rewrite broadcast with { + let replacement = ConvertBroadcast(broadcast, input0, output0); + replace broadcast with(replacement.dispatch, replacement.invoke); + erase output0; + }; } diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 70f915329..69a6ccf8b 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -12,6 +12,7 @@ #include "TPP/Transforms/Utils/ValueUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -23,10 +24,135 @@ #include "llvm/Support/Debug.h" #define DEBUG_TYPE "xsmm-utils" +using namespace mlir::linalg; + namespace mlir { namespace xsmm { namespace utils { +// Callable object to verify if `operand` has static shape. +struct HasStaticShape { + SmallVectorImpl *shape = nullptr; + HasStaticShape() = default; + HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + if (auto shapedType = dyn_cast_or_null(operandType)) { + if (!shapedType.hasStaticShape()) + return false; + if (shape) { + for (int64_t shapeOnDim : shapedType.getShape()) + shape->push_back(shapeOnDim); + } + } + return true; + } +}; + +// Callable object to verify if `operand` has static strides. +// If `operand` is a tensor type or a scalar, return true. +struct HasStaticStrides { + SmallVectorImpl *strides = nullptr; + HasStaticStrides() = default; + HasStaticStrides(SmallVector *strides) : strides(strides){}; + + bool operator()(Value operand, Operation *op) const { + auto operandType = operand.getType(); + SmallVector strides; + if (auto memRefType = dyn_cast_or_null(operandType)) { + int64_t offset; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return false; + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return false; + } + if (this->strides) + this->strides->append(strides.begin(), strides.end()); + } + return true; + } +}; + +// Structural matcher. +static FailureOr +checkStructure(vector::ContractionOp contractOp, SmallVector &inputs, + SmallVector &outputs, ArrayRef indexingMap) { + if (!HasStaticShape()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticShape()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticShape()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticShape()(outputs[0], outputs[0].getDefiningOp())) || + !HasStaticStrides()(inputs[0], inputs[0].getDefiningOp()) || + !HasStaticStrides()(inputs[1], inputs[1].getDefiningOp()) || + !HasStaticStrides()(inputs[2], inputs[2].getDefiningOp()) || + (outputs[0] != nullptr && + !HasStaticStrides()(outputs[0], outputs[0].getDefiningOp()))) { + return failure(); + } + + return inferContractionDims(indexingMap); +} + +// Return the position of `dim` in the codomain of `operand`. +std::optional getPosInCodomain(unsigned dim, + vector::ContractionOp contractOp, + AffineMap map) { + return map.getResultPosition(getAffineDimExpr(dim, contractOp.getContext())); +} + +static SmallVector +createFlatListOfOperandStaticDims(vector::ContractionOp contractOp) { + SmallVector res; + for (unsigned op = 0; op < contractOp.getOperation()->getNumOperands(); + op++) { + Value operand = contractOp.getOperation()->getOperand(op); + llvm::append_range(res, dyn_cast(operand.getType()).getShape()); + } + return res; +} + +static SmallVector +computeStaticLoopSizes(vector::ContractionOp contractOp, + ArrayRef maps) { + AffineMap map = concatAffineMaps(maps, contractOp.getContext()); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector res(numDims, 0); + auto allShapeSizes = createFlatListOfOperandStaticDims(contractOp); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = dyn_cast(result)) { + res[d.getPosition()] = allShapeSizes[idx]; + } + } + return res; +} + +static FailureOr> +getVNNIStaticStrides(MemRefType valueType) { + SmallVector strides; + int64_t offset; + SmallVector shape; + for (size_t i = 0; i < valueType.getShape().size(); i++) { + shape.push_back(valueType.getShape()[i]); + } + auto temp = shape[shape.size() - 1]; + shape[shape.size() - 1] = shape[shape.size() - 2]; + shape[shape.size() - 2] = temp; + auto memrefType = MemRefType::get(shape, valueType.getElementType()); + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + // Examples: // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. @@ -287,22 +413,9 @@ FailureOr getUnaryFlags(Type inputType, Type outputType) { return failure(); } -FailureOr getBinaryFlags(Type operandType, Type outputType, - OperandPos operandNumber) { - assert(isa(outputType) && "expect shaped type on output"); - assert(cast(outputType).getRank() == 2 && - "expect rank 2 on output"); - - if (!isa(operandType) || - cast(operandType).getRank() == 0) { - if (operandNumber == OperandPos::LHS) - return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; - return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; - } - - enum class BCastType { NONE = 0, SCALAR, ROW, COL }; - auto shapeOutput = cast(outputType).getShape(); - auto shapeOperand = cast(operandType).getShape(); +FailureOr getBinFlags(ArrayRef shapeOutput, + ArrayRef shapeOperand, + OperandPos operandNumber) { assert(shapeOutput.size() >= shapeOperand.size() && "Output rank must be >= operand rank"); SmallVector bOperandShape; @@ -310,6 +423,7 @@ FailureOr getBinaryFlags(Type operandType, Type outputType, assert(shapeOutput.size() == bOperandShape.size()); assert(shapeOutput.size() == 2); + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; auto getBCastEnum = [](BCastType bCastType, OperandPos operandPos) -> xsmm::BinaryFlags { switch (bCastType) { @@ -350,9 +464,70 @@ FailureOr getBinaryFlags(Type operandType, Type outputType, return failure(); } +FailureOr getBinaryFlags(Type operandType, Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getBinaryFlagsVectorType(Type operandType, + Type outputType, + OperandPos operandNumber) { + assert(isa(outputType) && "expect shaped type on output"); + assert(cast(outputType).getRank() == 2 && + "expect rank 2 on output"); + + if (!isa(operandType) || + cast(operandType).getRank() == 0) { + if (operandNumber == OperandPos::LHS) + return xsmm::BinaryFlags::BCAST_SCALAR_IN_0; + return xsmm::BinaryFlags::BCAST_SCALAR_IN_1; + } + + enum class BCastType { NONE = 0, SCALAR, ROW, COL }; + auto shapeOutput = cast(outputType).getShape(); + auto shapeOperand = cast(operandType).getShape(); + return getBinFlags(shapeOutput, shapeOperand, operandNumber); +} + +FailureOr getLeadingDim(Type type, size_t pos) { + // Not shaped type, the leading dimension is the single scalar. + auto memref = dyn_cast(type); + if (!memref) + return 1; + // For 1d memref we cannot use the stride as leading dimension, but the + // leading dimension is the dimension itself. + if (memref.getRank() == 1) + return memref.getShape()[0]; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memref, strides, offset))) + return failure(); + // fail if the strides are non-constant + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) + return failure(); + return strides[pos]; +} SmallVector extractOperandTypes(OpBuilder &builder, SmallVector operands) { + SmallVector results; for (XsmmOperand operand : operands) { if (std::holds_alternative(operand)) @@ -582,6 +757,276 @@ template FailureOr> getBrgemmFlags( PatternRewriter &rewriter, xsmm::FusedBrgemmDispatchOp dispatchOpTy, bool returnNone); + +// Access matcher. +FailureOr checkAccess(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + unsigned m, unsigned n, unsigned k, + std::optional batchPos, + SmallVector inputs, + ArrayRef indexingMap) { + Value operandA = inputs[0]; + Value operandB = inputs[1]; + Value operandC = inputs[2]; + + auto kPos = *xsmm::utils::getPosInCodomain(k, contractOp, indexingMap[0]); + auto checkStridesAndGetLdaAndBatch = + [&](int minorDim, int majorDim, Value operand, AffineMap indexingMap, + int operandIndex, std::optional batchPos, bool isVnni, + int vnniFactor) -> FailureOr> { + auto minorDimPosInCodomain = + xsmm::utils::getPosInCodomain(minorDim, contractOp, indexingMap); + auto majorDimPosInCodomain = + xsmm::utils::getPosInCodomain(majorDim, contractOp, indexingMap); + if (!minorDimPosInCodomain || !majorDimPosInCodomain) { + return failure(); + } + MemRefType type; + if (operand.getDefiningOp() != NULL) { + if (isa(operand.getDefiningOp()) || + isa(operand.getDefiningOp())) { + type = dyn_cast( + operand.getDefiningOp()->getResult(0).getType()); + } else if (isa(operand.getDefiningOp())) { + type = dyn_cast( + operand.getDefiningOp()->getOperand(0).getType()); + } else { + type = dyn_cast( + operand.getDefiningOp()->getOperand(0).getType()); + } + } else if (isa(operand.getType())) { + type = dyn_cast(operand.getType()); + } + + auto stride = 1; + if (batchPos && batchPos.value() >= 0) { + auto batchPosCodomainA = + getPosInCodomain(batchPos.value(), contractOp, indexingMap); + auto stridesOnA = ::mlir::utils::getStaticStrides(type); + if (succeeded(stridesOnA) && batchPosCodomainA) { + stride = (*stridesOnA)[*batchPosCodomainA]; + } + } + + FailureOr> stridesOnOperand; + if (isVnni && operandIndex == 1) { + stridesOnOperand = getVNNIStaticStrides(type); + } else { + stridesOnOperand = ::mlir::utils::getStaticStrides(type); + } + if (failed(stridesOnOperand) || + (!isVnni && (*stridesOnOperand)[*minorDimPosInCodomain] != 1)) { + return failure(); + } + + if (isVnni) { + if (operandIndex == 1) { + if (*majorDimPosInCodomain == (*stridesOnOperand).size() - 3) { + return std::make_pair( + (*stridesOnOperand)[*majorDimPosInCodomain] / vnniFactor, stride); + } + if (*majorDimPosInCodomain == (*stridesOnOperand).size() - 2) { + return std::make_pair((*stridesOnOperand)[*majorDimPosInCodomain] + 1, + stride); + } else if (*majorDimPosInCodomain == (*stridesOnOperand).size() - 1) { + return std::make_pair((long)vnniFactor, stride); + } + } + } else { + if (operandIndex == 0 && isVnni) { + if (*majorDimPosInCodomain == (*stridesOnOperand).size() - 2) { + return std::make_pair((long)vnniFactor, stride); + } else if (*majorDimPosInCodomain == (*stridesOnOperand).size() - 3) { + return std::make_pair( + (*stridesOnOperand)[*majorDimPosInCodomain] / vnniFactor, stride); + } + } + } + + return std::make_pair((*stridesOnOperand)[*majorDimPosInCodomain], stride); + }; + + auto vnniBlockingFactor = + vnni::utils::getVnniBlockingFactor(inputs[1].getType()); + bool isVnni = false; + auto vnniFactor = 1; + if (vnniBlockingFactor) { + vnniFactor = *vnniBlockingFactor; + isVnni = succeeded(vnni::utils::isInVnniLayout(contractOp, vnniFactor)); + } + + auto ldaVal = checkStridesAndGetLdaAndBatch(k, m, operandA, indexingMap[0], 0, + batchPos, isVnni, vnniFactor); + + if (failed(ldaVal)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute lda\n"); + return failure(); + } + auto lda = (*ldaVal).first; + auto strideA = (*ldaVal).second; + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "A: OK " + << lda << "\n"); + + auto ldbVal = checkStridesAndGetLdaAndBatch(n, k, operandB, indexingMap[1], 1, + batchPos, isVnni, vnniFactor); + + if (failed(ldbVal)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldb\n"); + return failure(); + } + auto ldb = (*ldbVal).first; + auto strideB = (*ldbVal).second; + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "B: OK " + << ldb << "\n"); + + // C(m, n) + auto ldcVal = checkStridesAndGetLdaAndBatch(n, m, operandC, indexingMap[2], 2, + batchPos, isVnni, vnniFactor); + if (failed(ldcVal)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to compute ldc\n"); + return failure(); + } + auto ldc = (*ldcVal).first; + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Strides on " + "C: OK " + << ldc << "\n"); + auto loops = computeStaticLoopSizes(contractOp, indexingMap); + int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; + auto loopsK = loops[k]; + if (isVnni && !batchVal && + dyn_cast(inputs[0].getType()).getRank() - 2 == kPos) { + loopsK *= vnniFactor; + } + + xsmm::BrgemmInfo info{loops[m], loops[n], loopsK, batchVal, lda, + ldb, ldc, strideA, strideB, isVnni}; + return info; +} +// Check if the given +// generic is mappable to a +// brgemm xsmm op. +// - It is a contraction, +// with: +// -- 1 m and 1 n and 2 k +// dimensions. +// -- m appears on the LHS +// and OUT but not in RHS. +// -- n appears on the RHS +// and OUT but not in LHS. +// -- k and k' appear on the +// RHS and LHS but not OUT. +// -- the stride of the +// minor dimension for A, k +// is 1. +// -- the stride of the +// minor dimension for B, n +// is 1. +// -- the stride of the +// minor dimension for C, n +// is 1. +FailureOr isMappableToBrgemm(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector &inputs, + SmallVector &output, + ArrayRef indexingMap) { + auto contractionDims = + checkStructure(contractOp, inputs, output, indexingMap); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] Failed " + "on " + "checkStructure" + "\n"); + return failure(); + } + unsigned m = contractionDims->m.back(); + unsigned n = contractionDims->n.back(); + SmallVector kVector; + std::optional batch; + auto pos = xsmm::utils::getPosInCodomain( + contractionDims->k[0], contractOp, contractOp.getIndexingMapsArray()[0]); + int prevPos = -1; + int prevIndex = -1; + int index = 0; + bool isVnni = vnni::utils::isInVnniLayout( + dyn_cast(inputs[1].getType()).getRank(), + dyn_cast(inputs[1].getType())); + + if (contractionDims->k.size() > 1) { + for (unsigned i = 1; i < contractionDims->k.size(); i++) { + auto posTwo = + xsmm::utils::getPosInCodomain(contractionDims->k[i], contractOp, + contractOp.getIndexingMapsArray()[0]); + if (*posTwo < *pos) { + prevPos = *pos; + prevIndex = index; + pos = posTwo; + index = i; + } else if (prevIndex == -1 || *posTwo < static_cast(prevPos)) { + prevPos = *posTwo; + prevIndex = i; + } + } + } + + unsigned k; + if (prevIndex == -1 || + (dyn_cast(inputs[0].getType()).getRank() - 1 == prevPos && + isVnni)) { + k = contractionDims->k[index]; + } else { + batch = contractionDims->k[index]; + k = contractionDims->k[prevIndex]; + } + + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] Candidate " + "dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] m: " + << m << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrge" + "mm] n: " + << n << "\n"); + if (batch) + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] batch: " + << batch << "\n"); + else + LLVM_DEBUG(llvm::dbgs() << "[isMappableToBr" + "gemm] no batch " + "dim\n"); + auto retval = + checkAccess(rewriter, contractOp, m, n, k, batch, inputs, indexingMap); + if (failed(retval)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to check access\n"); + return failure(); + } + return retval; +} + +int64_t getOredFlags(ArrayAttr flags) { + int64_t oredFlag = 0; + for (auto flag : flags) { + int64_t intAttr = dyn_cast(flag).getInt(); + // LIBXSMM is col-major, swap A and B flags. + if (auto gemmFlag = dyn_cast_or_null(flag)) { + if (gemmFlag.getValue() == GemmFlags::VNNI_A) + intAttr = static_cast(GemmFlags::VNNI_B); + if (gemmFlag.getValue() == GemmFlags::VNNI_B) + intAttr = static_cast(GemmFlags::VNNI_A); + } + oredFlag |= intAttr; + } + return oredFlag; +} + } // namespace utils } // namespace xsmm } // namespace mlir diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 226e515c8..6b5938db0 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -8,10 +8,12 @@ #include "TPP/Transforms/Utils/VNNIUtils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" +#include "llvm/ADT/SetOperations.h" #include "libxsmm.h" @@ -26,6 +28,23 @@ std::optional getVnniBlockingFactor(Type type) { return std::nullopt; } +static llvm::SmallDenseSet +findPermutationsIndexingOperand(AffineMap indexingMap, + ArrayRef iterators, + mlir::vector::IteratorType iter) { + assert(iterators.size() == indexingMap.getNumDims()); + llvm::SmallDenseSet res; + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = dyn_cast(e)) { + if (iterators[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} // Until we have a better way to express the VNNI layout (see: #563), it is up // to the callee to specify the expected rank in the VNNI layout as the rank // depends on the operations we are dealing with. @@ -80,6 +99,65 @@ bool isInVnniLayout(int64_t expectedRank, VectorType vector) { } return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); } +FailureOr isInVnniLayout(mlir::vector::ContractionOp contractOp, + int64_t blockingFactor) { + AffineMap map = contractOp.getIndexingMapsArray()[1]; + auto arrayAttr = contractOp.getIteratorTypes(); + SmallVector iteratorTypes; + for (auto attr : arrayAttr) { + iteratorTypes.push_back( + cast(attr).getValue()); + } + + int inputZeroRank = + dyn_cast(contractOp.getOperand(0).getType()).getRank(); + int inputOneRank = + dyn_cast(contractOp.getOperand(1).getType()).getRank(); + bool isVnni = + isInVnniLayout(inputZeroRank, dyn_cast( + contractOp.getOperand(0).getType())) && + isInVnniLayout(inputOneRank, + dyn_cast(contractOp.getOperand(1).getType())); + if (!isVnni) + return failure(); + ArrayRef results = map.getResults(); + + AffineExpr vnniDim = results.back(); + auto dimExpr = dyn_cast(vnniDim); + if (!dimExpr || iteratorTypes[dimExpr.getPosition()] != + mlir::vector::IteratorType::reduction) { + return failure(); + } + AffineExpr rhsCst; + for (auto result : results) { + rhsCst = result; + if (!rhsCst) + continue; + if (iteratorTypes[dyn_cast(rhsCst).getPosition()] != + mlir::vector::IteratorType::reduction) + continue; + } + + llvm::SmallDenseSet a = findPermutationsIndexingOperand( + contractOp.getIndexingMapsArray()[0], iteratorTypes, + vector::IteratorType::reduction); + llvm::SmallDenseSet b = findPermutationsIndexingOperand( + contractOp.getIndexingMapsArray()[1], iteratorTypes, + vector::IteratorType::reduction); + llvm::set_union(a, b); + + if (a.size() < 2) { + return failure(); + } + llvm::SmallDenseSet c = findPermutationsIndexingOperand( + contractOp.getIndexingMapsArray()[2], iteratorTypes, + vector::IteratorType::reduction); + if (!c.contains(*a.begin())) { + // GEMM + return failure(); + } + return dyn_cast(rhsCst); +} } // namespace utils } // namespace vnni