From 608ed8916b26793f2629606d6afcc2e2b4b0d284 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Tue, 20 Aug 2024 01:13:38 +0000 Subject: [PATCH 1/2] First Iteration/Prototype: Runtime refactor to support runtime stitching --- runtime/include/tt/runtime/detail/ttnn.h | 15 +++ runtime/include/tt/runtime/runtime.h | 7 + runtime/include/tt/runtime/types.h | 26 +++- runtime/lib/CMakeLists.txt | 8 ++ runtime/lib/runtime.cpp | 36 +++++ runtime/lib/ttnn/program.cpp | 165 ++++++++++++++++------- runtime/lib/ttnn/runtime.cpp | 31 +++++ runtime/lib/types.cpp | 32 +++++ 8 files changed, 271 insertions(+), 49 deletions(-) create mode 100644 runtime/lib/types.cpp diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index ff0c80f45..8947603d6 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -90,12 +90,27 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, std::vector const &outputs); +std::vector submit(Device device, Binary executable, + std::uint32_t programIndex, + std::vector const &inputs); + void wait(Event event); void runProgram(::ttnn::Device &device, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); + +std::vector runProgram(::ttnn::Device &device, + ::tt::target::ttnn::Program const *program, + std::vector<::ttnn::Tensor *> const &inputs); + +Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex, + std::uint32_t inputIndex, Tensor const &input); + +Tensor updateProgramTensorLayout(Device device, + ::tt::target::ttnn::Program const *program, + std::uint32_t inputIndex, Tensor const &input); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 395e7551b..99d4af802 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -51,9 +51,16 @@ void closeDevice(Device device); Event submit(Device device, Binary executable, std::uint32_t programIndex, std::vector const &inputs, std::vector const &outputs); + +std::vector submit(Device device, Binary executable, + std::uint32_t programIndex, + std::vector const &inputs); void wait(Event event); +Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex, + std::uint32_t inputIndex, Tensor const &input); + } // namespace tt::runtime #endif diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index bfb7e4ba5..e29215044 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -111,14 +111,36 @@ struct Device : public detail::RuntimeCheckedObjectImpl { }; struct Event : public detail::RuntimeCheckedObjectImpl { - using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; + Event(std::shared_ptr handle, DeviceRuntime runtime) + : detail::RuntimeCheckedObjectImpl(handle, runtime) {} + + bool isTTNNEvent() const { + return this->matchesRuntime(DeviceRuntime::TTNN) and this->handle.get(); + } + + bool isTTMetalEvent() const { + return this->matchesRuntime(DeviceRuntime::TTMetal) and this->handle.get(); + } }; struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; + Event event; + Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime) - : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(Event(nullptr, runtime)) {} + + + Tensor(std::shared_ptr handle, std::shared_ptr data, + DeviceRuntime runtime, Event event) + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(event) {} + + // Users need to manually deallocate tensors returned from submit + // As the storage is now owned instead of borrowed + void deallocate(); }; } // namespace tt::runtime diff --git a/runtime/lib/CMakeLists.txt b/runtime/lib/CMakeLists.txt index 1792f24bf..6583b81f3 100644 --- a/runtime/lib/CMakeLists.txt +++ b/runtime/lib/CMakeLists.txt @@ -24,6 +24,14 @@ target_include_directories(TTBinary ) add_dependencies(TTBinary FBS_GENERATION) +add_library(TTRuntimeTypes STATIC types.cpp) +target_include_directories(TTRuntimeTypes + PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common +) +add_dependencies(TTBinary FBS_GENERATION) + if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL)) add_subdirectory(common) else() diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 7b34f04e5..f98bb6d80 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -191,6 +191,25 @@ Event submit(Device deviceHandle, Binary executableHandle, throw std::runtime_error("runtime is not enabled"); } +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, + programIndex, inputHandles); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + throw std::runtime_error("Currently not supported after refactor"); + } +#endif + + throw std::runtime_error("runtime is not enabled"); +} + void wait(Event event) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { @@ -206,4 +225,21 @@ void wait(Event event) { throw std::runtime_error("runtime is not enabled"); } +Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex, + std::uint32_t inputIndex, Tensor const &input) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::toLayout(device, executable, programIndex, + inputIndex, input); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + throw std::runtime_error("Not implemented"); + } + +#endif + throw std::runtime_error("runtime is not enabled"); +} } // namespace tt::runtime diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index a14896569..77ed88c06 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -63,6 +63,10 @@ class ProgramTensorPool { return liveTensors.contains(global_id); } + size_t size() const { + return liveTensors.size(); + } + private: // A superset of intermedTensors, containing all tensors created by the // program and the input/output tensors passed in by the user @@ -89,6 +93,20 @@ static ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) { tensorRef->desc()->layout()->memory_desc()->data_type()); } +static Tensor toTypeErasedTensor(const ::ttnn::Tensor &tensor) { + std::shared_ptr<::ttnn::Tensor> tensorHandle = std::make_shared<::ttnn::Tensor>(tensor); + void *dataPtr = isOnHost(*tensorHandle) ? ::tt::tt_metal::get_raw_host_data_ptr(*tensorHandle) : nullptr; + return Tensor(tensorHandle, ::tt::runtime::utils::unsafe_borrow_shared(dataPtr), DeviceRuntime::TTNN); +} + +static void tensorMemcpy(::ttnn::Tensor &dst, ::ttnn::Tensor &src) { + assert(isOnHost(src) and dst.storage_type() == ::tt::tt_metal::StorageType::BORROWED); + void *srcDataPtr = ::tt::tt_metal::get_raw_host_data_ptr(src); + void *dstDataPtr = ::tt::tt_metal::get_raw_host_data_ptr(dst); + std::uint32_t size = src.volume() * src.element_size(); + std::memcpy(dstDataPtr, srcDataPtr, size); + +} static CoreRangeSet toCoreRangeSet( const ::flatbuffers::Vector *coreRangeSet) { std::set coreRanges; @@ -214,10 +232,9 @@ updateLayoutAndDataType(const ::ttnn::Tensor &inputTensor, return outputTensor; } -static void +static ::ttnn::Tensor handleToHostMemoryConfigOp(const ::ttnn::Tensor &inputTensor, - const ::tt::target::TensorRef *outputTensorRef, - ProgramTensorPool &tensorPool) { + const ::tt::target::TensorRef *outputTensorRef) { ::ttnn::Tensor result; ::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef); bool shouldTilize, shouldUntilize; @@ -232,24 +249,13 @@ handleToHostMemoryConfigOp(const ::ttnn::Tensor &inputTensor, result = updateLayoutAndDataType(inputTensor.cpu(), targetDataTypeTTNN, shouldTilize, shouldUntilize); } - // copy the output to the output tensor if it exists - if (tensorPool.contains(outputTensorRef->global_id())) { - ::ttnn::Tensor &outputTensor = tensorPool.at(outputTensorRef->global_id()); - void *src = ::tt::tt_metal::get_raw_host_data_ptr(result); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); - std::uint32_t size = result.volume() * result.element_size(); - std::memcpy(dst, src, size); - } else { - tensorPool.insert_or_assign(outputTensorRef->global_id(), - std::move(result)); - } + return result; } -static void +static ::ttnn::Tensor handleToDramMemoryConfigOp(::ttnn::Device &device, const ::ttnn::Tensor &inputTensor, - const ::tt::target::TensorRef *outputTensorRef, - ProgramTensorPool &tensorPool) { + const ::tt::target::TensorRef *outputTensorRef) { ::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef); ::tt::tt_metal::MemoryConfig targetMemoryConfig = createMemoryConfig(outputTensorRef); @@ -266,24 +272,23 @@ handleToDramMemoryConfigOp(::ttnn::Device &device, result = ::ttnn::to_device(result, &device, targetMemoryConfig); result = updateLayoutAndDataType(result, targetDataTypeTTNN, shouldTilize, shouldUntilize); - tensorPool.insert_or_assign(outputTensorRef->global_id(), - std::move(result)); + return result; } else if (isOnDevice(inputTensor)) { shouldTilize = false; shouldUntilize = false; ::ttnn::Tensor result = updateLayoutAndDataType( inputTensor, targetDataTypeTTNN, shouldTilize, shouldUntilize); result = ::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt); - tensorPool.insert_or_assign(outputTensorRef->global_id(), - std::move(result)); + return result; + } else { + throw std::runtime_error("Unsupported input tensor storage type"); } } -static void +static ::ttnn::Tensor handleToL1MemoryConfigOp(::ttnn::Device &device, const ::ttnn::Tensor &inputTensor, - const ::tt::target::TensorRef *outputTensorRef, - ProgramTensorPool &tensorPool) { + const ::tt::target::TensorRef *outputTensorRef) { ::ttnn::DataType targetDataTypeTTNN = getDataType(outputTensorRef); ::tt::tt_metal::MemoryConfig targetMemoryConfig = createMemoryConfig(outputTensorRef); @@ -309,53 +314,65 @@ handleToL1MemoryConfigOp(::ttnn::Device &device, result = ::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt); } - tensorPool.insert_or_assign(outputTensorRef->global_id(), - std::move(result)); + return result; } else if (isOnDevice(inputTensor)) { shouldTilize = false; shouldUntilize = false; ::ttnn::Tensor result = updateLayoutAndDataType( inputTensor, targetDataTypeTTNN, shouldTilize, shouldUntilize); result = ::ttnn::to_memory_config(result, targetMemoryConfig, std::nullopt); - tensorPool.insert_or_assign(outputTensorRef->global_id(), - std::move(result)); + return result; + } else { + throw std::runtime_error("Unsupported input tensor storage type"); } } -// TODO(bug #272): right now hardcoding tilize/untilize, should determine with -// tile shape blocked by issue #272 -static void run(::tt::target::ttnn::ToMemoryConfigOp const *op, - ::ttnn::Device &device, ProgramTensorPool &tensorPool) { - - const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id()); - assert(isOnHost(inputTensor) or - isOnDevice(inputTensor) && "Unsupported storage type"); - - const ::tt::target::Dim2d *targetTileShape = - op->out()->desc()->layout()->memory_desc()->tile_shape(); - assert(utils::isValidTileShape(targetTileShape) && "Invalid tile shape"); - +static ::ttnn::Tensor updateTensorMemoryConfig(::ttnn::Device &device, + const ::ttnn::Tensor &inputTensor, + const ::tt::target::TensorRef *outputTensorRef) { + const ::tt::target::MemoryDesc *targetMemoryDesc = + outputTensorRef->desc()->layout()->memory_desc(); const ::tt::target::MemorySpace targetMemorySpace = - op->out()->desc()->layout()->memory_desc()->memory_space(); + targetMemoryDesc->memory_space(); switch (targetMemorySpace) { - // This case should only be used when gathering outputs at the end of the - // program case ::tt::target::MemorySpace::System: case ::tt::target::MemorySpace::SystemMMIO: { - handleToHostMemoryConfigOp(inputTensor, op->out(), tensorPool); + return handleToHostMemoryConfigOp(inputTensor, outputTensorRef); break; } case ::tt::target::MemorySpace::DeviceDRAM: { - handleToDramMemoryConfigOp(device, inputTensor, op->out(), tensorPool); + return handleToDramMemoryConfigOp(device, inputTensor, outputTensorRef); break; } case ::tt::target::MemorySpace::DeviceL1: { - handleToL1MemoryConfigOp(device, inputTensor, op->out(), tensorPool); + return handleToL1MemoryConfigOp(device, inputTensor, outputTensorRef); break; } } } +// TODO(bug #272): right now hardcoding tilize/untilize, should determine with +// tile shape blocked by issue #272 +static void run(::tt::target::ttnn::ToMemoryConfigOp const *op, + ::ttnn::Device &device, ProgramTensorPool &tensorPool) { + + const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id()); + assert(isOnHost(inputTensor) or + isOnDevice(inputTensor) && "Unsupported storage type"); + + const ::tt::target::Dim2d *targetTileShape = + op->out()->desc()->layout()->memory_desc()->tile_shape(); + assert(utils::isValidTileShape(targetTileShape) && "Invalid tile shape"); + + ::ttnn::Tensor result = updateTensorMemoryConfig(device, inputTensor, op->out()); + // copy the output to the output tensor if it exists + if (tensorPool.contains(op->out()->global_id()) and tensorPool.at(op->out()->global_id()).storage_type() == ::tt::tt_metal::StorageType::BORROWED) { + tensorMemcpy(tensorPool.at(op->out()->global_id()), result); + } else { + tensorPool.insert_or_assign(op->out()->global_id(), + std::move(result)); + } +} static void run(::tt::target::ttnn::EmptyOp const *op, ::ttnn::Device &device, ProgramTensorPool &tensorPool) { @@ -837,4 +854,58 @@ void runProgram(::ttnn::Device &device, run(op, device, tensorPool); } } + +std::vector runProgram(::ttnn::Device &device, + ::tt::target::ttnn::Program const *program, + std::vector<::ttnn::Tensor *> const &inputs) { + + ProgramTensorPool tensorPool({}); + int inputIndex = 0; + + // convert inputs to the desired layout/memory config + for (::tt::target::TensorRef const *inputRef : *program->inputs()) { + const ::ttnn::Tensor *inputTensor = inputs[inputIndex++]; + ::ttnn::Tensor updatedInputTensor = updateTensorMemoryConfig(device, *inputTensor, inputRef); + auto [iter, inserted] = tensorPool.try_emplace(inputRef->global_id(), std::move(updatedInputTensor)); + assert(inserted && "Duplicate input tensor"); + } + + for (::tt::target::ttnn::Operation const *op : *program->operations()) { + run(op, device, tensorPool); + } + + // convert outputs to the desired layout/memory config + // then convert them to type erased tensors and return + std::vector outputs; + for (::tt::target::TensorRef const *outputRef : *program->outputs()) { + size_t outputId = outputRef->global_id(); + assert(tensorPool.contains(outputId) && + "Program output tensor not found in tensorPool"); + const ::ttnn::Tensor &outputTensor = tensorPool.at(outputId); + ::ttnn::Tensor updatedOutputTensor = updateTensorMemoryConfig(device, outputTensor, outputRef); + outputs.push_back(toTypeErasedTensor(updatedOutputTensor)); + } + + return outputs; +} + +Tensor updateProgramTensorLayout(Device device, + ::tt::target::ttnn::Program const *program, + std::uint32_t inputIndex, + Tensor const &input) { + TT_FATAL(inputIndex < program->inputs()->size(), + "Input index {} out of range {}", inputIndex, + program->inputs()->size()); + const ::tt::target::TensorRef *inputRef = program->inputs()->Get(inputIndex); + + ::ttnn::Device &ttnnDevice = device.as<::ttnn::Device>(DeviceRuntime::TTNN); + const ::ttnn::Tensor &ttnnInput = + input.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + + ::ttnn::Tensor result = + updateTensorMemoryConfig(ttnnDevice, ttnnInput, inputRef); + + return toTypeErasedTensor(result); +} + } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index e1f786bc4..13d630cd8 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -105,6 +105,37 @@ Event submit(Device deviceHandle, Binary executableHandle, return Event(nullptr, DeviceRuntime::TTNN); } +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { + ::ttnn::Device &device = deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN); + ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); + std::vector<::ttnn::Tensor *> inputs; + inputs.reserve(inputHandles.size()); + for (auto &input : inputHandles) { + assert(input.matchesRuntime(DeviceRuntime::TTNN)); + inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); + } + std::vector outputs = ::tt::runtime::ttnn::runProgram( + device, fbb.programs()->Get(programIndex), inputs); + return outputs; +} + +Tensor toLayout(Device device, Binary executable, std::uint32_t programIndex, + std::uint32_t inputIndex, Tensor const &input) { + + const ::tt::target::ttnn::TTNNBinary *fbb = getBinary(executable); + + TT_FATAL(programIndex < fbb->programs()->size(), + "Program index {} out of range {}", programIndex, + fbb->programs()->size()); + const ::tt::target::ttnn::Program *program = + fbb->programs()->Get(programIndex); + + return ::tt::runtime::ttnn::updateProgramTensorLayout(device, program, + inputIndex, input); +} + void wait(Event event) { // Not implemented assert(event.matchesRuntime(DeviceRuntime::TTNN)); diff --git a/runtime/lib/types.cpp b/runtime/lib/types.cpp new file mode 100644 index 000000000..9d7a8e7e8 --- /dev/null +++ b/runtime/lib/types.cpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/types.h" + +#if defined(TT_RUNTIME_ENABLE_TTNN) +#include "tt/runtime/detail/ttnn.h" +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) +#include "tt/runtime/detail/ttmetal.h" +#endif + +namespace tt::runtime { + +void Tensor::deallocate() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (this->matchesRuntime(DeviceRuntime::TTNN)) { + ::ttnn::Tensor &tensor = this->as<::ttnn::Tensor>(DeviceRuntime::TTNN); + tensor.deallocate(); + return; + } +#elif defined(TT_RUNTIME_ENABLE_TTMETAL) + if (this->matchesRuntime(DeviceRuntime::TTMetal)) { + throw std::runtime_error("Not implemented"); + } +#endif + throw std::runtime_error("Runtime not enabled"); +} + +} // namespace tt::runtime From 4555dd8ae8aa994943940e86955d148ed16a02c7 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Sun, 8 Sep 2024 03:18:38 +0000 Subject: [PATCH 2/2] Add API to move tensors to CPU --- runtime/include/tt/runtime/detail/ttnn.h | 6 +++++- runtime/include/tt/runtime/types.h | 5 +++-- runtime/lib/ttnn/runtime.cpp | 14 ++++++++++++++ runtime/lib/types.cpp | 24 +++++++++++++++++++----- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 8947603d6..f7d0cd9a5 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -79,6 +79,10 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { tt::target::DataType getTensorDataType(Tensor tensor); +void deallocateTensor(Tensor tensor, bool force); + +Tensor toCpu(Tensor tensor); + Device openDevice(std::vector const &deviceIds = {0}, std::vector const &numHWCQs = {}); @@ -100,7 +104,7 @@ void runProgram(::ttnn::Device &device, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); - + std::vector runProgram(::ttnn::Device &device, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs); diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index e29215044..46d8da326 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -132,7 +132,6 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), event(Event(nullptr, runtime)) {} - Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime, Event event) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), @@ -140,7 +139,9 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { // Users need to manually deallocate tensors returned from submit // As the storage is now owned instead of borrowed - void deallocate(); + void deallocate(bool force = false); + + Tensor cpu() const; }; } // namespace tt::runtime diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 13d630cd8..a1417266c 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -56,6 +56,20 @@ tt::target::DataType getTensorDataType(Tensor tensor) { return utils::fromTTNNDataType(nnTensor.dtype()); } +void deallocateTensor(Tensor tensor, bool force) { + ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + ttnnTensor.deallocate(force); +} + +Tensor toCpu(Tensor tensor) { + ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::shared_ptr<::ttnn::Tensor> cpuTensor = + std::make_shared<::ttnn::Tensor>(ttnnTensor.cpu()); + void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*cpuTensor); + return Tensor(cpuTensor, ::tt::runtime::utils::unsafe_borrow_shared(dataPtr), + DeviceRuntime::TTNN); +} + Device openDevice(std::vector const &deviceIds, std::vector const &numHWCQs) { assert(deviceIds.size() == 1 && "Only one device is supported for now"); diff --git a/runtime/lib/types.cpp b/runtime/lib/types.cpp index 9d7a8e7e8..b9cf17ad6 100644 --- a/runtime/lib/types.cpp +++ b/runtime/lib/types.cpp @@ -14,14 +14,14 @@ namespace tt::runtime { -void Tensor::deallocate() { +void Tensor::deallocate(bool force) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (this->matchesRuntime(DeviceRuntime::TTNN)) { - ::ttnn::Tensor &tensor = this->as<::ttnn::Tensor>(DeviceRuntime::TTNN); - tensor.deallocate(); - return; + ::tt::runtime::ttnn::deallocateTensor(*this, force); } -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) if (this->matchesRuntime(DeviceRuntime::TTMetal)) { throw std::runtime_error("Not implemented"); } @@ -29,4 +29,18 @@ void Tensor::deallocate() { throw std::runtime_error("Runtime not enabled"); } +Tensor Tensor::cpu() const { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (this->matchesRuntime(DeviceRuntime::TTNN)) { + return ::tt::runtime::ttnn::toCpu(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (this->matchesRuntime(DeviceRuntime::TTMetal)) { + throw std::runtime_error("Not implemented"); + } +#endif + throw std::runtime_error("Runtime not enabled"); +} } // namespace tt::runtime