Skip to content

Commit

Permalink
ttrt deallocate_buffers (#588)
Browse files Browse the repository at this point in the history
Plumb through deallocate_buffers API to clear metal allocator between
programs.  Previously we were running OOM because of leaked buffers.
Leaked buffers should largely be cleaned up after the allocate=false
support is in #408.

It might still be needed even after #408 lands, because metal internally
uses its allocator for some fast dispatch things, but we can revisit
what policy to adopt.
  • Loading branch information
nsmithtt authored Sep 3, 2024
1 parent c75811b commit 4ce9a3c
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 1 deletion.
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Device openDevice(std::vector<int> const &deviceIds = {0},

void closeDevice(Device device);

void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ Device openDevice(std::vector<int> const &deviceIds = {0},

void closeDevice(Device device);

void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
Expand Down
4 changes: 4 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ namespace system_desc {
std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
} // namespace system_desc

namespace detail {
void deallocateBuffers(Device device);
}

DeviceRuntime getCurrentRuntime();

std::vector<DeviceRuntime> getAvailableRuntimes();
Expand Down
15 changes: 15 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ DeviceRuntime globalCurrentRuntime = DeviceRuntime::TTMetal;
DeviceRuntime globalCurrentRuntime = DeviceRuntime::Disabled;
#endif

void deallocateBuffers(Device device) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::deallocateBuffers(device);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::deallocateBuffers(device);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

} // namespace detail

DeviceRuntime getCurrentRuntime() {
Expand Down
7 changes: 7 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ void closeDevice(Device device) {
}
}

void deallocateBuffers(Device deviceHandle) {
DeviceMesh &deviceMesh = deviceHandle.as<DeviceMesh>(DeviceRuntime::TTMetal);
for (::tt::tt_metal::Device *device : deviceMesh) {
device->deallocate_buffers();
}
}

static std::pair<std::shared_ptr<::tt::tt_metal::Buffer>,
std::shared_ptr<::tt::tt_metal::Event>>
prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor,
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ void closeDevice(Device device) {
::ttnn::close_device(ttnn_device);
}

void deallocateBuffers(Device deviceHandle) {
deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN).deallocate_buffers();
}

static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
binary.handle.get());
Expand Down
2 changes: 2 additions & 0 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,8 @@ def _execute(binaries):
)
for tensor in program.output_tensors:
self.logging.debug(f"{tensor}\n")

device.deallocate_buffers()
finally:
ttrt.runtime.close_device(device)

Expand Down
3 changes: 2 additions & 1 deletion runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ PYBIND11_MODULE(_C, m) {
m.doc() = "ttrt.runtime python extension for interacting with the "
"Tenstorrent devies";

py::class_<tt::runtime::Device>(m, "Device");
py::class_<tt::runtime::Device>(m, "Device")
.def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers);
py::class_<tt::runtime::Event>(m, "Event");
py::class_<tt::runtime::Tensor>(m, "Tensor");
py::enum_<::tt::target::DataType>(m, "DataType")
Expand Down

0 comments on commit 4ce9a3c

Please sign in to comment.