diff --git a/include/gc/Dialect/Linalgx/CMakeLists.txt b/include/gc/Dialect/Linalgx/CMakeLists.txt index 1aceb8345..6090d833b 100644 --- a/include/gc/Dialect/Linalgx/CMakeLists.txt +++ b/include/gc/Dialect/Linalgx/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS LinalgxDialect.td) +mlir_tablegen(LinalgxOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=linalgx) +mlir_tablegen(LinalgxOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=linalgx) + add_mlir_dialect(LinalgxOps linalgx) set(LLVM_TARGET_DEFINITIONS LinalgxStructuredOps.td) mlir_tablegen(LinalgxStructuredOps.h.inc -gen-op-decls) diff --git a/include/gc/Dialect/Linalgx/LinalgxDialect.td b/include/gc/Dialect/Linalgx/LinalgxDialect.td index 1bb9521de..7c7e3f90d 100644 --- a/include/gc/Dialect/Linalgx/LinalgxDialect.td +++ b/include/gc/Dialect/Linalgx/LinalgxDialect.td @@ -10,6 +10,11 @@ #define LINALGX_DIALECT include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/CommonTypeConstraints.td" //===----------------------------------------------------------------------===// // Linalgx dialect definition. @@ -32,6 +37,7 @@ def LinalgxDialect : Dialect { "tensor::TensorDialect", ]; + let useDefaultAttributePrinterParser = 1; let extraClassDeclaration = [{ /// Attribute name used to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral @@ -47,4 +53,44 @@ def LinalgxDialect : Dialect { }]; } +class Linalgx_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def PackingMapAttr : Linalgx_Attr<"PackingMap", "packing_map"> { + let summary = "An Attribute containing a map between index of matmul input/output"; + let description = [{ + A map between index of matmul input/output. + }]; + + let cppNamespace = "::mlir::linalgx"; + let parameters = (ins ArrayRefParameter<"uint64_t">:$first, + ArrayRefParameter<"uint64_t">:$second); + + let assemblyFormat = "`<` `[` $first `]` `->` `[` $second `]` `>`"; + + let extraClassDeclaration = [{ + /// Index first is 0; Index second is 1 + unsigned getPackingSrcIndex() { + return getFirst().size() == 1 ? 0 : 1; + } + unsigned getPackingDstIndex() { + return getFirst().size() == 1 ? 1 : 0; + } + /// SrcDims.size() == 1; DstDims.size() >= 1 + ArrayRef getPackingSrcDims() { + return getPackingSrcIndex() == 0 ? getFirst() + : getSecond(); + } + ArrayRef getPackingDstDims() { + return getPackingDstIndex() == 0 ? getFirst() + : getSecond(); + } + }]; +} + +def PackingMapArrayAttr : TypedArrayAttrBase; + #endif // LINALGX_DIALECT diff --git a/include/gc/Dialect/Linalgx/LinalgxOps.h b/include/gc/Dialect/Linalgx/LinalgxOps.h index 9ea73a91e..c2a5e8eb8 100644 --- a/include/gc/Dialect/Linalgx/LinalgxOps.h +++ b/include/gc/Dialect/Linalgx/LinalgxOps.h @@ -19,6 +19,9 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#define GET_ATTRDEF_CLASSES +#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.h.inc" + #define GET_OP_CLASSES #include "gc/Dialect/Linalgx/LinalgxOps.h.inc" diff --git a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td index dee5eef74..9af766a12 100644 --- a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td +++ b/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td @@ -211,6 +211,62 @@ def Linalgx_Mm4DVnniOp }]; } +def Linalgx_PackedMatmulOp + : LinalgxStructuredBase_Op<"packed_matmul", [AttrSizedOperandSegments]> { + let summary = "matmul with packed data format"; + let description = [{ + Use m_packing, n_packing and k_packing to define relation shape between C[M, N] = A[M, K] * B[K, N]. + }]; + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + PackingMapArrayAttr:$m_packing, + PackingMapArrayAttr:$n_packing, + PackingMapArrayAttr:$k_packing + ); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins + "TypeRange":$resultTensorTypes, + "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildStructuredOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, PackedMatmulOp::getRegionBuilder()); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Declare functions necessary for LinalgStructuredInterface. + SmallVector getIteratorTypesArray(); + ArrayAttr getIndexingMaps(); + static unsigned getNumRegionArgs() { return 3; } + std::string getLibraryCallName() { + return "op_has_no_registered_library_name"; + } + + // Implement functions necessary for DestinationStyleOpInterface. + MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + }]; +} + def Linalgx_BatchReduceMatmulVnniOp : LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> { let summary = "Batch reduced matmul with 3d batch input and vnni packed weights"; diff --git a/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp b/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp index d2d7a4389..0689b0d4e 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp +++ b/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp @@ -16,12 +16,16 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/Parser/Parser.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::linalgx; #include "gc/Dialect/Linalgx/LinalgxOpsDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.cpp.inc" void LinalgxDialect::initialize() { addOperations< @@ -32,4 +36,8 @@ void LinalgxDialect::initialize() { #define GET_OP_LIST #include "gc/Dialect/Linalgx/LinalgxStructuredOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.cpp.inc" + >(); } diff --git a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp index 04eae3657..78eb47e6d 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp +++ b/lib/gc/Dialect/Linalgx/LinalgxOps.cpp @@ -365,6 +365,218 @@ LogicalResult Mm4DVnniOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PackedMatmulOp +//===----------------------------------------------------------------------===// + +SmallVector PackedMatmulOp::getIteratorTypesArray() { + SmallVector iteratorTypes; + // get packing num for each packing map + auto getPackingIteratorTypes = [&](ArrayAttr packingMaps, + utils::IteratorType iterTy) { + for (auto &attr : packingMaps) { + auto packingNum = + llvm::cast(attr).getPackingDstDims().size(); + iteratorTypes.insert(iteratorTypes.end(), packingNum, iterTy); + } + }; + // Process order: m, n, k packing + getPackingIteratorTypes(getMPacking(), utils::IteratorType::parallel); + getPackingIteratorTypes(getNPacking(), utils::IteratorType::parallel); + getPackingIteratorTypes(getKPacking(), utils::IteratorType::reduction); + return iteratorTypes; +} + +unsigned getPackingDimsExpr(PackedMatmulOp self, + SmallVector> &exprsArr) { + MLIRContext *context = self.getContext(); + auto typeA = cast(self.getDpsInputOperand(0)->get().getType()); + auto typeB = cast(self.getDpsInputOperand(1)->get().getType()); + auto typeC = cast(self.getDpsInitOperand(0)->get().getType()); + SmallVector exprsA(typeA.getRank()); + SmallVector exprsB(typeB.getRank()); + SmallVector exprsC(typeC.getRank()); + // dims count from 0 + unsigned dims = 0; + // + auto getPackingExprs = [&](ArrayAttr attrArray, ArrayRef types, + ArrayRef *> exprs) { + for (auto &attr : attrArray) { + auto packingMap = cast(attr); + auto srcIndex = packingMap.getPackingSrcIndex(); + auto dstIndex = packingMap.getPackingDstIndex(); + auto srcDims = packingMap.getPackingSrcDims(); + auto dstDims = packingMap.getPackingDstDims(); + auto &dstExprs = *exprs[dstIndex]; + auto &srcExprs = *exprs[srcIndex]; + auto compound = getAffineConstantExpr(0, context); + for (auto dim : dstDims) { + auto curr = getAffineDimExpr(dims++, context); + auto constant = + getAffineConstantExpr(types[dstIndex].getDimSize(dim), context); + compound = compound * constant + curr; + dstExprs[dim] = curr; + } + srcExprs[srcDims.front()] = compound; + } + }; + // Process order: m, n, k packing, kept same as packing iterator types + getPackingExprs(self.getMPacking(), ArrayRef{typeA, typeC}, + ArrayRef{&exprsA, &exprsC}); + getPackingExprs(self.getNPacking(), ArrayRef{typeB, typeC}, + ArrayRef{&exprsB, &exprsC}); + getPackingExprs(self.getKPacking(), ArrayRef{typeA, typeB}, + ArrayRef{&exprsA, &exprsB}); + exprsArr.emplace_back(exprsA); + exprsArr.emplace_back(exprsB); + exprsArr.emplace_back(exprsC); + return dims; +} + +ArrayAttr PackedMatmulOp::getIndexingMaps() { + static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; + ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); + if (cached) + return cached; + + SmallVector> exprsArr; + auto dims = getPackingDimsExpr(*this, exprsArr); + + MLIRContext *context = getContext(); + auto mapA = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[0], context)); + auto mapB = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[1], context)); + auto mapC = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[2], context)); + + cached = Builder(context).getAffineMapArrayAttr({mapA, mapB, mapC}); + getOperation()->setAttr(memoizeAttr, cached); + return cached; +} + +void PackedMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + assert(3 > 0 && block.getNumArguments() == 3 && + "PackedMatmulOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +ParseResult PackedMatmulOp::parse(OpAsmParser &parser, OperationState &result) { + return ::parseNamedStructuredOp(parser, result, + PackedMatmulOp::getNumRegionArgs(), + PackedMatmulOp::getRegionBuilder()); +} + +void PackedMatmulOp::print(OpAsmPrinter &p) { + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); +} + +LogicalResult PackedMatmulOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void PackedMatmulOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +LogicalResult PackedMatmulOp::verify() { + // A[M, K] + // B[K, N] + // C[M, N] + // mPacking = A -> C + // nPacking = B -> C + // kPacking = A -> B + auto shapeA = cast(getDpsInputOperand(0)->get().getType()); + auto shapeB = cast(getDpsInputOperand(1)->get().getType()); + auto shapeC = cast(getDpsInitOperand(0)->get().getType()); + auto mPacking = getMPacking(); + auto nPacking = getNPacking(); + auto kPacking = getKPacking(); + + // check rank + bool hasRank = shapeA.hasRank() && shapeB.hasRank() && shapeC.hasRank(); + if (!hasRank) + return emitOpError() << "input/output shape must have rank."; + + // check packing axis + auto getAxisSet = [](ArrayAttr arrayAttr, + llvm::SmallSet &firstIndexSet, + llvm::SmallSet &secondIndexSet) { + for (auto &attr : arrayAttr) { + auto packingMap = cast(attr); + auto firstDims = packingMap.getFirst(); + firstIndexSet.insert(firstDims.begin(), firstDims.end()); + auto secondDims = packingMap.getSecond(); + secondIndexSet.insert(secondDims.begin(), secondDims.end()); + } + }; + llvm::SmallSet indexSetA; + llvm::SmallSet indexSetB; + llvm::SmallSet indexSetC; + getAxisSet(mPacking, indexSetA, indexSetC); + getAxisSet(nPacking, indexSetB, indexSetC); + getAxisSet(kPacking, indexSetA, indexSetB); + bool checkAxis = (shapeA.getRank() == (int64_t)indexSetA.size()) && + (shapeB.getRank() == (int64_t)indexSetB.size()) && + (shapeC.getRank() == (int64_t)indexSetC.size()); + if (!checkAxis) + return emitOpError() << "input/output must match packing axis."; + + // check packing dims match + auto matchDims = [](ArrayAttr arrayAttr, ShapedType firstShape, + ShapedType secondShape) { + for (auto &attr : arrayAttr) { + auto packingMap = cast(attr); + bool isDynamic = false; + int64_t firstSize = 1; + auto firstDims = packingMap.getFirst(); + for (auto dim : firstDims) { + auto size = firstShape.getDimSize(dim); + if (size == ShapedType::kDynamic) + isDynamic = true; + firstSize *= size; + } + int64_t secondSize = 1; + auto secondDims = packingMap.getSecond(); + for (auto dim : secondDims) { + auto size = secondShape.getDimSize(dim); + if (size == ShapedType::kDynamic) + isDynamic = true; + secondSize *= size; + } + if (isDynamic) + continue; + if (firstSize != secondSize) + return false; + } + return true; + }; + bool matchM = matchDims(mPacking, shapeA, shapeC); + bool matchN = matchDims(nPacking, shapeB, shapeC); + bool matchK = matchDims(kPacking, shapeA, shapeB); + bool checkMatch = matchM && matchN && matchK; + if (!checkMatch) + return emitOpError() << "input/output must match packing dim size."; + + return success(); +} + //===----------------------------------------------------------------------===// // BatchReduceMatmulVnniOp //===----------------------------------------------------------------------===// diff --git a/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir b/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir index c87ca2259..4a5711c04 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/linalgx-named-ops.mlir @@ -25,6 +25,20 @@ func.func @mm4d_vnni(%arg0: tensor<2x8x32x32xbf16>, %arg1: tensor<4x8x16x32x2xbf return %0 : tensor<2x4x32x32xbf16> } +// CHECK-LABEL: @packed_matmul +#m_packing_vnni = [#linalgx.packing_map<[0] -> [0]>, #linalgx.packing_map<[2] -> [2]>] +#n_packing_vnni = [#linalgx.packing_map<[0] -> [1]>, #linalgx.packing_map<[3] -> [3]>] +#k_packing_vnni = [#linalgx.packing_map<[1] -> [1]>, #linalgx.packing_map<[3] -> [2, 4]>] +func.func @packed_matmul(%A: tensor<2x8x32x32xbf16>, %B: tensor<4x8x16x32x2xbf16>, + %C: tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> { + // CHECK: linalgx.packed_matmul + %0 = linalgx.packed_matmul + {m_packing = #m_packing_vnni, n_packing = #n_packing_vnni, k_packing = #k_packing_vnni} + ins(%A, %B : tensor<2x8x32x32xbf16>, tensor<4x8x16x32x2xbf16>) + outs(%C : tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + return %0 : tensor<2x4x32x32xbf16> +} + // CHECK-LABEL: @batch_reduce_matmul_vnni func.func @batch_reduce_matmul_vnni(%arg0: tensor<512x32x64xbf16>, %arg1: tensor<512x32x128x2xbf16>, %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> {