Skip to content

Commit

Permalink
Added load_from_capsule functionality to TTRT and Python Bindings + S…
Browse files Browse the repository at this point in the history
…mall Bug Fix (#477)
  • Loading branch information
vprajapati-tt authored Aug 26, 2024
1 parent 3969de3 commit 6263b91
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 34 deletions.
10 changes: 10 additions & 0 deletions python/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ void populatePassesModule(py::module &m) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
data = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp);
});

m.def("ttnn_to_flatbuffer_binary", [](MlirModule module) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
std::shared_ptr<void> *binary = new std::shared_ptr<void>();
*binary = mlir::tt::ttnn::ttnnToFlatbuffer(moduleOp);
return py::capsule((void *)binary, [](void *data) {
std::shared_ptr<void> *bin = static_cast<std::shared_ptr<void> *>(data);
delete bin;
});
});
}

} // namespace mlir::ttmlir::python
1 change: 1 addition & 0 deletions runtime/tools/python/ttrt/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._C import (
load_from_path,
load_binary_from_path,
load_binary_from_capsule,
load_system_desc_from_path,
Flatbuffer,
)
Expand Down
6 changes: 6 additions & 0 deletions runtime/tools/python/ttrt/binary/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,11 @@ PYBIND11_MODULE(_C, m) {
.def("store", &tt::runtime::SystemDesc::store);
m.def("load_from_path", &tt::runtime::Flatbuffer::loadFromPath);
m.def("load_binary_from_path", &tt::runtime::Binary::loadFromPath);
m.def("load_binary_from_capsule", [](py::capsule capsule) {
std::shared_ptr<void> *binary =
static_cast<std::shared_ptr<void> *>(capsule.get_pointer());
return tt::runtime::Flatbuffer(
*binary); // Dereference capsule, and then dereference shared_ptr*
});
m.def("load_system_desc_from_path", &tt::runtime::SystemDesc::loadFromPath);
}
87 changes: 58 additions & 29 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class Read:
]

def __init__(self, args={}, logging=None, artifacts=None):
for name, _ in API.Read.registered_args.items():
for name, attributes in API.Read.registered_args.items():
name = name if not name.startswith("-") else name.lstrip("-")
name = name.replace("-", "_")

Expand Down Expand Up @@ -695,7 +695,7 @@ class Run:
api_only_arg = []

def __init__(self, args={}, logging=None, artifacts=None):
for name, _ in API.Run.registered_args.items():
for name, attributes in API.Run.registered_args.items():
name = name if not name.startswith("-") else name.lstrip("-")
name = name.replace("-", "_")

Expand Down Expand Up @@ -1039,7 +1039,7 @@ class Perf:
api_only_arg = []

def __init__(self, args={}, logging=None, artifacts=None):
for name, _ in API.Perf.registered_args.items():
for name, attributes in API.Perf.registered_args.items():
name = name if not name.startswith("-") else name.lstrip("-")
name = name.replace("-", "_")

Expand Down Expand Up @@ -1092,49 +1092,78 @@ def check_constraints(self):
self.tracy_csvexport_tool_path
), f"perf tool={self.tracy_csvexport_tool_path} does not exist - rebuild using perf mode"

ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths(self["binary"])
ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths(
self["binary"]
)

self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}")
self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}")

for path in ttnn_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if "binary" not in self:
# Load from Capsule instead. only TTNN Path is supported for now
bin = Binary(self.logger, self.file_manager, "", self["capsule"])
if not bin.check_version():
continue
self.logger.warning(
"Flatbuffer version not present, are you sure that the binary is valid? - Skipped"
)
return

if not bin.check_system_desc(self.query):
continue
self.logger.warning(
"System desc does not match, are you sure that the binary is valid? - Skipped"
)
return

if self["program_index"] != "all":
if not bin.check_program_index_exists(int(self["program_index"])):
self.logging.warning(
f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test"
)
return
self.ttnn_binaries.append(bin)
else:
ttnn_binary_paths = self.file_manager.find_ttnn_binary_paths(
self["binary"]
)
ttmetal_binary_paths = self.file_manager.find_ttmetal_binary_paths(
self["binary"]
)

self.logging.debug(f"ttnn_binary_paths={ttnn_binary_paths}")
self.logging.debug(f"ttmetal_binary_paths={ttmetal_binary_paths}")

for path in ttnn_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
continue

self.ttnn_binaries.append(bin)
if not bin.check_system_desc(self.query):
continue

self.logging.debug(f"finished checking constraints for run API")
if self["program_index"] != "all":
if not bin.check_program_index_exists(
int(self["program_index"])
):
self.logging.warning(
f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test"
)
continue

for path in ttmetal_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
continue
self.ttnn_binaries.append(bin)

if not bin.check_system_desc(self.query):
continue
self.logging.debug(f"finished checking constraints for run API")

if self["program_index"] != "all":
if not bin.check_program_index_exists(int(self["program_index"])):
self.logging.warning(
f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test"
)
for path in ttmetal_binary_paths:
bin = Binary(self.logger, self.file_manager, path)
if not bin.check_version():
continue

self.ttmetal_binaries.append(bin)
if not bin.check_system_desc(self.query):
continue

if self["program_index"] != "all":
if not bin.check_program_index_exists(
int(self["program_index"])
):
self.logging.warning(
f"program index={int(self['program_index'])} is greater than number of programs in: {bin.file_path} - skipping this test"
)
continue

self.ttmetal_binaries.append(bin)

self.logging.debug(f"finished checking constraints for perf API")

Expand Down
13 changes: 8 additions & 5 deletions runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,13 @@ class Flatbuffer:
ttmetal_file_extension = ".ttm"
ttsys_file_extension = ".ttsys"

def __init__(self, logger, file_manager, file_path):
def __init__(self, logger, file_manager, file_path, capsule=None):
import ttrt.binary

self.logger = logger
self.logging = self.logger.get_logger()
self.file_manager = file_manager
self.file_path = file_path
self.file_path = file_path if file_path != None else "<binary-from-capsule>"
self.name = self.file_manager.get_file_name(file_path)
self.extension = self.file_manager.get_file_extension(file_path)
self.version = None
Expand Down Expand Up @@ -506,12 +506,15 @@ def get_ttsys_file_extension():


class Binary(Flatbuffer):
def __init__(self, logger, file_manager, file_path):
super().__init__(logger, file_manager, file_path)
def __init__(self, logger, file_manager, file_path, capsule=None):
super().__init__(logger, file_manager, file_path, capsule=capsule)

import ttrt.binary

self.fbb = ttrt.binary.load_binary_from_path(file_path)
if not capsule:
self.fbb = ttrt.binary.load_binary_from_path(file_path)
else:
self.fbb = ttrt.binary.load_binary_from_capsule(capsule)
self.fbb_dict = ttrt.binary.as_dict(self.fbb)
self.version = self.fbb.version
self.programs = []
Expand Down

0 comments on commit 6263b91

Please sign in to comment.