diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index b79bde0e1..78596c36b 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -68,6 +68,8 @@ Device openDevice(std::vector const &deviceIds = {0}, void closeDevice(Device device); +void deallocateBuffers(Device device); + Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, std::vector const &outputs); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 487bfdc77..13e7de81a 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -83,6 +83,8 @@ Device openDevice(std::vector const &deviceIds = {0}, void closeDevice(Device device); +void deallocateBuffers(Device device); + Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, std::vector const &outputs); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 80b0eb783..f9ab6127d 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -17,6 +17,10 @@ namespace system_desc { std::pair getCurrentSystemDesc(); } // namespace system_desc +namespace detail { +void deallocateBuffers(Device device); +} + DeviceRuntime getCurrentRuntime(); std::vector getAvailableRuntimes(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 888f12b46..30604fc18 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -26,6 +26,21 @@ DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTMetal; DeviceRuntime globalCurrentRuntime = DeviceRuntime::Disabled; #endif +void deallocateBuffers(Device device) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::deallocateBuffers(device); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::deallocateBuffers(device); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + } // namespace detail DeviceRuntime getCurrentRuntime() { diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 4842f0b47..ede9b7a71 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -78,6 +78,13 @@ void closeDevice(Device device) { } } +void deallocateBuffers(Device deviceHandle) { + DeviceMesh &deviceMesh = deviceHandle.as(DeviceRuntime::TTMetal); + for (::tt::tt_metal::Device *device : deviceMesh) { + device->deallocate_buffers(); + } +} + static std::pair, std::shared_ptr<::tt::tt_metal::Event>> prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 9cf057513..e1f786bc4 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -69,6 +69,10 @@ void closeDevice(Device device) { ::ttnn::close_device(ttnn_device); } +void deallocateBuffers(Device deviceHandle) { + deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN).deallocate_buffers(); +} + static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 4b2d3e418..54f488649 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -900,6 +900,8 @@ def _execute(binaries): ) for tensor in program.output_tensors: self.logging.debug(f"{tensor}\n") + + device.deallocate_buffers() finally: ttrt.runtime.close_device(device) diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 453e5164a..0aa5d84db 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -13,7 +13,8 @@ PYBIND11_MODULE(_C, m) { m.doc() = "ttrt.runtime python extension for interacting with the " "Tenstorrent devies"; - py::class_(m, "Device"); + py::class_(m, "Device") + .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); py::class_(m, "Tensor"); py::enum_<::tt::target::DataType>(m, "DataType")