From 6263b9146d48f3dbbf2eaeac8b4d1054fddc137e Mon Sep 17 00:00:00 2001 From: Vraj Prajapati Date: Mon, 26 Aug 2024 10:38:34 -0500 Subject: [PATCH] Added load_from_capsule functionality to TTRT and Python Bindings + Small Bug Fix (#477) --- python/Passes.cpp | 10 +++ runtime/tools/python/ttrt/binary/__init__.py | 1 + runtime/tools/python/ttrt/binary/module.cpp | 6 ++ runtime/tools/python/ttrt/common/api.py | 87 +++++++++++++------- runtime/tools/python/ttrt/common/util.py | 13 +-- 5 files changed, 83 insertions(+), 34 deletions(-) diff --git a/python/Passes.cpp b/python/Passes.cpp index a8964b7df3..d599b44e3b 100644 --- a/python/Passes.cpp +++ b/python/Passes.cpp @@ -49,6 +49,16 @@ void populatePassesModule(py::module &m) { mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); data = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp); }); + + m.def("ttnn_to_flatbuffer_binary", [](MlirModule module) { + mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); + std::shared_ptr *binary = new std::shared_ptr(); + *binary = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp); + return py::capsule((void *)binary, [](void *data) { + std::shared_ptr *bin = static_cast *>(data); + delete bin; + }); + }); } } // namespace mlir::ttmlir::python diff --git a/runtime/tools/python/ttrt/binary/__init__.py b/runtime/tools/python/ttrt/binary/__init__.py index f50c28f028..c90c2354cc 100644 --- a/runtime/tools/python/ttrt/binary/__init__.py +++ b/runtime/tools/python/ttrt/binary/__init__.py @@ -5,6 +5,7 @@ from ._C import ( load_from_path, load_binary_from_path, + load_binary_from_capsule, load_system_desc_from_path, Flatbuffer, ) diff --git a/runtime/tools/python/ttrt/binary/module.cpp b/runtime/tools/python/ttrt/binary/module.cpp index 392f2e3e8e..9b076ef7f6 100644 --- a/runtime/tools/python/ttrt/binary/module.cpp +++ b/runtime/tools/python/ttrt/binary/module.cpp @@ -38,5 +38,11 @@ PYBIND11_MODULE(_C, m) { .def("store", &tt::runtime::SystemDesc::store); m.def("load_from_path", &tt::runtime::Flatbuffer::loadFromPath); m.def("load_binary_from_path", &tt::runtime::Binary::loadFromPath); + m.def("load_binary_from_capsule", [](py::capsule capsule) { + std::shared_ptr *binary = + static_cast *>(capsule.get_pointer()); + return tt::runtime::Flatbuffer( + *binary); // Dereference capsule, and then dereference shared_ptr* + }); m.def("load_system_desc_from_path", &tt::runtime::SystemDesc::loadFromPath); } diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 2f7c20e6f9..5d3164a3b1 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -446,7 +446,7 @@ class Read: ] def __init__(self, args={}, logging=None, artifacts=None): - for name, _ in API.Read.registered_args.items(): + for name, attributes in API.Read.registered_args.items(): name = name if not name.startswith("-") else name.lstrip("-") name = name.replace("-", "_") @@ -695,7 +695,7 @@ class Run: api_only_arg = [] def __init__(self, args={}, logging=None, artifacts=None): - for name, _ in API.Run.registered_args.items(): + for name, attributes in API.Run.registered_args.items(): name = name if not name.startswith("-") else name.lstrip("-") name = name.replace("-", "_") @@ -1039,7 +1039,7 @@ class Perf: api_only_arg = [] def __init__(self, args={}, logging=None, artifacts=None): - for name, _ in API.Perf.registered_args.items(): + for name, attributes in API.Perf.registered_args.items(): name = name if not name.startswith("-") else name.lstrip("-") name = name.replace("-", "_") @@ -1092,49 +1092,78 @@ def check_constraints(self): self.tracy_csvexport_tool_path ), f"perf tool={self.tracy_csvexport_tool_path} does not exist - rebuild using perf mode" - ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths(self["binary"]) - ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths( - self["binary"] - ) - - self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}") - self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}") - - for path in ttnn_binary_paths: - bin = Binary(self.logger, self.file_manager, path) + if "binary" not in self: + # Load from Capsule instead. only TTNN Path is supported for now + bin = Binary(self.logger, self.file_manager, "", self["capsule"]) if not bin.check_version(): - continue + self.logger.warning( + "Flatbuffer version not present, are you sure that the binary is valid? - Skipped" + ) + return if not bin.check_system_desc(self.query): - continue + self.logger.warning( + "System desc does not match, are you sure that the binary is valid? - Skipped" + ) + return if self["program_index"] != "all": if not bin.check_program_index_exists(int(self["program_index"])): self.logging.warning( f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test" ) + return + self.ttnn_binaries.append(bin) + else: + ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths( + self["binary"] + ) + ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths( + self["binary"] + ) + + self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}") + self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}") + + for path in ttnn_binary_paths: + bin = Binary(self.logger, self.file_manager, path) + if not bin.check_version(): continue - self.ttnn_binaries.append(bin) + if not bin.check_system_desc(self.query): + continue - self.logging.debug(f"finished checking constraints for run API") + if self["program_index"] != "all": + if not bin.check_program_index_exists( + int(self["program_index"]) + ): + self.logging.warning( + f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test" + ) + continue - for path in ttmetal_binary_paths: - bin = Binary(self.logger, self.file_manager, path) - if not bin.check_version(): - continue + self.ttnn_binaries.append(bin) - if not bin.check_system_desc(self.query): - continue + self.logging.debug(f"finished checking constraints for run API") - if self["program_index"] != "all": - if not bin.check_program_index_exists(int(self["program_index"])): - self.logging.warning( - f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test" - ) + for path in ttmetal_binary_paths: + bin = Binary(self.logger, self.file_manager, path) + if not bin.check_version(): continue - self.ttmetal_binaries.append(bin) + if not bin.check_system_desc(self.query): + continue + + if self["program_index"] != "all": + if not bin.check_program_index_exists( + int(self["program_index"]) + ): + self.logging.warning( + f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test" + ) + continue + + self.ttmetal_binaries.append(bin) self.logging.debug(f"finished checking constraints for perf API") diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index d332e3b818..d8e12f1687 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -465,13 +465,13 @@ class Flatbuffer: ttmetal_file_extension = ".ttm" ttsys_file_extension = ".ttsys" - def __init__(self, logger, file_manager, file_path): + def __init__(self, logger, file_manager, file_path, capsule=None): import ttrt.binary self.logger = logger self.logging = self.logger.get_logger() self.file_manager = file_manager - self.file_path = file_path + self.file_path = file_path if file_path != None else "" self.name = self.file_manager.get_file_name(file_path) self.extension = self.file_manager.get_file_extension(file_path) self.version = None @@ -506,12 +506,15 @@ def get_ttsys_file_extension(): class Binary(Flatbuffer): - def __init__(self, logger, file_manager, file_path): - super().__init__(logger, file_manager, file_path) + def __init__(self, logger, file_manager, file_path, capsule=None): + super().__init__(logger, file_manager, file_path, capsule=capsule) import ttrt.binary - self.fbb = ttrt.binary.load_binary_from_path(file_path) + if not capsule: + self.fbb = ttrt.binary.load_binary_from_path(file_path) + else: + self.fbb = ttrt.binary.load_binary_from_capsule(capsule) self.fbb_dict = ttrt.binary.as_dict(self.fbb) self.version = self.fbb.version self.programs = []