Skip to content

Commit

Permalink
#8957: Use tt::stl::json for serialization for types MemoryConfig, Co…
Browse files Browse the repository at this point in the history
…reRange, CoreRangeSet, CoreCoord, ShardSpec
  • Loading branch information
jdesousa-TT committed Jul 26, 2024
1 parent bf5cb31 commit f212d4b
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ void implement_buffer_protocol(PyType& py_buffer_t) {

} // namespace detail

template <typename T>
auto tt_pybind_class(py::module m, auto name, auto desc) {
return py::class_<T>(m, name, desc)
.def("to_json", [](const T& self) -> std::string { return tt::stl::json::to_json(self).dump(); })
.def(
"from_json",
[](const std::string& json_string) -> T { return tt::stl::json::from_json<T>(nlohmann::json::parse(json_string)); })
.def("__repr__", [](const T& self) { return fmt::format("{}", self); });
}

void TensorModule(py::module& m_tensor) {
// ENUM SECTION

Expand All @@ -85,7 +95,7 @@ void TensorModule(py::module& m_tensor) {
.value("L1_SMALL", BufferType::L1_SMALL);


auto py_core_coord = py::class_<CoreCoord>(m_tensor, "CoreCoord", R"doc(
auto py_core_coord = tt_pybind_class<CoreCoord>(m_tensor, "CoreCoord", R"doc(
Class defining core coordinate
)doc");

Expand Down Expand Up @@ -144,7 +154,7 @@ void TensorModule(py::module& m_tensor) {

py::implicitly_convertible<std::vector<uint32_t>, Shape>();

auto pyMemoryConfig = py::class_<MemoryConfig>(m_tensor, "MemoryConfig", R"doc(
auto pyMemoryConfig = tt_pybind_class<MemoryConfig>(m_tensor, "MemoryConfig", R"doc(
Class defining memory configuration for storing tensor data on TT Accelerator device.
There are eight DRAM memory banks on TT Accelerator device, indexed as 0, 1, 2, ..., 7.
)doc");
Expand Down Expand Up @@ -213,15 +223,15 @@ void TensorModule(py::module& m_tensor) {
py::class_<owned_buffer::Buffer<uint16_t>>(m_tensor, "owned_buffer_for_uint16_t", py::buffer_protocol());
detail::implement_buffer_protocol<owned_buffer::Buffer<uint16_t>, uint16_t>(py_owned_buffer_for_uint16_t);

auto pyCoreRange = py::class_<CoreRange>(m_tensor, "CoreRange", R"doc(
auto pyCoreRange = tt_pybind_class<CoreRange>(m_tensor, "CoreRange", R"doc(
Class defining a range of cores)doc");
pyCoreRange.def(py::init<>([](const CoreCoord& start, const CoreCoord& end) { return CoreRange{start, end}; }))
.def("__repr__", [](const CoreRange& core_range) -> std::string { return fmt::format("{}", core_range); })
.def_readonly("start", &CoreRange::start_coord)
.def_readonly("end", &CoreRange::end_coord)
.def("grid_size", &CoreRange::grid_size);

auto pyCoreRangeSet = py::class_<CoreRangeSet>(m_tensor, "CoreRangeSet", R"doc(
auto pyCoreRangeSet = tt_pybind_class<CoreRangeSet>(m_tensor, "CoreRangeSet", R"doc(
Class defining a set of CoreRanges required for sharding)doc");
pyCoreRangeSet.def(py::init<>([](const std::set<CoreRange>& core_ranges) { return CoreRangeSet(core_ranges); }))
.def(
Expand All @@ -243,7 +253,7 @@ void TensorModule(py::module& m_tensor) {
Returns a CoreRangeSet from number of cores
)doc");

auto pyShardSpec = py::class_<ShardSpec>(m_tensor, "ShardSpec", R"doc(
auto pyShardSpec = tt_pybind_class<ShardSpec>(m_tensor, "ShardSpec", R"doc(
Class defining the specs required for sharding.
)doc");

Expand Down

0 comments on commit f212d4b

Please sign in to comment.