diff --git a/include/ttmlir/Bindings/Python/Overrides.h b/include/ttmlir/Bindings/Python/Overrides.h deleted file mode 100644 index 8c5604380..000000000 --- a/include/ttmlir/Bindings/Python/Overrides.h +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#ifndef TTMLIR_BINDINGS_PYTHON_OVERRIDES_H -#define TTMLIR_BINDINGS_PYTHON_OVERRIDES_H - -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Parser/Parser.h" -#include "ttmlir-c/Dialects.h" -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" -#include -#include -#include - -namespace py = pybind11; - -namespace mlir::ttmlir::python { -void populateOverridesModule(py::module &m); -} // namespace mlir::ttmlir::python - -#endif // TTMLIR_BINDINGS_PYTHON_OVERRIDES_H diff --git a/include/ttmlir/Bindings/Python/TTMLIRModule.h b/include/ttmlir/Bindings/Python/TTMLIRModule.h index ed93109ce..ca5f758a5 100644 --- a/include/ttmlir/Bindings/Python/TTMLIRModule.h +++ b/include/ttmlir/Bindings/Python/TTMLIRModule.h @@ -5,19 +5,27 @@ #ifndef TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H #define TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "ttmlir-c/Dialects.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Pipelines/Passes.h" +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/Dialect/TTNN/Transforms/TTNNToCpp.h" +#include "ttmlir/RegisterAll.h" namespace py = pybind11; namespace mlir::ttmlir::python { void populateTTModule(py::module &m); void populateTTKernelModule(py::module &m); +void populateOverridesModule(py::module &m); +void populatePassesModule(py::module &m); } // namespace mlir::ttmlir::python #endif // TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 999a2e5d0..ca34abf65 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -6,6 +6,10 @@ set(TTMLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/ttmlir") add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=ttmlir.") +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) + declare_mlir_python_sources(TTMLIRPythonSources) declare_mlir_python_sources(TTMLIRPythonExtensions) @@ -40,6 +44,18 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME ttkernel ) +declare_mlir_python_sources(TTMLIRPythonSources.Overrides + ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TTMLIRPythonSources + SOURCES overrides.py +) + +declare_mlir_python_sources(TTMLIRPythonSources.Passes + ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TTMLIRPythonSources + SOURCES passes.py +) + declare_mlir_python_extension(TTMLIRPythonExtensions.Main MODULE_NAME _ttmlir ADD_TO_PARENT TTMLIRPythonExtensions @@ -48,11 +64,17 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main TTModule.cpp TTKernelModule.cpp Overrides.cpp + Passes.cpp EMBED_CAPI_LINK_LIBS MLIRCAPITransforms TTMLIRCAPI PRIVATE_LINK_LIBS LLVMSupport + ${dialect_libs} + ${conversion_libs} + ${translation_libs} + MLIR + TTMLIRStatic ) set(TTMLIR_PYTHON_SOURCES diff --git a/python/Overrides.cpp b/python/Overrides.cpp index c70f6270f..b4aa1623f 100644 --- a/python/Overrides.cpp +++ b/python/Overrides.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttmlir/Bindings/Python/Overrides.h" +#include "ttmlir/Bindings/Python/TTMLIRModule.h" namespace mlir::ttmlir::python { diff --git a/python/Passes.cpp b/python/Passes.cpp new file mode 100644 index 000000000..a8964b7df --- /dev/null +++ b/python/Passes.cpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/InitAllTranslations.h" +#include "ttmlir/Bindings/Python/TTMLIRModule.h" +#include "ttmlir/RegisterAll.h" +#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" + +PYBIND11_MAKE_OPAQUE(std::shared_ptr); + +namespace mlir::ttmlir::python { + +void populatePassesModule(py::module &m) { + // When populating passes, need to first register them + + mlir::tt::registerAllPasses(); + mlir::registerAllTranslations(); + + m.def("ttir_to_ttnn_backend_pipeline", [](MlirModule module) { + mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); + mlir::PassManager pm(moduleOp->getName()); + + mlir::DialectRegistry registry; + mlir::tt::registerAllDialects(registry); + mlir::MLIRContext *ctx = unwrap(mlirModuleGetContext(module)); + ctx->appendDialectRegistry(registry); + + const auto pipeline = + mlir::PassPipelineInfo::lookup("ttir-to-ttnn-backend-pipeline"); + + std::string options = ""; + + mlir::function_ref err_handler = + [](const llvm::Twine &loc) { return mlir::failure(); }; + + if (mlir::failed(pipeline->addToPipeline(pm, options, err_handler))) { + throw std::runtime_error("Failed to add pipeline to pass manager"); + } + + if (mlir::failed(pm.run(moduleOp))) { + throw std::runtime_error("Failed to run pass manager"); + } + }); + + py::class_>(m, "SharedVoidPtr") + .def(py::init<>()) + .def("from_ttnn", [](std::shared_ptr data, MlirModule module) { + mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); + data = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp); + }); +} + +} // namespace mlir::ttmlir::python diff --git a/python/TTMLIRModule.cpp b/python/TTMLIRModule.cpp index 0aa31452e..d9f62ae7b 100644 --- a/python/TTMLIRModule.cpp +++ b/python/TTMLIRModule.cpp @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Bindings/Python/TTMLIRModule.h" -#include "ttmlir/Bindings/Python/Overrides.h" PYBIND11_MODULE(_ttmlir, m) { m.doc() = "ttmlir main python extension"; @@ -29,6 +28,9 @@ PYBIND11_MODULE(_ttmlir, m) { mlir::ttmlir::python::populateTTModule(tt_ir); auto ttkernel_ir = m.def_submodule("ttkernel_ir", "TTKernel IR Bindings"); mlir::ttmlir::python::populateTTKernelModule(ttkernel_ir); - auto overrides_ = m.def_submodule("overrides", "Python-Bound Overrides"); - mlir::ttmlir::python::populateOverridesModule(overrides_); + auto overrides = m.def_submodule("overrides", "Python-Bound Overrides"); + mlir::ttmlir::python::populateOverridesModule(overrides); + auto passes = + m.def_submodule("passes", "Python-Bound Passes & Transformations"); + mlir::ttmlir::python::populatePassesModule(passes); } diff --git a/python/ttmlir/dialects/ttir.py b/python/ttmlir/dialects/ttir.py index 023d159c0..335f276d4 100644 --- a/python/ttmlir/dialects/ttir.py +++ b/python/ttmlir/dialects/ttir.py @@ -4,4 +4,3 @@ from ._ttir_ops_gen import * from .._mlir_libs._ttmlir import register_dialect -from .._mlir_libs._ttmlir import overrides diff --git a/python/ttmlir/overrides.py b/python/ttmlir/overrides.py new file mode 100644 index 000000000..fe6a90c80 --- /dev/null +++ b/python/ttmlir/overrides.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from ._mlir_libs._ttmlir import overrides diff --git a/python/ttmlir/passes.py b/python/ttmlir/passes.py new file mode 100644 index 000000000..0e447a9bf --- /dev/null +++ b/python/ttmlir/passes.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from ._mlir_libs._ttmlir import passes