Skip to content

Commit

Permalink
Brgemm pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Dec 24, 2024
1 parent 8d0da95 commit 4c67d4b
Show file tree
Hide file tree
Showing 6 changed files with 813 additions and 58 deletions.
48 changes: 48 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,33 @@ class Operation;
class PatternRewriter;
class VectorType;
class MemRefType;
class ModuleOp;

namespace func {
class CallOp;
}

namespace vector {
class ContractionOp;
}

namespace xsmm {

struct BrgemmInfo {
int64_t m;
int64_t n;
int64_t k;
int64_t batch;

int64_t lda;
int64_t ldb;
int64_t ldc;
int64_t strideA;
int64_t strideB;

bool isVnni = false;
};

class UnaryKindAttr;

struct UnaryInfo {
Expand Down Expand Up @@ -89,19 +110,41 @@ FailureOr<UnaryFlags> getUnaryFlags(Type inputType, Type outputType);

// Compute the broadcasting flags for 'operandType' based on 'outputType'.
enum class OperandPos { LHS = 0, RHS = 1 };
FailureOr<BinaryFlags> getBinFlags(ArrayRef<int64_t> shapeOutput,
ArrayRef<int64_t> shapeOperand,
OperandPos operandNumber);
FailureOr<BinaryFlags> getBinaryFlags(Type operandType, Type outputType,
OperandPos operandNumber);
FailureOr<BinaryFlags> getBinaryFlagsVectorType(Type operandType,
Type outputType,
OperandPos operandNumber);

FailureOr<FusedMatch> getFusedBrgemmSequenceFromProducer(Operation *op);

ArrayAttr getUnaryDispatchFlags(UnaryOp op);

ArrayAttr getBinaryDispatchFlags(BinaryOp op);

int64_t getOredFlags(ArrayAttr flags);

func::CallOp buildInvokeCall(RewriterBase &rewriter, Operation *parentOp,
ModuleOp module, SmallVector<Value> inputOperands,
SmallVector<Value> prependValues, int prependIndex,
SmallVector<Value> operands, StringRef invokeName,
DataTypeAttr dtype, bool getResults = false);

template <typename DispatchOpTy>
FailureOr<SmallVector<Attribute>> getBrgemmFlags(PatternRewriter &rewriter,
DispatchOpTy dispatchOpTy,
bool returnNone);
std::optional<unsigned>
getPosInCodomain(unsigned dim, vector::ContractionOp contractOp, AffineMap map);
FailureOr<xsmm::BrgemmInfo> checkAccess(PatternRewriter &rewriter,
vector::ContractionOp contractOp,
unsigned m, unsigned n, unsigned k,
std::optional<unsigned> batchPos,
SmallVector<Value> inputs,
ArrayRef<AffineMap> indexingMap);

SmallVector<Type> extractOperandTypes(OpBuilder &builder,
ArrayRef<Value> operands);
Expand All @@ -122,6 +165,11 @@ func::CallOp buildXsmmCall(RewriterBase &rewriter, XsmmCallType callType,
SmallVector<XsmmOperand> operands, TypeRange results,
FlatSymbolRefAttr fnName, Operation *parentOp,
Operation *insertBefore);
FailureOr<BrgemmInfo> isMappableToBrgemm(PatternRewriter &rewriter,
vector::ContractionOp contractOp,
SmallVector<Value> &inputs,
SmallVector<Value> &output,
ArrayRef<AffineMap> indexingMap);
} // namespace utils
} // namespace xsmm
} // namespace mlir
Expand Down
4 changes: 3 additions & 1 deletion include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef TPP_TRANSFORMS_UTILS_VNNIUTILS_H
#define TPP_TRANSFORMS_UTILS_VNNIUTILS_H

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
#include <optional>
Expand Down Expand Up @@ -51,7 +52,8 @@ bool isInVnniLayout(int64_t expectedRank, VectorType vector);
FailureOr<AffineDimExpr> isInVnniLayout(linalg::GenericOp linalgOp,
AffineMap affineMap,
int64_t blockingFactor);

FailureOr<AffineDimExpr> isInVnniLayout(mlir::vector::ContractionOp contractOp,
int64_t blockingFactor);
} // namespace utils
} // namespace vnni
} // namespace mlir
Expand Down
178 changes: 163 additions & 15 deletions lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.h"
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -68,9 +70,9 @@ getUnaryXSMMCalls(PatternRewriter &rewriter, xsmm::UnaryInfo &unaryInfo,
}

static std::pair<Operation *, Operation *>
convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
Operation *input, Operation *output, Type outputType) {
LLVM_DEBUG(llvm::dbgs() << "convertTransposeOp\n");
convertTranspose(PatternRewriter &rewriter, Operation *transposeOp,
Operation *input, Operation *output, Type outputType) {
LLVM_DEBUG(llvm::dbgs() << "convertTranspose\n");
VectorType outType = cast<VectorType>(outputType);
xsmm::UnaryKind opType;
xsmm::UnaryInfo unaryInfo;
Expand Down Expand Up @@ -111,11 +113,10 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
output, transposeOp, unaryFlags.getInt());
}

static LogicalResult validateTransposeOp(PatternRewriter &rewriter,
Operation *transposeOp,
Operation *input, Operation *output,
Type outputType) {
LLVM_DEBUG(llvm::dbgs() << "validateTransposeOp\n");
static LogicalResult validateTranspose(PatternRewriter &rewriter,
Operation *transposeOp, Operation *input,
Operation *output, Type outputType) {
LLVM_DEBUG(llvm::dbgs() << "validateTranspose\n");
Value result = input->getResult(0);
Value source = input->getOperand(0);
VectorType outType = cast<VectorType>(outputType);
Expand Down Expand Up @@ -187,10 +188,10 @@ convertBroadcast(PatternRewriter &rewriter, Operation *broadcastOp,
broadcastOp, unaryFlags.getInt());
}

static LogicalResult validateBroadcastOp(PatternRewriter &rewriter,
Operation *broadcastOp,
Operation *input, Operation *output) {
LLVM_DEBUG(llvm::dbgs() << "validateBroadcastOp\n");
static LogicalResult validateBroadcast(PatternRewriter &rewriter,
Operation *broadcastOp, Operation *input,
Operation *output) {
LLVM_DEBUG(llvm::dbgs() << "validateBroadcast\n");
auto unaryFlag = xsmm::utils::getUnaryFlags(input->getOperand(0).getType(),
output->getOperand(1).getType());
auto inputMemRefType = dyn_cast<MemRefType>(input->getOperand(0).getType());
Expand All @@ -205,15 +206,162 @@ static LogicalResult validateBroadcastOp(PatternRewriter &rewriter,
return success();
}

FailureOr<xsmm::BrgemmInfo>
computeBrgemmInfo(PatternRewriter &rewriter, Operation *contractOp,
Operation *input0, Operation *input1, Operation *input2) {
SmallVector<Value> inputs;
LLVM_DEBUG(llvm::dbgs() << "computeBrgemminfo\n");

inputs.push_back(input0->getResult(0));
inputs.push_back(input1->getResult(0));
inputs.push_back(input2->getResult(0));

SmallVector<Value> outputs;
outputs.push_back(nullptr);
auto failedOrbrgemmInfo = mlir::xsmm::utils::isMappableToBrgemm(
rewriter, dyn_cast<mlir::vector::ContractionOp>(contractOp), inputs,
outputs,
dyn_cast<mlir::vector::ContractionOp>(contractOp).getIndexingMapsArray());
if (failed(failedOrbrgemmInfo))
return failure();
xsmm::BrgemmInfo brgemmInfo = *failedOrbrgemmInfo;
return brgemmInfo;
}

static std::pair<Operation *, Operation *>
createBrgemmImpl(PatternRewriter &rewriter, Operation *contractOp, Value input0,
Value input1, Value input2, xsmm::BrgemmInfo brgemmInfo,
SmallVector<Attribute> flags) {
SmallVector<Value> inputs;
inputs.push_back(input0);
inputs.push_back(input1);
inputs.push_back(input2);
auto m = brgemmInfo.m;
auto n = brgemmInfo.n;
auto k = brgemmInfo.k;
auto batch = brgemmInfo.batch;
int64_t lda = brgemmInfo.lda;
int64_t ldb = brgemmInfo.ldb;
int64_t ldc = brgemmInfo.ldc;
int64_t strideA = brgemmInfo.strideA;
int64_t strideB = brgemmInfo.strideB;
auto loc = contractOp->getLoc();
auto dtype =
xsmm::utils::getDataType(rewriter, contractOp->getOperand(0).getType());
SmallVector<xsmm::utils::XsmmOperand> dispatchOperands;
// Dispatch the data type.
dispatchOperands.push_back(dyn_cast<DataTypeAttr>(dtype).getInt());

ArrayAttr brgemmFlags = rewriter.getArrayAttr(flags);
SmallVector<Value, 10> invokeOperands;
std::string dispatchName = "xsmm_gemm_dispatch";
std::string invokeName = "xsmm_gemm_invoke";

if (batch != 0) {
dispatchName = "xsmm_brgemm_dispatch";
invokeName = "xsmm_brgemm_invoke";
}

dispatchOperands.append(
SmallVector<xsmm::utils::XsmmOperand>{m, n, k, lda, ldb, ldc});
if (batch != 0) {
dispatchOperands.push_back(strideA);
dispatchOperands.push_back(strideB);
}
int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags);
dispatchOperands.push_back(oredFlag);

auto dispatchCall = xsmm::utils::buildXsmmCall(
rewriter, xsmm::utils::XsmmCallType::DISPATCH, loc, dtype,
dispatchOperands, IntegerType::get(rewriter.getContext(), 64),
SymbolRefAttr::get(contractOp->getContext(), dispatchName), contractOp,
nullptr);

SmallVector<xsmm::utils::XsmmOperand> operandRange{
dyn_cast<DataTypeAttr>(dtype).getInt(),
xsmm::utils::XsmmCall{xsmm::utils::XsmmCallType::DISPATCH,
dispatchCall.getResult(0)},
input0.getDefiningOp()->getOperand(0),
input1.getDefiningOp()->getOperand(0),
input2.getDefiningOp()->getOperand(0)};

if (batch != 0) {
operandRange.push_back(batch);
}
auto invokeCall = xsmm::utils::buildXsmmCall(
rewriter, xsmm::utils::XsmmCallType::INVOKE, loc, dtype, operandRange,
TypeRange(), SymbolRefAttr::get(contractOp->getContext(), invokeName),
contractOp, input2.getDefiningOp());
return std::make_pair(&*dispatchCall, &*invokeCall);
}

static std::pair<Operation *, Operation *>
createBrgemmWithBetaZero(PatternRewriter &rewriter, Operation *contractOp,
Operation *input0, Operation *input1,
Operation *input2, Operation *betaZero) {
LLVM_DEBUG(llvm::dbgs() << "createBrgemmWithBetaZero\n");
auto brgemmInfo =
computeBrgemmInfo(rewriter, contractOp, input0, input1, input2);
SmallVector<Attribute> flags;
if (brgemmInfo->isVnni) {
flags.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::VNNI_B));
}
flags.push_back(
xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::BETA_0));

return createBrgemmImpl(rewriter, contractOp, input0->getResult(0),
input1->getResult(0), input2->getResult(0),
*brgemmInfo, flags);
}

static std::pair<Operation *, Operation *>
createBrgemm(PatternRewriter &rewriter, Operation *contractOp,
Operation *input0, Operation *input1, Operation *input2) {
LLVM_DEBUG(llvm::dbgs() << "createBrgemm\n");
FailureOr<mlir::xsmm::BrgemmInfo> brgemmInfo;
brgemmInfo = computeBrgemmInfo(rewriter, contractOp, input0, input1, input2);
SmallVector<Attribute> flags;
if (brgemmInfo->isVnni) {
flags.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::VNNI_B));
}

return createBrgemmImpl(rewriter, contractOp, input0->getResult(0),
input1->getResult(0), input2->getResult(0),
*brgemmInfo, flags);
}

static LogicalResult validateBrgemm(PatternRewriter &rewriter,
Operation *contractOp, Operation *input0,
Operation *input1, Operation *input2,
Operation *result) {
LLVM_DEBUG(llvm::dbgs() << "validateBrgemm\n");
FailureOr<xsmm::BrgemmInfo> brgemmInfo =
computeBrgemmInfo(rewriter, contractOp, input0, input1, input2);

if (failed(brgemmInfo)) {
return failure(contractOp);
}

return success(contractOp);
}

static void registerNativeRewrite(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerRewriteFunction("ConvertTranspose",
convertTransposeOp);
convertTranspose);
patterns.getPDLPatterns().registerConstraintFunction("ValidateTranspose",
validateTransposeOp);
validateTranspose);
patterns.getPDLPatterns().registerRewriteFunction("ConvertBroadcast",
convertBroadcast);
patterns.getPDLPatterns().registerConstraintFunction("ValidateBroadcast",
validateBroadcastOp);
validateBroadcast);
patterns.getPDLPatterns().registerRewriteFunction("CreateBrgemmWithBetaZero",
createBrgemmWithBetaZero);
patterns.getPDLPatterns().registerRewriteFunction("CreateBrgemm",
createBrgemm);
patterns.getPDLPatterns().registerConstraintFunction("ValidateBrgemm",
validateBrgemm);
}

namespace mlir {
Expand Down
Loading

0 comments on commit 4c67d4b

Please sign in to comment.