diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 3108d59e4..710c5768a 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -58,6 +58,7 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> { TT_OperandConstraintArrayAttr:$operand_constraints); let results = (outs Variadic:$results); let regions = (region AnyRegion:$region); + let hasVerifier = 1; } def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> { diff --git a/include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td b/include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td index 02d8f3e82..acf15bf68 100644 --- a/include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td +++ b/include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td @@ -30,8 +30,7 @@ def TTMetal_DispatchOp : TTMetal_Op<"dispatch", [DestinationStyleOpInterface, At let arguments = (ins Variadic:$inputs, Variadic:$outputs, TTMetal_CoreRangeArrayAttr:$core_ranges, - TTKernel_ThreadTypeArrayAttr:$threadTypes, - ArrayAttr:$operand_cb_port_mapping); + TTKernel_ThreadTypeArrayAttr:$threadTypes); let results = (outs Variadic:$results); let regions = (region VariadicRegion:$regions); diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index 32a1c036a..19baee00c 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -48,6 +48,9 @@ llvm::SmallVector evalShape(mlir::AffineMap map, Vector shape) { return result; } +template std::underlying_type_t enum_as_int(Enum e) { + return static_cast>(e); +} } // namespace ttmlir::utils #endif diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index a97ad6695..5f23e7967 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -51,6 +51,14 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() { isMemorySpaceChange); } +::mlir::LogicalResult mlir::tt::ttir::GenericOp::verify() { + if (getNumOperands() != getRegion().getNumArguments()) { + return emitOpError( + "The number of op operands and region/block operands must match"); + } + return success(); +} + template static void buildGenericEltwiseBinaryRegion(::mlir::Location loc, ::mlir::OpBuilder &opBuilder, diff --git a/lib/Dialect/TTMetal/Transforms/Passes.cpp b/lib/Dialect/TTMetal/Transforms/Passes.cpp index 4ac93286c..4146db7c2 100644 --- a/lib/Dialect/TTMetal/Transforms/Passes.cpp +++ b/lib/Dialect/TTMetal/Transforms/Passes.cpp @@ -7,6 +7,7 @@ #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -217,7 +218,6 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { auto noc0Attr = rewriter.getAttr(ttkernel::ThreadType::Noc0); SmallVector threadTypes(dm.size(), noc0Attr); - SmallVector operand_cb_port_mapping; SmallVector coreRanges; coreRanges.reserve(dm.size()); for (auto [dstCoord, srcs] : dm) { @@ -231,8 +231,7 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { op.getLoc(), SmallVector({outputTy}), SmallVector({op.getInput()}), SmallVector({op.getOutput()}), rewriter.getArrayAttr(coreRanges), - rewriter.getArrayAttr(threadTypes), - rewriter.getArrayAttr(operand_cb_port_mapping), threadTypes.size()); + rewriter.getArrayAttr(threadTypes), threadTypes.size()); int i = 0; PhysicalCoreCoordMapping physicalCoordMapping(systemDesc.getChipDescs()); @@ -272,10 +271,6 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { auto tensixAttr = rewriter.getAttr( ttkernel::ThreadType::Tensix); SmallVector threadTypes = {tensixAttr}; - SmallVector operand_cb_port_mapping = { - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(16), - }; SmallVector coreRanges = { rewriter.getAttr(inputLayout.getGrid()), @@ -285,8 +280,7 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { op.getLoc(), SmallVector({outputTy}), SmallVector({op.getInput()}), SmallVector({op.getOutput()}), rewriter.getArrayAttr(coreRanges), - rewriter.getArrayAttr(threadTypes), - rewriter.getArrayAttr(operand_cb_port_mapping), threadTypes.size()); + rewriter.getArrayAttr(threadTypes), threadTypes.size()); std::int64_t inputBaseAddress = lookupAddress(op.getInput()); std::int64_t outputBaseAddress = lookupAddress(op.getOutput()); @@ -417,17 +411,6 @@ class TTIRToTTMetalKernelRewriter : public OpRewritePattern { } }; -class TTIRToTTMetalReturnRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ttir::YieldOp op, - PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op); - return success(); - } -}; - class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -442,21 +425,327 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { return exists; } - SmallVector getBlockArgumentTypesAsCBs( - mlir::Block::BlockArgListType blockArguments, - SmallVector const &operand_cb_port_mapping, - PatternRewriter &rewriter) const { + ttkernel::CBPort getPort(unsigned argNumber, + std::int64_t numDPSInputs) const { + std::int64_t operandInOutPartition = numDPSInputs; + std::uint32_t portIdx = 0; + if (argNumber < static_cast(operandInOutPartition)) { + assert(argNumber < 8 && "Exceeds max 8 input ports"); + portIdx = ttmlir::utils::enum_as_int(ttkernel::CBPort::In0) + argNumber; + } else { + assert((argNumber - operandInOutPartition) < 8 && + "Exceeds max 8 output ports"); + portIdx = ttmlir::utils::enum_as_int(ttkernel::CBPort::Out0) + + (argNumber - operandInOutPartition); + } + std::optional maybePort = + ttkernel::symbolizeCBPort(portIdx); + assert(maybePort.has_value() && "Expected legal port value"); + return maybePort.value(); + } + + // This routine evaluates the memref's affine map with it's shape to return a + // single result affine map, e.g.: + // - Given: shape{2, 4} and affine_map<(d0, d1) -> (d0, d1)> + // - Becomes: affine_map<(d0, d1) -> (d0 * 4 + d1) + // This is useful for evaluating iterator increment steps between each loop. + AffineMap getAffineIterator(MemRefType memref) const { + ArrayRef shape = memref.getShape(); + SmallVector physShape = + memref.getLayout().getAffineMap().compose(shape); + + mlir::AffineExpr resultExpr = getAffineConstantExpr(0, memref.getContext()); + int volume = 1; + for (int i = static_cast(physShape.size()) - 1; i >= 0; i--) { + mlir::AffineExpr dimExpr = getAffineDimExpr(i, memref.getContext()); + mlir::AffineExpr strideExpr = + getAffineConstantExpr(volume, memref.getContext()); + resultExpr = dimExpr * strideExpr + resultExpr; + volume *= physShape[i]; + } + return AffineMap::get(physShape.size(), 0, resultExpr, memref.getContext()); + } + + Value i32(std::int32_t value, OpBuilder &builder) const { + return builder + .create(builder.getUnknownLoc(), + builder.getI32Type(), + builder.getI32IntegerAttr(value)) + .getResult(); + } + + struct LoopNest { + SmallVector loops; + SmallVector loopRegions; + SmallVector blockArgIteratorMapping; + }; + + // Creates a loop nest that walks the input/output operand tiles in the shard. + // Converts this: + // %0 = arith.add(%1, %2) : tensor<2x4x!tile>, tensor<2x4x!tile> + // -> tensor<2x4x!tile> + // Into this: + // for (%i0 = 0; %i0 < 2; %i0++) + // for (%i1 = 0; %i1 < 4; %i1++) + // %ii = %i0 * 4 + %i1 + // %3 = ttkernel.add_tiles(%1, %2, %ii, %ii) + LoopNest createLoopNest(ArrayRef blockArguments, + std::int64_t numDPSInputs, OpBuilder &builder) const { + Value output = blockArguments[numDPSInputs]; + ttkernel::CBType outputTy = mlir::cast(output.getType()); + MemRefType outputMemref = outputTy.getMemref(); + ArrayRef outputShape = outputMemref.getShape(); + AffineMap outputAffineMap = outputMemref.getLayout().getAffineMap(); + + // Uniquify the iterators, i.e. operands that have identical access pattern + // can be shared. + llvm::MapVector iteratorMaps; + auto getOrInsertIterator = [&iteratorMaps, &builder, + this](AffineMap affineIterator) { + if (iteratorMaps.find(affineIterator) == iteratorMaps.end()) { + iteratorMaps[affineIterator] = i32(0, builder); + } + return iteratorMaps[affineIterator]; + }; + + // Map block arguments to their respective unique iterators Values + SmallVector iterators; + iterators.resize(blockArguments.size()); + for (BlockArgument operand : blockArguments) { + auto cbType = mlir::cast(operand.getType()); + AffineMap affineIterator = getAffineIterator(cbType.getMemref()); + assert(affineIterator.getNumDims() == outputAffineMap.getNumDims()); + iterators[operand.getArgNumber()] = getOrInsertIterator(affineIterator); + } + + // Map block arguments to their respective unique iterator offset in the + // map. This is needed by the caller to know how to wire the iterators into + // the ttkernel tile operation. + SmallVector blockArgIteratorMapping; + blockArgIteratorMapping.resize(blockArguments.size()); + for (BlockArgument operand : blockArguments) { + auto cbType = mlir::cast(operand.getType()); + AffineMap affineIterator = getAffineIterator(cbType.getMemref()); + auto match = iteratorMaps.find(affineIterator); + assert(match != iteratorMaps.end()); + blockArgIteratorMapping[operand.getArgNumber()] = + std::distance(iteratorMaps.begin(), match); + } + + // Convert the map data structure into a vector because it's easier to work + // with when creating the loop nest below. + SmallVector uniqueIterators; + for (auto [affineMap, iterator] : iteratorMaps) { + uniqueIterators.push_back(iterator); + } + + // Create loop nest + // The loop nest is created from outermost to innermost. The innermost loop + // is special in the sense that it implements the actual iterator increment + // and the tile operation. The outer loops are responsible for fixing up the + // iterator offset for the current dimension if there was a stride or we're + // accessing the tiles in non-row-major order. + // + // iterators are just ints that correspond to absolute offsets in the CB. + // They walk the order defined by the affine map associated with the memref. + LoopNest loopNest; + loopNest.blockArgIteratorMapping = blockArgIteratorMapping; + SmallVector loops; + SmallVector loopRegions; + SmallVector> iteratorsNest = {uniqueIterators}; + for (unsigned dim = 0; dim < outputAffineMap.getNumDims(); ++dim) { + OpBuilder regionBuilder(builder); + if (!loopNest.loopRegions.empty()) { + regionBuilder = OpBuilder(loopNest.loopRegions.back()); + } + // Loop variables, these are decoupled from the iterators + Value lowerBound = i32(0, regionBuilder); + Value upperBound = i32(outputShape[dim], regionBuilder); + Value loopStep = i32(1, regionBuilder); + scf::ForOp forOp = regionBuilder.create( + output.getLoc(), lowerBound, upperBound, loopStep, + iteratorsNest.back()); + loopNest.loops.push_back(forOp); + + SmallVector innerIndexStep(outputAffineMap.getNumDims(), 0); + innerIndexStep[dim] = 1; + + bool innerLoop = dim == (outputAffineMap.getNumDims() - 1); + if (innerLoop) { + OpBuilder innerLoopRegion(loopNest.loops.back().getRegion()); + SmallVector innerIndices; + int i = 0; + for (auto [affineMap, iterator] : iteratorMaps) { + // Calculate how far a single step in the inner dim is. + SmallVector innerOffset = + affineMap.compose(innerIndexStep); + assert(innerOffset.size() == 1); + innerIndices.push_back(innerLoopRegion.create( + output.getLoc(), forOp.getRegionIterArg(i), + i32(innerOffset[0], innerLoopRegion))); + ++i; + } + innerLoopRegion.create(output.getLoc(), innerIndices); + } + + // Backpedal and adjust the iterator offset for the current dimension. + if (dim > 0) { + SmallVector outerIndices; + SmallVector outerIndexStep(outputAffineMap.getNumDims(), + 0); + outerIndexStep[dim - 1] = 1; + int i = 0; + for (auto [affineMap, iterator] : iteratorMaps) { + // Calculate how far a single step in the inner dim is. + SmallVector innerOffset = + affineMap.compose(innerIndexStep); + assert(innerOffset.size() == 1); + // Calculate how far a single step in the outer dim is. + SmallVector outerOffset = + affineMap.compose(outerIndexStep); + assert(outerOffset.size() == 1); + // Multiply by the number of steps that the inner loop took. + // FIXME: test this for higher dims + std::int64_t offset = + outerOffset[0] - innerOffset[0] * outputShape[dim]; + outerIndices.push_back(regionBuilder.create( + output.getLoc(), forOp.getResult(i), i32(offset, regionBuilder))); + ++i; + } + regionBuilder.create(output.getLoc(), outerIndices); + } + + loopNest.loopRegions.push_back(&loopNest.loops.back().getRegion()); + iteratorsNest.emplace_back(forOp.getRegionIterArgs()); + } + + return loopNest; + } + + // Convert arith and math dialect operations into ttkernel init tile + // operations. HLK requires the FPU to be initialized before any tile ops get + // executed. We separate the init tile operation from the actual tile + // operation so that we can hoist the init tile operation outside of the loop + // nest. + void convertComputeInitOp(Operation &op, ArrayRef cbOperands, + std::int64_t numDpsInputs, + OpBuilder &builder) const { + SmallVector operandIndices; + for (OpOperand &operand : op.getOpOperands()) { + operandIndices.push_back(operand.getOperandNumber()); + } + if (mlir::isa(op)) { + builder.create( + op.getLoc(), cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]], cbOperands[numDpsInputs]); + builder.create(op.getLoc(), + cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]]); + } else if (mlir::isa(op)) { + builder.create( + op.getLoc(), cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]], cbOperands[numDpsInputs]); + builder.create(op.getLoc(), + cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]]); + } else { + llvm_unreachable("Unhandled conversion"); + } + } + + // Convert arith and math dialect operations into ttkernel tile operations. + // Here iterators are the block arguments from the innermost scf.for loop. + // The iterators are unique-ified so we need blockArgIteratorMapping to + // recover which top level tensor operand is associated with which iterator. + void convertComputeOp(Operation &op, ArrayRef cbOperands, + ArrayRef iterators, + SmallVector blockArgIteratorMapping, + Value dstIndex, OpBuilder &builder) const { + SmallVector operandIndices; + for (OpOperand &operand : op.getOpOperands()) { + operandIndices.push_back(operand.getOperandNumber()); + } + if (mlir::isa(op)) { + builder.create( + op.getLoc(), cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]], iterators[blockArgIteratorMapping[0]], + iterators[blockArgIteratorMapping[1]], dstIndex); + } else if (mlir::isa(op)) { + builder.create( + op.getLoc(), cbOperands[operandIndices[0]], + cbOperands[operandIndices[1]], iterators[blockArgIteratorMapping[0]], + iterators[blockArgIteratorMapping[1]], dstIndex); + } else { + llvm_unreachable("Unhandled conversion"); + } + } + + // Convert the original block into a lowered block that contains a fully + // expanded loop nest and inner loop that implements the underlying arith or + // math operation as a tile operation. + void lowerBlock(Block *origBlock, Block *computeBlock, + std::int64_t numDPSInputs) const { + Block::OpListType &operations = origBlock->getOperations(); + assert(operations.size() == 2); + Operation::user_range users = operations.front().getUsers(); + assert(users.begin() != users.end()); + assert(mlir::isa(*users.begin())); + assert(computeBlock->getNumArguments() > numDPSInputs); + assert((computeBlock->getNumArguments() - numDPSInputs) == 1 && + "Expected 1 output"); + + OpBuilder builder(computeBlock, computeBlock->begin()); + convertComputeInitOp(operations.front(), computeBlock->getArguments(), + numDPSInputs, builder); + LoopNest loopNest = + createLoopNest(computeBlock->getArguments(), numDPSInputs, builder); + builder.create(origBlock->getTerminator()->getLoc()); + + // Build the inner loop compute / unpack / pack + { + Value output = computeBlock->getArgument(numDPSInputs); + Region *innerLoopRegion = loopNest.loopRegions.back(); + ArrayRef iterators = + loopNest.loops.back().getRegionIterArgs(); + SmallVector blockArgIteratorMapping = + loopNest.blockArgIteratorMapping; + OpBuilder innerLoopBuilder(&innerLoopRegion->front(), + innerLoopRegion->front().begin()); + Value dstIndex = i32(0, innerLoopBuilder); + innerLoopBuilder.create( + computeBlock->front().getLoc()); + convertComputeOp(operations.front(), computeBlock->getArguments(), + iterators, blockArgIteratorMapping, dstIndex, + innerLoopBuilder); + innerLoopBuilder.create( + computeBlock->front().getLoc()); + innerLoopBuilder.create( + computeBlock->front().getLoc()); + innerLoopBuilder.create( + computeBlock->front().getLoc(), dstIndex, output, + iterators[blockArgIteratorMapping[numDPSInputs]]); + innerLoopBuilder.create( + computeBlock->front().getLoc()); + } + } + + SmallVector + getBlockArgumentTypesAsCBs(mlir::OperandRange dispatchOperands, + mlir::Block::BlockArgListType blockArguments, + std::int64_t numDPSInputs, + PatternRewriter &rewriter) const { SmallVector rewrittenBlockArgumentTypes; for (auto arg : blockArguments) { - auto address = lookupAddress(arg); - auto port = - mlir::cast(operand_cb_port_mapping[arg.getArgNumber()]) - .getInt(); + auto address = lookupAddress(dispatchOperands[arg.getArgNumber()]); + assert(address && "Expected valid address"); + auto port = getPort(arg.getArgNumber(), numDPSInputs); auto tensor = mlir::cast(arg.getType()); auto buffer = mlir::cast(tensor.getEncoding()); auto memref = buffer.getMemref(); - rewrittenBlockArgumentTypes.push_back(rewriter.getType( - ttkernel::symbolizeCBPort(port).value(), address, memref)); + assert(buffer.getBufferAccess() == BufferAccess::Alias && + "Currently only alias mode is supported"); + rewrittenBlockArgumentTypes.push_back( + rewriter.getType(port, address, memref)); } return rewrittenBlockArgumentTypes; } @@ -468,66 +757,28 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { } SmallVector threadTypes = { - rewriter.getAttr(ttkernel::ThreadType::Noc0), - rewriter.getAttr(ttkernel::ThreadType::Noc1), rewriter.getAttr( ttkernel::ThreadType::Tensix), }; SmallVector coreRanges = { rewriter.getAttr(op.getGrid()), - rewriter.getAttr(op.getGrid()), - rewriter.getAttr(op.getGrid()), }; - SmallVector operand_cb_port_mapping; - for (auto &operand : op->getOpOperands()) { - operand_cb_port_mapping.push_back( - rewriter.getI64IntegerAttr(operand.getOperandNumber())); - } + auto metalDispatch = rewriter.create( op.getLoc(), op.getResults().getTypes(), op.getInputs(), op.getOutputs(), rewriter.getArrayAttr(coreRanges), - rewriter.getArrayAttr(threadTypes), - rewriter.getArrayAttr(operand_cb_port_mapping), threadTypes.size()); + rewriter.getArrayAttr(threadTypes), threadTypes.size()); auto rewrittenBlockArgumentTypes = getBlockArgumentTypesAsCBs( - op->getRegion(0).getArguments(), operand_cb_port_mapping, rewriter); + op->getOperands(), op->getRegion(0).getArguments(), + op.getNumDpsInputs(), rewriter); - metalDispatch.getRegion(2).takeBody(op->getRegion(0)); - Block *tensixBlock = &metalDispatch.getRegion(2).front(); - Block *noc0Block = rewriter.createBlock(&metalDispatch.getRegion(0)); - Block *noc1Block = rewriter.createBlock(&metalDispatch.getRegion(1)); - - int i = 0; + Block *tensixBlock = &metalDispatch.getRegion(0).emplaceBlock(); for (auto ty : rewrittenBlockArgumentTypes) { - noc0Block->addArgument(ty, op.getLoc()); - noc1Block->addArgument(ty, op.getLoc()); - auto arg = tensixBlock->getArgument(i++); - arg.setType(ty); + tensixBlock->addArgument(ty, op.getLoc()); } - { - OpBuilder noc0Builder(noc0Block, noc0Block->begin()); - auto one = noc0Builder.create( - op.getLoc(), noc0Builder.getI32Type(), - noc0Builder.getI32IntegerAttr(1)); - noc0Builder.create( - op.getLoc(), noc0Block->getArgument(0), one); - noc0Builder.create( - op.getLoc(), noc0Block->getArgument(0), one); - noc0Builder.create(op.getLoc(), ValueRange()); - } - - { - OpBuilder noc1Builder(noc1Block, noc1Block->begin()); - auto one = noc1Builder.create( - op.getLoc(), noc1Builder.getI32Type(), - noc1Builder.getI32IntegerAttr(1)); - noc1Builder.create( - op.getLoc(), noc1Block->getArgument(0), one); - noc1Builder.create( - op.getLoc(), noc1Block->getArgument(0), one); - noc1Builder.create(op.getLoc(), ValueRange()); - } + lowerBlock(&op->getRegion(0).front(), tensixBlock, op.getNumDpsInputs()); rewriter.replaceOp(op, metalDispatch); @@ -567,9 +818,8 @@ class ConvertTTIRToTTMetal void runOnOperation() final { RewritePatternSet patterns(&getContext()); patterns.add( - &getContext()); + TTIRToTTMetalDispatchRewriter, TTIRToTTMetalAllocRewriter, + TTIRToTTMetalDeallocRewriter>(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { signalPassFailure();