Skip to content

Commit

Permalink
Add pass to emit helper funcs so that we have common call signature
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT committed Dec 16, 2024
1 parent 28c7381 commit 7f71e72
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 9 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TTIR)
add_subdirectory(TTNN)
add_subdirectory(TTMetal)
add_subdirectory(TTKernel)
add_subdirectory(LLVM)
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/LLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Transforms)
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/LLVM/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc --gen-pass-decls)
add_public_tablegen_target(MLIRLLVMPassesIncGen)
add_dependencies(mlir-headers MLIRLLVMPassesIncGen)
19 changes: 19 additions & 0 deletions include/ttmlir/Dialect/LLVM/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_LLVM_TRANSFORMS_PASSES_H
#define TTMLIR_DIALECT_LLVM_TRANSFORMS_PASSES_H

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::tt::llvm_util {
#define GEN_PASS_DECL
#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc"

#define GEN_PASS_REGISTRATION
#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc"
} // namespace mlir::tt::llvm_util

#endif
16 changes: 16 additions & 0 deletions include/ttmlir/Dialect/LLVM/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_LLVM_LLVMPASSES_TD
#define TTMLIR_TTMLIR_DIALECT_LLVM_LLVMPASSES_TD

include "mlir/Pass/PassBase.td"

def LLVMEmitHelperFuncs: Pass<"emit-llvm-helpers", "::mlir::ModuleOp">
{
let summary = "Helper function for emitting standardized call-format for all of our functions lowered to LLVMDialect";
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

#endif
1 change: 1 addition & 0 deletions lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ add_subdirectory(TTIR)
add_subdirectory(TTNN)
add_subdirectory(TTMetal)
add_subdirectory(TTKernel)
add_subdirectory(LLVM)
1 change: 1 addition & 0 deletions lib/Dialect/LLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Transforms)
11 changes: 11 additions & 0 deletions lib/Dialect/LLVM/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_mlir_dialect_library(MLIRLLVMTransforms
EmitHelperFuncs.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir

DEPENDS
MLIRTTIROpsIncGen
MLIRTTIRPassesIncGen
MLIRTTOpsIncGen
)
164 changes: 164 additions & 0 deletions lib/Dialect/LLVM/Transforms/EmitHelperFuncs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // For LLVM Dialect definitions
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // For LLVM Type support (e.g., LLVMStructType, LLVMPointerType)

#include "llvm/ADT/ArrayRef.h" // For ArrayRef
#include "llvm/ADT/SmallVector.h" // For SmallVector
#include "llvm/Support/Casting.h" // For dyn_cast

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::tt::llvm_util {
#define GEN_PASS_DEF_LLVMEMITHELPERFUNCS
#include "ttmlir/Dialect/LLVM/Transforms/Passes.h.inc"

void generateLLVMHelpersForArgRanks(mlir::ModuleOp moduleOp) {
auto *context = moduleOp.getContext();
OpBuilder builder(context);

for (auto func : moduleOp.getOps<LLVM::LLVMFuncOp>()) {
if (!func->hasAttr("arg_ranks")) {
continue;
}

// Extract the `arg_ranks` attribute
auto argRanksAttr = llvm::dyn_cast<ArrayAttr>(func->getAttr("arg_ranks"));
if (!argRanksAttr) {
continue;
}

builder.setInsertionPointToEnd(moduleOp.getBody());

// Define the helper function name and type
std::string helperName = func.getName().str() + "_helper";

// Create the helper function
auto helperFuncType = LLVM::LLVMFunctionType::get(
LLVM::LLVMVoidType::get(context), {LLVM::LLVMPointerType::get(context)},
false);

auto helperFunc = builder.create<LLVM::LLVMFuncOp>(
func.getLoc(), helperName, helperFuncType);

Block *entryBlock = helperFunc.addEntryBlock(builder);
builder.setInsertionPointToStart(entryBlock);

// Unpack the argument
Value structArrayPtr = entryBlock->getArgument(0);
SmallVector<Value, 16> originalCallArgs;

// Iterate over arg_ranks to unpack tensors
int tensorIdx = 0;
for (auto rankAttr : argRanksAttr) {
Value tensorIndex = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI32Type(),
builder.getI32IntegerAttr(tensorIdx++));

Value structPtr = builder.create<LLVM::GEPOp>(
func.getLoc(), LLVM::LLVMPointerType::get(context),
LLVM::LLVMPointerType::get(context), structArrayPtr,
ValueRange(tensorIndex));

int64_t rank = mlir::cast<IntegerAttr>(rankAttr).getInt();

Value index = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0));
// `start`
Value tensorBase = builder.create<LLVM::GEPOp>(
func.getLoc(), LLVM::LLVMPointerType::get(context),
LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index});

index = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(1));
// `aligned_start`
Value alignedBase = builder.create<LLVM::GEPOp>(
func.getLoc(), LLVM::LLVMPointerType::get(context),
LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index});

index = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(2));
// `start_idx`
Value startIdxPtr = builder.create<LLVM::GEPOp>(
func.getLoc(), LLVM::LLVMPointerType::get(context),
builder.getI64Type(), structPtr, ValueRange{index});
// Convert the pointer to an integer (i64)
Value startIdx = builder.create<LLVM::PtrToIntOp>(
func.getLoc(), builder.getI64Type(), startIdxPtr);

index = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(3));
// `sizes_and_strides`
Value sizesAndStrides = builder.create<LLVM::GEPOp>(
func.getLoc(), LLVM::LLVMPointerType::get(context),
LLVM::LLVMPointerType::get(context), structPtr, ValueRange{index});

originalCallArgs.push_back(tensorBase);
originalCallArgs.push_back(alignedBase);
originalCallArgs.push_back(startIdx);

// Iterate over size and stride pairs
for (int i = 0; i < 2 * rank; i++) {
// Compute the address of the i-th element
Value idx = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(i));

Value elementPtr = builder.create<LLVM::GEPOp>(
func.getLoc(),
LLVM::LLVMPointerType::get(context), // Pointer to i64
builder.getI64Type(),
sizesAndStrides, // Base pointer
ValueRange{idx} // Offset
);

// Load the value from the computed address
Value strideOrSize = builder.create<LLVM::LoadOp>(
func.getLoc(),
builder.getI64Type(), // Type of the loaded value
elementPtr // Computed address
);

// Add the loaded value to the call arguments
originalCallArgs.push_back(strideOrSize);
}
}

// Call the original function
builder.create<LLVM::CallOp>(func.getLoc(),
func.getFunctionType().getReturnType(),
func.getName(), originalCallArgs);

// Return the result
builder.create<LLVM::ReturnOp>(func.getLoc(), ValueRange());
}
}

class LLVMEmitHelperFuncs
: public impl::LLVMEmitHelperFuncsBase<LLVMEmitHelperFuncs> {
using impl::LLVMEmitHelperFuncsBase<
LLVMEmitHelperFuncs>::LLVMEmitHelperFuncsBase;
// using impl::createLLVMEmitHelperFuncs;

void runOnOperation() final {
auto moduleOp = getOperation();
// only run this on our hoisted cpu op modules
if (!moduleOp->getAttr("ttir.cpu_module")) {
return;
}
generateLLVMHelpersForArgRanks(moduleOp);

// for every func in this module, emit a corresponding unpacker
}
};

} // namespace mlir::tt::llvm_util
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Pipelines/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTTIRPipelines
${PROJECT_SOURCE_DIR}/include/ttmlir

LINK_LIBS PUBLIC
MLIRLLVMTransforms
MLIRTTIRDialect
MLIRPass
MLIRTransforms
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Transforms/Passes.h"

#include "ttmlir/Conversion/Passes.h"
#include "ttmlir/Dialect/LLVM/Transforms/Passes.h"

#ifdef TTMLIR_ENABLE_STABLEHLO
#include "stablehlo/transforms/Passes.h"
Expand Down
21 changes: 12 additions & 9 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "ttmlir/Conversion/Passes.h"
#include "ttmlir/Dialect/LLVM/Transforms/Passes.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
Expand All @@ -24,15 +25,16 @@
#endif

void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
registry
.insert<mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect,
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect>();
registry.insert<
mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect,
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect,
mlir::LLVM::LLVMDialect>();
#if TTMLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
#endif
Expand All @@ -56,6 +58,7 @@ void mlir::tt::registerAllPasses() {
mlir::tt::ttnn::registerTTNNOptimizer();
mlir::tt::ttnn::registerPasses();
mlir::tt::ttmetal::registerPasses();
mlir::tt::llvm_util::registerPasses();

// Pipeline registration
mlir::tt::ttir::registerTTIRPipelines();
Expand Down

0 comments on commit 7f71e72

Please sign in to comment.