diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h index dd3772c37..05d039682 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h @@ -10,6 +10,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallString.h" namespace mlir::tt::ttir { #define GEN_PASS_DECL diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index b6269f715..b26dcf8cc 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -125,4 +125,43 @@ def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { }]; } +def TTIRHoistTransform: Pass<"ttir-cpu-hoist-transform", "::mlir::ModuleOp"> +{ + let summary = "Transform to perform hoist mechanics on any ops marked to be hoisted for CPU lowering"; + let description = [{ + Transform pass which runs an analysis pass to find ops which should be hoisted, and then hoists those ops. Currently we only have a manual analysis which requires a commandline list of named locs to hoist--in the future, we will have an automatic analysis as well. + + Example: + input: + module { + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1") + return %1 : tensor<32x32xbf16> + } + } + + output: + module { + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> + } + module @cpu_module attributes {ttir.cpu_module} { + func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} { + %0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0 : tensor<32x32xbf16> + } + } + func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + } + + }]; + let options = [ + ListOption<"hoistLocs", "hoist-locations", "std::string", + "comma-separated list of NameLoc's from input.">, + ]; +} + #endif diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 597c55e3c..a16078f59 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms Broadcast.cpp Constant.cpp Generic.cpp + HoistCPUOps.cpp Layout.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp b/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp new file mode 100644 index 000000000..84089ffa3 --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRHOISTTRANSFORM +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Hoist CPU ops to standalone funcs pass +//===----------------------------------------------------------------------===// + +// Helper function to get ranks of an op's operands +// we use this to populate attrs which we need to tensor unpacking operations +// later. +static llvm::SmallVector +getOperandTensorRanks(mlir::Operation *op) { + llvm::SmallVector ranks; + + // Iterate over operands (inputs) + for (auto operand : op->getOperands()) { + // Check if the operand is a tensor + if (auto tensorType = dyn_cast(operand.getType())) { + // Add the rank of the tensor (number of dimensions) + ranks.push_back(tensorType.getRank()); + } + } + + return ranks; +} + +// Generate unique name base on operation type + argument tensors dims & types. +static llvm::SmallString<16> generateHoistedFuncName(mlir::Operation *op) { + // Start building the unique function name + llvm::SmallString<16> uniqueName("hoisted_"); + uniqueName.append(op->getName().getStringRef().begin(), + op->getName().getStringRef().end()); + + // Iterate over operands to extract tensor shapes and types + for (auto operand : op->getOperands()) { + auto rankedTensorType = dyn_cast(operand.getType()); + if (rankedTensorType) { + // Append the shape (dimensions) and the element type + llvm::SmallString<5> shapeStr("_"); + for (auto dim : rankedTensorType.getShape()) { + shapeStr += std::to_string(dim) + "x"; + } + + // Append the element type (e.g., f32, i32) -- unforunately I don't think + // there's a better way to get string from mlir::Type + std::string elementTypeStr; + llvm::raw_string_ostream stream(elementTypeStr); + rankedTensorType.getElementType().print(stream); + + uniqueName.append(shapeStr.begin(), shapeStr.end()); + uniqueName.append(elementTypeStr.begin(), elementTypeStr.end()); + } + } + + // Add suffix to indicate it's a function + uniqueName += "_func"; + + return uniqueName; +} + +// Helper function to hoist an arbitrary op into a new function in targetModule, +// generate a matching extern prototype in the sourceModule, and replace the +// original op with a callOp to the extern function. +static void hoistOperationToFunction(mlir::Operation *opToHoist, + mlir::ModuleOp sourceModule, + mlir::ModuleOp targetModule) { + const llvm::SmallVector ranks = getOperandTensorRanks(opToHoist); + + const llvm::SmallString<16> functionName = generateHoistedFuncName(opToHoist); + + auto localFunc = llvm::dyn_cast_or_null( + sourceModule.lookupSymbol(functionName)); + + // Create a new hoisted function only if an equivalent one does not exist. + if (localFunc == nullptr) { + mlir::MLIRContext *context = sourceModule.getContext(); + + // Gather operand and result types. + llvm::SmallVector operandTypes, resultTypes; + for (auto operand : opToHoist->getOperands()) { + operandTypes.push_back(operand.getType()); + } + for (auto result : opToHoist->getResultTypes()) { + resultTypes.push_back(result); + } + + // Create the function signature. + mlir::FunctionType funcType = + mlir::FunctionType::get(context, operandTypes, resultTypes); + + // Create the function in the target module. + auto hoistedFunc = + func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); + targetModule.push_back(hoistedFunc); + + // Add a basic block to the function. + mlir::Block *block = hoistedFunc.addEntryBlock(); + mlir::OpBuilder builder(block, block->end()); + + // Map operands to block arguments and clone the operation. + llvm::SmallVector newOperands; + for (auto operand : llvm::enumerate(opToHoist->getOperands())) { + newOperands.push_back(block->getArgument(operand.index())); + } + + mlir::IRMapping mapping; + for (auto operand : llvm::zip(opToHoist->getOperands(), newOperands)) { + mapping.map(std::get<0>(operand), std::get<1>(operand)); + } + + mlir::Operation *clonedOp = builder.clone(*opToHoist, mapping); + + // Add a return operation to the function. + builder.create(opToHoist->getLoc(), + clonedOp->getResults()); + + // Declare the function prototype in the source module. + localFunc = + func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); + localFunc.setPrivate(); + sourceModule.push_back(localFunc); + + hoistedFunc->setAttr("arg_ranks", builder.getI64ArrayAttr(ranks)); + } + + // Replace the original operation with a call to the hoisted function. + mlir::OpBuilder opBuilder(opToHoist); + auto callOp = opBuilder.create( + opToHoist->getLoc(), localFunc, opToHoist->getOperands()); + + // Replace all results of the original operation with the call results. + opToHoist->replaceAllUsesWith(callOp); + + // Erase the original operation. + opToHoist->erase(); +} + +// An analysis class which currently relies on manually tagging ops with a +// `should_hoist` attribute, but in the future will also tag fall-back ops, etc. +class TTIRHoistAnalyze { +public: + using HoistOpSet = llvm::SmallVector>; + + TTIRHoistAnalyze(mlir::ModuleOp moduleOp) { + moduleOp.walk([&](mlir::Operation *nestedOp) { + if (nestedOp->hasAttr("should_hoist")) { + llvm::SmallSet opSet; + opSet.insert(nestedOp); + hoistedOps.push_back(opSet); + } + }); + } + + HoistOpSet getResults() { return hoistedOps; } + +private: + HoistOpSet hoistedOps; +}; + +// Transform pass to hoist specific ops (based on configured analysis pass) into +// a cpu submodule for later independent lowering. +class TTIRHoistTransform + : public impl::TTIRHoistTransformBase { +public: + using impl::TTIRHoistTransformBase< + TTIRHoistTransform>::TTIRHoistTransformBase; + + void runOnOperation() final { + mlir::ModuleOp moduleOp = getOperation(); + + IRRewriter rewriter(&getContext()); + + auto loc = moduleOp->getLoc(); + + TTIRHoistAnalyze analysisPass(moduleOp); + const TTIRHoistAnalyze::HoistOpSet &hoistOpSets = analysisPass.getResults(); + + // Check if a "cpu_module" already exists. + mlir::ModuleOp cpuModule; + for (auto &op : moduleOp.getBody()->getOperations()) { + if (auto module = llvm::dyn_cast(op)) { + if (module->hasAttr("ttir.cpu_module")) { + cpuModule = module; + break; + } + } + } + + // If no CPU module exists, create one. + if (!cpuModule) { + rewriter.setInsertionPointToEnd(moduleOp.getBody()); + cpuModule = rewriter.create(loc); + cpuModule->setAttr("ttir.cpu_module", rewriter.getUnitAttr()); + cpuModule->setAttr(mlir::SymbolTable::getSymbolAttrName(), + rewriter.getStringAttr("cpu_module")); + // try to make cpu module global + mlir::SymbolTable::setSymbolVisibility( + cpuModule, mlir::SymbolTable::Visibility::Public); + } + + for (const auto &opSet : hoistOpSets) { + assert(opSet.size() == 1 && + "currently don't support hoisting multiple instructions at once!"); + hoistOperationToFunction(*opSet.begin(), moduleOp, cpuModule); + } + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index cfe0dacc3..6a0fed3e8 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -4,6 +4,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" diff --git a/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir b/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir new file mode 100644 index 000000000..7a4d788c4 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir @@ -0,0 +1,37 @@ +// RUN: ttmlir-opt --ttir-cpu-hoist-transform %s | FileCheck %s + +func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} + +func.func @add2(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} + +func.func @add3(%arg0: tensor<32x3xf32>, %arg1: tensor<32x3xf32>) -> tensor<32x3xf32> { + %0 = tensor.empty() : tensor<32x3xf32> +// CHECK: %{{.*}} = call @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x3xf32>, tensor<32x3xf32>, tensor<32x3xf32>) -> tensor<32x3xf32> + return %1 : tensor<32x3xf32> +} + +func.func @add4(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} +// CHECK: module @cpu_module attributes {ttir.cpu_module} +// CHECK: func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func +// CHECK: func.func @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func +// CHECK: func.func @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func + +// CHECK: func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func +// CHECK: func.func private @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func +// CHECK: func.func private @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func