Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableHLO to TTIR build #353

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.local
build
third_party/stablehlo
third_party/tt-metal
.DS_STORE
.vscode/*
Expand Down
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ endif()

option(TT_RUNTIME_ENABLE_PERF_TRACE "Enable performance mode" OFF)
option(TTMLIR_ENABLE_RUNTIME "Enable runtime" OFF)
option(TTMLIR_ENABLE_STABLEHLO "Enable StableHLO support" ON)

if (TTMLIR_ENABLE_STABLEHLO)
add_compile_definitions(TTMLIR_ENABLE_STABLEHLO)
endif()

set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)

Expand Down Expand Up @@ -39,6 +44,7 @@ set(TTMLIR_TOOLCHAIN_DIR $ENV{TTMLIR_TOOLCHAIN_DIR})
set(TTMLIR_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(TTMLIR_BINARY_DIR ${PROJECT_BINARY_DIR})
set(LLVM_LIT_TOOLS_DIR "${TTMLIR_TOOLCHAIN_DIR}/src/llvm-project/llvm/utils/lit")
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})
include_directories(SYSTEM ${MLIR_INCLUDE_DIRS})
include_directories(${TTMLIR_SOURCE_DIR}/include)
Expand All @@ -47,6 +53,12 @@ link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
include(TTMLIRPythonSitePackages)

if (TTMLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/stablehlo ${CMAKE_CURRENT_BINARY_DIR}/stablehlo EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/stablehlo)
endif()

add_subdirectory(third_party)
add_subdirectory(include)
add_subdirectory(lib)
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build)
include_directories(${TTMLIR_SOURCE_DIR}/include)

set(LLVM_TARGET_DEFINITIONS Passes.td)
if (TTMLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion -DTTMLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion)
endif()
add_public_tablegen_target(TTMLIRConversionPassIncGen)
3 changes: 3 additions & 0 deletions include/ttmlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#ifndef TTMLIR_CONVERSION_PASSES_H
#define TTMLIR_CONVERSION_PASSES_H

#ifdef TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
nsmithtt marked this conversation as resolved.
Show resolved Hide resolved
#endif
#include "ttmlir/Conversion/TTIRToTTNN/TTIRToTTNN.h"
#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h"
#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h"
Expand Down
9 changes: 9 additions & 0 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@

include "mlir/Pass/PassBase.td"

#ifdef TTMLIR_ENABLE_STABLEHLO
def ConvertStableHLOToTTIR : Pass<"convert-stablehlo-to-ttir", "::mlir::ModuleOp"> {
uazizTT marked this conversation as resolved.
Show resolved Hide resolved
let summary = "Convert StableHLO dialect to TTIR dialect.";
let constructor = "createConvertStableHLOToTTIRPass()";
// TODO(mrakita): Probably will need to add some include here for StableHLO dialect.
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::tt::ttir::TTIRDialect"];
}
#endif

def ConvertTosaToTTIR : Pass<"convert-tosa-to-ttir", "::mlir::ModuleOp"> {
let summary = "Convert TOSA dialect to TTIR dialect.";
let constructor = "createConvertTosaToTTIRPass()";
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_StableHLOToTTIR_StableHLOToTTIR_H
#define TTMLIR_CONVERSION_StableHLOToTTIR_StableHLOToTTIR_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All caps, I think clang-tidy will complain



#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::tt {

#ifdef TTMLIR_ENABLE_STABLEHLO
std::unique_ptr<OperationPass<ModuleOp>> createConvertStableHLOToTTIRPass();
#endif

} // namespace mlir::tt

#endif
21 changes: 17 additions & 4 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
add_library(TTMLIRConversions INTERFACE)

add_subdirectory(TosaToTTIR)
add_subdirectory(TTNNToEmitC)
add_subdirectory(TTIRToTTNN)
if (TTMLIR_ENABLE_STABLEHLO)
add_subdirectory(StableHLOToTTIR)
endif()

add_library(TTMLIRConversions INTERFACE)
include_directories(${TTMLIR_SOURCE_DIR}/include)

set(link_libs
TTMLIRTosaToTTIR;
TTMLIRTTNNToEmitC;
TTMLIRTTIRToTTNN
)

if (TTMLIR_ENABLE_STABLEHLO)
list(APPEND link_libs TTMLIRStableHLOToTTIR)
endif()

target_link_libraries(TTMLIRConversions INTERFACE
TTMLIRTosaToTTIR
TTMLIRTTNNToEmitC
TTMLIRTTIRToTTNN
${link_libs}
)
18 changes: 18 additions & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build)
uazizTT marked this conversation as resolved.
Show resolved Hide resolved
include_directories(${TTMLIR_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/build/stablehlo)

add_mlir_library(TTMLIRStableHLOToTTIR
StableHLOToTTIR.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR

DEPENDS
TTMLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIR
)
103 changes: 103 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIR.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/ValueRange.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/DialectConversion.h>

#include "stablehlo/dialect/StablehloOps.h"
using namespace mlir;
using namespace tt;

namespace mlir::tt::ttir {

#define GEN_PASS_DEF_CONVERTSTABLEHLOTOTTIR
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir

namespace {

template <typename SrcOp, typename DestOp,
typename Adaptor = typename SrcOp::Adaptor>
class StableHLOToTTIROpConversionPattern : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(srcOp.getResult().getType());
auto outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<DestOp>(
srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(),
ValueRange(outputTensor),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

struct ConvertStableHLOToTTIRPass
: public ttir::impl::ConvertStableHLOToTTIRBase<ConvertStableHLOToTTIRPass> {
void runOnOperation() override {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<mlir::stablehlo::StablehloDialect>();

target.addLegalDialect<ttir::TTIRDialect>();
target.addLegalOp<mlir::tensor::EmptyOp>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
target.addLegalOp<mlir::func::ReturnOp>();

// For now keep the same type assuming StableHLO ops operate on builtin tensor.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) {
assert(isa<RankedTensorType>(type) &&
"only ranked tensor type supported");
return type;
});
RewritePatternSet patterns(&getContext());

// Add conversion patterns.
patterns
.add<StableHLOToTTIROpConversionPattern<mlir::stablehlo::AddOp, mlir::tt::ttir::AddOp>>(
typeConverter, &getContext());

// Apply conversion.
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<ModuleOp>> createConvertStableHLOToTTIRPass() {
return std::make_unique<ConvertStableHLOToTTIRPass>();
}

} // namespace mlir::tt
6 changes: 6 additions & 0 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/Pipelines/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#ifdef TTMLIR_ENABLE_STABLEHLO
#include "stablehlo/dialect/Register.h"
#endif

void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
registry
Expand All @@ -28,6 +31,9 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect>();
#if TTMLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
#endif
}

void mlir::tt::registerAllPasses() {
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTIR/stablehlo_to_ttir_addop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --convert-stablehlo-to-ttir %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_add attributes {} {
func.func public @test_add(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.add"[[C:.*]]
return %0 : tensor<13x21x3xf32>
}
}
17 changes: 16 additions & 1 deletion test/ttmlir/ttmlir-opt.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// RUN: ttmlir-opt --show-dialects | FileCheck %s
// CHECK: Available Dialects:
// CHECK-SAME: arith,builtin,cf,emitc,func,linalg,ml_program,scf,tensor,tosa,tt,ttir,ttkernel,ttmetal,ttnn,vector
// CHECK: arith
// CHECK: builtin
// CHECK: cf
// CHECK: emitc
// CHECK: func
// CHECK: linalg
// CHECK: ml_program
// CHECK: scf
// CHECK: tensor
// CHECK: tosa
// CHECK: tt
// CHECK: ttir
// CHECK: ttkernel
// CHECK: ttmetal
// CHECK: ttnn
// CHECK: vector
65 changes: 3 additions & 62 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,65 +1,6 @@
include(ExternalProject)

if ("$ENV{ARCH_NAME}" STREQUAL "grayskull")
set(ARCH_NAME "grayskull")
set(ARCH_EXTRA_DIR "grayskull")
elseif ("$ENV{ARCH_NAME}" STREQUAL "wormhole_b0")
set(ARCH_NAME "wormhole")
set(ARCH_EXTRA_DIR "wormhole/wormhole_b0_defines")
elseif ("$ENV{ARCH_NAME}" STREQUAL "blackhole")
set(ARCH_NAME "blackhole")
set(ARCH_EXTRA_DIR "blackhole")
else()
message(FATAL_ERROR "Unsupported ARCH_NAME: $ENV{ARCH_NAME}")
include (tt-metal.cmake)
if (TTMLIR_ENABLE_STABLEHLO)
include(stablehlo.cmake)
endif()

set(TTMETAL_INCLUDE_DIRS
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/ttnn/cpp
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/deprecated
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/fmt
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_NAME}
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_EXTRA_DIR}
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/src/firmware/riscv/${ARCH_NAME}
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_eager
${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f
PARENT_SCOPE
)

set(TTMETAL_LIBRARY_DIR ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/lib)
set(TTNN_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/_ttnn.so)
set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libtt_metal.so)

set(TTMETAL_LIBRARY_DIR ${TTMETAL_LIBRARY_DIR} PARENT_SCOPE)
set(TTNN_LIBRARY_PATH ${TTNN_LIBRARY_PATH} PARENT_SCOPE)
set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_PATH} PARENT_SCOPE)

ExternalProject_Add(
tt-metal
PREFIX ${TTMLIR_SOURCE_DIR}/third_party/tt-metal
CMAKE_GENERATOR Ninja
CMAKE_ARGS
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}
-DENABLE_TRACY=${TT_RUNTIME_ENABLE_PERF_TRACE}
GIT_REPOSITORY https://github.com/tenstorrent/tt-metal.git
GIT_TAG f6a2e5cb2b857bf4c72401bea68adf98c25bbe47
GIT_PROGRESS ON
BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH}
)

set_target_properties(tt-metal PROPERTIES EXCLUDE_FROM_ALL TRUE)

list(APPEND library_names TTNN_LIBRARY TTMETAL_LIBRARY)
list(APPEND library_paths ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH})

foreach(lib_name lib_path IN ZIP_LISTS library_names library_paths)
add_library(${lib_name} SHARED IMPORTED GLOBAL)
set_target_properties(${lib_name} PROPERTIES EXCLUDE_FROM_ALL TRUE IMPORTED_LOCATION ${lib_path})
add_dependencies(${lib_name} tt-metal)
endforeach()
16 changes: 16 additions & 0 deletions third_party/stablehlo.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
include(ExternalProject)

ExternalProject_Add(
tt-metal
PREFIX ${TTMLIR_SOURCE_DIR}/third_party/tt-metal
CMAKE_GENERATOR Ninja
CMAKE_ARGS
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}
-DENABLE_TRACY=${TT_RUNTIME_ENABLE_PERF_TRACE}
GIT_REPOSITORY https://github.com/openxla/stablehlo.git
GIT_TAG v1.5.0
GIT_PROGRESS ON
)
Loading
Loading