Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Pipeline/Transformation from TTIR -> Flatbuffer (#314) #406

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions include/ttmlir/Bindings/Python/Overrides.h

This file was deleted.

8 changes: 8 additions & 0 deletions include/ttmlir/Bindings/Python/TTMLIRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,27 @@
#ifndef TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H
#define TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "ttmlir-c/Dialects.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/TTNNToCpp.h"
#include "ttmlir/RegisterAll.h"

namespace py = pybind11;

namespace mlir::ttmlir::python {
void populateTTModule(py::module &m);
void populateTTKernelModule(py::module &m);
void populateOverridesModule(py::module &m);
void populatePassesModule(py::module &m);
} // namespace mlir::ttmlir::python

#endif // TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H
22 changes: 22 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ set(TTMLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/ttmlir")

add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=ttmlir.")

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)

declare_mlir_python_sources(TTMLIRPythonSources)
declare_mlir_python_sources(TTMLIRPythonExtensions)

Expand Down Expand Up @@ -40,6 +44,18 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME ttkernel
)

declare_mlir_python_sources(TTMLIRPythonSources.Overrides
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonSources
SOURCES overrides.py
)

declare_mlir_python_sources(TTMLIRPythonSources.Passes
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonSources
SOURCES passes.py
)

declare_mlir_python_extension(TTMLIRPythonExtensions.Main
MODULE_NAME _ttmlir
ADD_TO_PARENT TTMLIRPythonExtensions
Expand All @@ -48,11 +64,17 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
TTModule.cpp
TTKernelModule.cpp
Overrides.cpp
Passes.cpp
EMBED_CAPI_LINK_LIBS
MLIRCAPITransforms
TTMLIRCAPI
PRIVATE_LINK_LIBS
LLVMSupport
${dialect_libs}
${conversion_libs}
${translation_libs}
MLIR
TTMLIRStatic
)

set(TTMLIR_PYTHON_SOURCES
Expand Down
2 changes: 1 addition & 1 deletion python/Overrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Bindings/Python/Overrides.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"

namespace mlir::ttmlir::python {

Expand Down
54 changes: 54 additions & 0 deletions python/Passes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/InitAllTranslations.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include "ttmlir/RegisterAll.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"

PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);

namespace mlir::ttmlir::python {

void populatePassesModule(py::module &m) {
// When populating passes, need to first register them

mlir::tt::registerAllPasses();
mlir::registerAllTranslations();

m.def("ttir_to_ttnn_backend_pipeline", [](MlirModule module) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
mlir::PassManager pm(moduleOp->getName());

mlir::DialectRegistry registry;
mlir::tt::registerAllDialects(registry);
mlir::MLIRContext *ctx = unwrap(mlirModuleGetContext(module));
ctx->appendDialectRegistry(registry);

const auto pipeline =
mlir::PassPipelineInfo::lookup("ttir-to-ttnn-backend-pipeline");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just call createTTIRToTTNNBackendPipeline directly?

  TTIRToTTNNBackendPipelineOptions options;
  //.. maybe set options ..
  createTTIRToTTNNBackendPipeline(pm, options);
  if (mlir::failed(pm.run(moduleOp))) {
      throw std::runtime_error("Failed to run pass manager");
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran it this way in case there was a more involved method to call Pipeline Options. If it's just as simple as editing the struct then that sounds good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned offline, this task will be completed later while I focus on getting E2E working with explorer. Refer to #446 for when that gets completed.


std::string options = "";

mlir::function_ref<mlir::LogicalResult(const llvm::Twine &)> err_handler =
[](const llvm::Twine &loc) { return mlir::failure(); };

if (mlir::failed(pipeline->addToPipeline(pm, options, err_handler))) {
throw std::runtime_error("Failed to add pipeline to pass manager");
}

if (mlir::failed(pm.run(moduleOp))) {
throw std::runtime_error("Failed to run pass manager");
}
});

py::class_<std::shared_ptr<void>>(m, "SharedVoidPtr")
.def(py::init<>())
.def("from_ttnn", [](std::shared_ptr<void> data, MlirModule module) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
data = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp);
});
}

} // namespace mlir::ttmlir::python
8 changes: 5 additions & 3 deletions python/TTMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include "ttmlir/Bindings/Python/Overrides.h"

PYBIND11_MODULE(_ttmlir, m) {
m.doc() = "ttmlir main python extension";
Expand All @@ -29,6 +28,9 @@ PYBIND11_MODULE(_ttmlir, m) {
mlir::ttmlir::python::populateTTModule(tt_ir);
auto ttkernel_ir = m.def_submodule("ttkernel_ir", "TTKernel IR Bindings");
mlir::ttmlir::python::populateTTKernelModule(ttkernel_ir);
auto overrides_ = m.def_submodule("overrides", "Python-Bound Overrides");
mlir::ttmlir::python::populateOverridesModule(overrides_);
auto overrides = m.def_submodule("overrides", "Python-Bound Overrides");
mlir::ttmlir::python::populateOverridesModule(overrides);
auto passes =
m.def_submodule("passes", "Python-Bound Passes & Transformations");
mlir::ttmlir::python::populatePassesModule(passes);
}
1 change: 0 additions & 1 deletion python/ttmlir/dialects/ttir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@

from ._ttir_ops_gen import *
from .._mlir_libs._ttmlir import register_dialect
from .._mlir_libs._ttmlir import overrides
5 changes: 5 additions & 0 deletions python/ttmlir/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from ._mlir_libs._ttmlir import overrides
5 changes: 5 additions & 0 deletions python/ttmlir/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from ._mlir_libs._ttmlir import passes
Loading