From 7f71e72af56ab3d9bccac7f6977ba63c734832c4 Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Fri, 13 Dec 2024 09:35:52 -0600 Subject: [PATCH] Add pass to emit helper funcs so that we have common call signature --- include/ttmlir/Dialect/CMakeLists.txt | 1 + include/ttmlir/Dialect/LLVM/CMakeLists.txt | 1 + .../Dialect/LLVM/Transforms/CMakeLists.txt | 4 + .../ttmlir/Dialect/LLVM/Transforms/Passes.h | 19 ++ .../ttmlir/Dialect/LLVM/Transforms/Passes.td | 16 ++ lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/LLVM/CMakeLists.txt | 1 + lib/Dialect/LLVM/Transforms/CMakeLists.txt | 11 ++ .../LLVM/Transforms/EmitHelperFuncs.cpp | 164 ++++++++++++++++++ lib/Dialect/TTIR/Pipelines/CMakeLists.txt | 1 + lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp | 1 + lib/RegisterAll.cpp | 21 ++- 12 files changed, 232 insertions(+), 9 deletions(-) create mode 100644 include/ttmlir/Dialect/LLVM/CMakeLists.txt create mode 100644 include/ttmlir/Dialect/LLVM/Transforms/CMakeLists.txt create mode 100644 include/ttmlir/Dialect/LLVM/Transforms/Passes.h create mode 100644 include/ttmlir/Dialect/LLVM/Transforms/Passes.td create mode 100644 lib/Dialect/LLVM/CMakeLists.txt create mode 100644 lib/Dialect/LLVM/Transforms/CMakeLists.txt create mode 100644 lib/Dialect/LLVM/Transforms/EmitHelperFuncs.cpp diff --git a/include/ttmlir/Dialect/CMakeLists.txt b/include/ttmlir/Dialect/CMakeLists.txt index 5c05b3097a..7f9046da8d 100644 --- a/include/ttmlir/Dialect/CMakeLists.txt +++ b/include/ttmlir/Dialect/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TTIR) add_subdirectory(TTNN) add_subdirectory(TTMetal) add_subdirectory(TTKernel) +add_subdirectory(LLVM) diff --git a/include/ttmlir/Dialect/LLVM/CMakeLists.txt b/include/ttmlir/Dialect/LLVM/CMakeLists.txt new file mode 100644 index 0000000000..e31af32661 --- /dev/null +++ b/include/ttmlir/Dialect/LLVM/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/include/ttmlir/Dialect/LLVM/Transforms/CMakeLists.txt b/include/ttmlir/Dialect/LLVM/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..eedb956d66 --- /dev/null +++ b/include/ttmlir/Dialect/LLVM/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc --gen-pass-decls) +add_public_tablegen_target(MLIRLLVMPassesIncGen) +add_dependencies(mlir-headers MLIRLLVMPassesIncGen) diff --git a/include/ttmlir/Dialect/LLVM/Transforms/Passes.h b/include/ttmlir/Dialect/LLVM/Transforms/Passes.h new file mode 100644 index 0000000000..8cf2bbd9de --- /dev/null +++ b/include/ttmlir/Dialect/LLVM/Transforms/Passes.h @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_LLVM_TRANSFORMS_PASSES_H +#define TTMLIR_DIALECT_LLVM_TRANSFORMS_PASSES_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::tt::llvm_util { +#define GEN_PASS_DECL +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc" +} // namespace mlir::tt::llvm_util + +#endif diff --git a/include/ttmlir/Dialect/LLVM/Transforms/Passes.td b/include/ttmlir/Dialect/LLVM/Transforms/Passes.td new file mode 100644 index 0000000000..77abd91583 --- /dev/null +++ b/include/ttmlir/Dialect/LLVM/Transforms/Passes.td @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTMLIR_DIALECT_LLVM_LLVMPASSES_TD +#define TTMLIR_TTMLIR_DIALECT_LLVM_LLVMPASSES_TD + +include "mlir/Pass/PassBase.td" + +def LLVMEmitHelperFuncs: Pass<"emit-llvm-helpers", "::mlir::ModuleOp"> +{ + let summary = "Helper function for emitting standardized call-format for all of our functions lowered to LLVMDialect"; + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + +#endif diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 5c05b3097a..7f9046da8d 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TTIR) add_subdirectory(TTNN) add_subdirectory(TTMetal) add_subdirectory(TTKernel) +add_subdirectory(LLVM) diff --git a/lib/Dialect/LLVM/CMakeLists.txt b/lib/Dialect/LLVM/CMakeLists.txt new file mode 100644 index 0000000000..e31af32661 --- /dev/null +++ b/lib/Dialect/LLVM/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/lib/Dialect/LLVM/Transforms/CMakeLists.txt b/lib/Dialect/LLVM/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..eabc811356 --- /dev/null +++ b/lib/Dialect/LLVM/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRLLVMTransforms + EmitHelperFuncs.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ttmlir + + DEPENDS + MLIRTTIROpsIncGen + MLIRTTIRPassesIncGen + MLIRTTOpsIncGen + ) diff --git a/lib/Dialect/LLVM/Transforms/EmitHelperFuncs.cpp b/lib/Dialect/LLVM/Transforms/EmitHelperFuncs.cpp new file mode 100644 index 0000000000..cd54986f65 --- /dev/null +++ b/lib/Dialect/LLVM/Transforms/EmitHelperFuncs.cpp @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // For LLVM Dialect definitions +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // For LLVM Type support (e.g., LLVMStructType, LLVMPointerType) + +#include "llvm/ADT/ArrayRef.h" // For ArrayRef +#include "llvm/ADT/SmallVector.h" // For SmallVector +#include "llvm/Support/Casting.h" // For dyn_cast + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::tt::llvm_util { +#define GEN_PASS_DEF_LLVMEMITHELPERFUNCS +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc" + +void generateLLVMHelpersForArgRanks(mlir::ModuleOp moduleOp) { + auto *context = moduleOp.getContext(); + OpBuilder builder(context); + + for (auto func : moduleOp.getOps()) { + if (!func->hasAttr("arg_ranks")) { + continue; + } + + // Extract the `arg_ranks` attribute + auto argRanksAttr = llvm::dyn_cast(func->getAttr("arg_ranks")); + if (!argRanksAttr) { + continue; + } + + builder.setInsertionPointToEnd(moduleOp.getBody()); + + // Define the helper function name and type + std::string helperName = func.getName().str() + "_helper"; + + // Create the helper function + auto helperFuncType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(context), {LLVM::LLVMPointerType::get(context)}, + false); + + auto helperFunc = builder.create( + func.getLoc(), helperName, helperFuncType); + + Block *entryBlock = helperFunc.addEntryBlock(builder); + builder.setInsertionPointToStart(entryBlock); + + // Unpack the argument + Value structArrayPtr = entryBlock->getArgument(0); + SmallVector originalCallArgs; + + // Iterate over arg_ranks to unpack tensors + int tensorIdx = 0; + for (auto rankAttr : argRanksAttr) { + Value tensorIndex = builder.create( + func.getLoc(), builder.getI32Type(), + builder.getI32IntegerAttr(tensorIdx++)); + + Value structPtr = builder.create( + func.getLoc(), LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), structArrayPtr, + ValueRange(tensorIndex)); + + int64_t rank = mlir::cast(rankAttr).getInt(); + + Value index = builder.create( + func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + // `start` + Value tensorBase = builder.create( + func.getLoc(), LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index}); + + index = builder.create( + func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(1)); + // `aligned_start` + Value alignedBase = builder.create( + func.getLoc(), LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index}); + + index = builder.create( + func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(2)); + // `start_idx` + Value startIdxPtr = builder.create( + func.getLoc(), LLVM::LLVMPointerType::get(context), + builder.getI64Type(), structPtr, ValueRange{index}); + // Convert the pointer to an integer (i64) + Value startIdx = builder.create( + func.getLoc(), builder.getI64Type(), startIdxPtr); + + index = builder.create( + func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(3)); + // `sizes_and_strides` + Value sizesAndStrides = builder.create( + func.getLoc(), LLVM::LLVMPointerType::get(context), + LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index}); + + originalCallArgs.push_back(tensorBase); + originalCallArgs.push_back(alignedBase); + originalCallArgs.push_back(startIdx); + + // Iterate over size and stride pairs + for (int i = 0; i < 2 * rank; i++) { + // Compute the address of the i-th element + Value idx = builder.create( + func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(i)); + + Value elementPtr = builder.create( + func.getLoc(), + LLVM::LLVMPointerType::get(context), // Pointer to i64 + builder.getI64Type(), + sizesAndStrides, // Base pointer + ValueRange{idx} // Offset + ); + + // Load the value from the computed address + Value strideOrSize = builder.create( + func.getLoc(), + builder.getI64Type(), // Type of the loaded value + elementPtr // Computed address + ); + + // Add the loaded value to the call arguments + originalCallArgs.push_back(strideOrSize); + } + } + + // Call the original function + builder.create(func.getLoc(), + func.getFunctionType().getReturnType(), + func.getName(), originalCallArgs); + + // Return the result + builder.create(func.getLoc(), ValueRange()); + } +} + +class LLVMEmitHelperFuncs + : public impl::LLVMEmitHelperFuncsBase { + using impl::LLVMEmitHelperFuncsBase< + LLVMEmitHelperFuncs>::LLVMEmitHelperFuncsBase; + // using impl::createLLVMEmitHelperFuncs; + + void runOnOperation() final { + auto moduleOp = getOperation(); + // only run this on our hoisted cpu op modules + if (!moduleOp->getAttr("ttir.cpu_module")) { + return; + } + generateLLVMHelpersForArgRanks(moduleOp); + + // for every func in this module, emit a corresponding unpacker + } +}; + +} // namespace mlir::tt::llvm_util diff --git a/lib/Dialect/TTIR/Pipelines/CMakeLists.txt b/lib/Dialect/TTIR/Pipelines/CMakeLists.txt index 3296c3f11c..7ae38f5ab0 100644 --- a/lib/Dialect/TTIR/Pipelines/CMakeLists.txt +++ b/lib/Dialect/TTIR/Pipelines/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTTIRPipelines ${PROJECT_SOURCE_DIR}/include/ttmlir LINK_LIBS PUBLIC + MLIRLLVMTransforms MLIRTTIRDialect MLIRPass MLIRTransforms diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index a092a36e1c..562d8ec722 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/Passes.h" #include "ttmlir/Conversion/Passes.h" +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h" #ifdef TTMLIR_ENABLE_STABLEHLO #include "stablehlo/transforms/Passes.h" diff --git a/lib/RegisterAll.cpp b/lib/RegisterAll.cpp index 9ee4c30e16..bc2361c90c 100644 --- a/lib/RegisterAll.cpp +++ b/lib/RegisterAll.cpp @@ -9,6 +9,7 @@ #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "ttmlir/Conversion/Passes.h" +#include "ttmlir/Dialect/LLVM/Transforms/Passes.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" @@ -24,15 +25,16 @@ #endif void mlir::tt::registerAllDialects(mlir::DialectRegistry ®istry) { - registry - .insert(); + registry.insert< + mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect, + mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect, + mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect, + mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect, + mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect, + mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, + mlir::tosa::TosaDialect, mlir::vector::VectorDialect, + mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect, + mlir::LLVM::LLVMDialect>(); #if TTMLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); #endif @@ -56,6 +58,7 @@ void mlir::tt::registerAllPasses() { mlir::tt::ttnn::registerTTNNOptimizer(); mlir::tt::ttnn::registerPasses(); mlir::tt::ttmetal::registerPasses(); + mlir::tt::llvm_util::registerPasses(); // Pipeline registration mlir::tt::ttir::registerTTIRPipelines();