From 432f8d82338ba23dcc6471263382f8259adae41c Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Fri, 20 Dec 2024 11:01:17 -0600 Subject: [PATCH] Add hoist mechanism for specific TTIR ops, with placeholder analysis pass (#1586) ### Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU. ### Context: The entire task will be split into (tentatively) 7 PRs, as follows: 1. Hoist specific ops into isolated funcs in a separate module 2. Convert TTIR ops to linalg ops within the module of hoisted funcs 3. **Build a pipeline to lower linalg to llvm from existing conversion passes** 4. Translate LLVM Dialect into a dynamic library for packing into flatbuffer 5. Generate helper functions so that we can call all of our hoisted funcs with a common signature 6. Insert TTNN instructions to move operands to host before executing hoisted func, then back to device afterwards 7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new passes, generate dylibs, and embed them into output flatbuffers, and update update runtime to consume dylibs from flatbuffers This PR represents the 1st point above. Here, we build hoisting (placeholder) analysis + transform pass to mark specific ops to be hoisted, and then actually pull them into separate functions in a new "cpu" module + replace the original op with a call to the func. ## Example: input: ``` module { func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1") return %1 : tensor<32x32xbf16> } } ``` output: ``` module { func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> %1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } module @cpu_module attributes {ttir.cpu_module} { func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} { %0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %0 : tensor<32x32xbf16> } } func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> } ``` --- .../ttmlir/Dialect/TTIR/Transforms/Passes.h | 1 + .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 39 +++ lib/Dialect/TTIR/Transforms/CMakeLists.txt | 1 + lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp | 226 ++++++++++++++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 1 + .../ttmlir/Dialect/TTIR/hoist/simple_add.mlir | 37 +++ 6 files changed, 305 insertions(+) create mode 100644 lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp create mode 100644 test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h index dd3772c37..05d039682 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.h +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.h @@ -10,6 +10,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallString.h" namespace mlir::tt::ttir { #define GEN_PASS_DECL diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index b6269f715..b26dcf8cc 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -125,4 +125,43 @@ def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { }]; } +def TTIRHoistTransform: Pass<"ttir-cpu-hoist-transform", "::mlir::ModuleOp"> +{ + let summary = "Transform to perform hoist mechanics on any ops marked to be hoisted for CPU lowering"; + let description = [{ + Transform pass which runs an analysis pass to find ops which should be hoisted, and then hoists those ops. Currently we only have a manual analysis which requires a commandline list of named locs to hoist--in the future, we will have an automatic analysis as well. + + Example: + input: + module { + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1") + return %1 : tensor<32x32xbf16> + } + } + + output: + module { + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> + } + module @cpu_module attributes {ttir.cpu_module} { + func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} { + %0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0 : tensor<32x32xbf16> + } + } + func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + } + + }]; + let options = [ + ListOption<"hoistLocs", "hoist-locations", "std::string", + "comma-separated list of NameLoc's from input.">, + ]; +} + #endif diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 597c55e3c..a16078f59 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms Broadcast.cpp Constant.cpp Generic.cpp + HoistCPUOps.cpp Layout.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp b/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp new file mode 100644 index 000000000..84089ffa3 --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRHOISTTRANSFORM +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Hoist CPU ops to standalone funcs pass +//===----------------------------------------------------------------------===// + +// Helper function to get ranks of an op's operands +// we use this to populate attrs which we need to tensor unpacking operations +// later. +static llvm::SmallVector +getOperandTensorRanks(mlir::Operation *op) { + llvm::SmallVector ranks; + + // Iterate over operands (inputs) + for (auto operand : op->getOperands()) { + // Check if the operand is a tensor + if (auto tensorType = dyn_cast(operand.getType())) { + // Add the rank of the tensor (number of dimensions) + ranks.push_back(tensorType.getRank()); + } + } + + return ranks; +} + +// Generate unique name base on operation type + argument tensors dims & types. +static llvm::SmallString<16> generateHoistedFuncName(mlir::Operation *op) { + // Start building the unique function name + llvm::SmallString<16> uniqueName("hoisted_"); + uniqueName.append(op->getName().getStringRef().begin(), + op->getName().getStringRef().end()); + + // Iterate over operands to extract tensor shapes and types + for (auto operand : op->getOperands()) { + auto rankedTensorType = dyn_cast(operand.getType()); + if (rankedTensorType) { + // Append the shape (dimensions) and the element type + llvm::SmallString<5> shapeStr("_"); + for (auto dim : rankedTensorType.getShape()) { + shapeStr += std::to_string(dim) + "x"; + } + + // Append the element type (e.g., f32, i32) -- unforunately I don't think + // there's a better way to get string from mlir::Type + std::string elementTypeStr; + llvm::raw_string_ostream stream(elementTypeStr); + rankedTensorType.getElementType().print(stream); + + uniqueName.append(shapeStr.begin(), shapeStr.end()); + uniqueName.append(elementTypeStr.begin(), elementTypeStr.end()); + } + } + + // Add suffix to indicate it's a function + uniqueName += "_func"; + + return uniqueName; +} + +// Helper function to hoist an arbitrary op into a new function in targetModule, +// generate a matching extern prototype in the sourceModule, and replace the +// original op with a callOp to the extern function. +static void hoistOperationToFunction(mlir::Operation *opToHoist, + mlir::ModuleOp sourceModule, + mlir::ModuleOp targetModule) { + const llvm::SmallVector ranks = getOperandTensorRanks(opToHoist); + + const llvm::SmallString<16> functionName = generateHoistedFuncName(opToHoist); + + auto localFunc = llvm::dyn_cast_or_null( + sourceModule.lookupSymbol(functionName)); + + // Create a new hoisted function only if an equivalent one does not exist. + if (localFunc == nullptr) { + mlir::MLIRContext *context = sourceModule.getContext(); + + // Gather operand and result types. + llvm::SmallVector operandTypes, resultTypes; + for (auto operand : opToHoist->getOperands()) { + operandTypes.push_back(operand.getType()); + } + for (auto result : opToHoist->getResultTypes()) { + resultTypes.push_back(result); + } + + // Create the function signature. + mlir::FunctionType funcType = + mlir::FunctionType::get(context, operandTypes, resultTypes); + + // Create the function in the target module. + auto hoistedFunc = + func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); + targetModule.push_back(hoistedFunc); + + // Add a basic block to the function. + mlir::Block *block = hoistedFunc.addEntryBlock(); + mlir::OpBuilder builder(block, block->end()); + + // Map operands to block arguments and clone the operation. + llvm::SmallVector newOperands; + for (auto operand : llvm::enumerate(opToHoist->getOperands())) { + newOperands.push_back(block->getArgument(operand.index())); + } + + mlir::IRMapping mapping; + for (auto operand : llvm::zip(opToHoist->getOperands(), newOperands)) { + mapping.map(std::get<0>(operand), std::get<1>(operand)); + } + + mlir::Operation *clonedOp = builder.clone(*opToHoist, mapping); + + // Add a return operation to the function. + builder.create(opToHoist->getLoc(), + clonedOp->getResults()); + + // Declare the function prototype in the source module. + localFunc = + func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); + localFunc.setPrivate(); + sourceModule.push_back(localFunc); + + hoistedFunc->setAttr("arg_ranks", builder.getI64ArrayAttr(ranks)); + } + + // Replace the original operation with a call to the hoisted function. + mlir::OpBuilder opBuilder(opToHoist); + auto callOp = opBuilder.create( + opToHoist->getLoc(), localFunc, opToHoist->getOperands()); + + // Replace all results of the original operation with the call results. + opToHoist->replaceAllUsesWith(callOp); + + // Erase the original operation. + opToHoist->erase(); +} + +// An analysis class which currently relies on manually tagging ops with a +// `should_hoist` attribute, but in the future will also tag fall-back ops, etc. +class TTIRHoistAnalyze { +public: + using HoistOpSet = llvm::SmallVector>; + + TTIRHoistAnalyze(mlir::ModuleOp moduleOp) { + moduleOp.walk([&](mlir::Operation *nestedOp) { + if (nestedOp->hasAttr("should_hoist")) { + llvm::SmallSet opSet; + opSet.insert(nestedOp); + hoistedOps.push_back(opSet); + } + }); + } + + HoistOpSet getResults() { return hoistedOps; } + +private: + HoistOpSet hoistedOps; +}; + +// Transform pass to hoist specific ops (based on configured analysis pass) into +// a cpu submodule for later independent lowering. +class TTIRHoistTransform + : public impl::TTIRHoistTransformBase { +public: + using impl::TTIRHoistTransformBase< + TTIRHoistTransform>::TTIRHoistTransformBase; + + void runOnOperation() final { + mlir::ModuleOp moduleOp = getOperation(); + + IRRewriter rewriter(&getContext()); + + auto loc = moduleOp->getLoc(); + + TTIRHoistAnalyze analysisPass(moduleOp); + const TTIRHoistAnalyze::HoistOpSet &hoistOpSets = analysisPass.getResults(); + + // Check if a "cpu_module" already exists. + mlir::ModuleOp cpuModule; + for (auto &op : moduleOp.getBody()->getOperations()) { + if (auto module = llvm::dyn_cast(op)) { + if (module->hasAttr("ttir.cpu_module")) { + cpuModule = module; + break; + } + } + } + + // If no CPU module exists, create one. + if (!cpuModule) { + rewriter.setInsertionPointToEnd(moduleOp.getBody()); + cpuModule = rewriter.create(loc); + cpuModule->setAttr("ttir.cpu_module", rewriter.getUnitAttr()); + cpuModule->setAttr(mlir::SymbolTable::getSymbolAttrName(), + rewriter.getStringAttr("cpu_module")); + // try to make cpu module global + mlir::SymbolTable::setSymbolVisibility( + cpuModule, mlir::SymbolTable::Visibility::Public); + } + + for (const auto &opSet : hoistOpSets) { + assert(opSet.size() == 1 && + "currently don't support hoisting multiple instructions at once!"); + hoistOperationToFunction(*opSet.begin(), moduleOp, cpuModule); + } + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index cfe0dacc3..6a0fed3e8 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -4,6 +4,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" diff --git a/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir b/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir new file mode 100644 index 000000000..7a4d788c4 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir @@ -0,0 +1,37 @@ +// RUN: ttmlir-opt --ttir-cpu-hoist-transform %s | FileCheck %s + +func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} + +func.func @add2(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} + +func.func @add3(%arg0: tensor<32x3xf32>, %arg1: tensor<32x3xf32>) -> tensor<32x3xf32> { + %0 = tensor.empty() : tensor<32x3xf32> +// CHECK: %{{.*}} = call @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x3xf32>, tensor<32x3xf32>, tensor<32x3xf32>) -> tensor<32x3xf32> + return %1 : tensor<32x3xf32> +} + +func.func @add4(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + // CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} +// CHECK: module @cpu_module attributes {ttir.cpu_module} +// CHECK: func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func +// CHECK: func.func @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func +// CHECK: func.func @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func + +// CHECK: func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func +// CHECK: func.func private @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func +// CHECK: func.func private @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func