Skip to content

Commit

Permalink
Add API to move tensors to CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Sep 8, 2024
1 parent 608ed89 commit 4555dd8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
6 changes: 5 additions & 1 deletion runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {

tt::target::DataType getTensorDataType(Tensor tensor);

void deallocateTensor(Tensor tensor, bool force);

Tensor toCpu(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});

Expand All @@ -100,7 +104,7 @@ void runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);

std::vector<Tensor> runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs);
Expand Down
5 changes: 3 additions & 2 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,16 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl {
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(Event(nullptr, runtime)) {}


Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime, Event event)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(event) {}

// Users need to manually deallocate tensors returned from submit
// As the storage is now owned instead of borrowed
void deallocate();
void deallocate(bool force = false);

Tensor cpu() const;
};

} // namespace tt::runtime
Expand Down
14 changes: 14 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ tt::target::DataType getTensorDataType(Tensor tensor) {
return utils::fromTTNNDataType(nnTensor.dtype());
}

void deallocateTensor(Tensor tensor, bool force) {
::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
ttnnTensor.deallocate(force);
}

Tensor toCpu(Tensor tensor) {
::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
std::shared_ptr<::ttnn::Tensor> cpuTensor =
std::make_shared<::ttnn::Tensor>(ttnnTensor.cpu());
void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*cpuTensor);
return Tensor(cpuTensor, ::tt::runtime::utils::unsafe_borrow_shared(dataPtr),
DeviceRuntime::TTNN);
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
assert(deviceIds.size() == 1 && "Only one device is supported for now");
Expand Down
24 changes: 19 additions & 5 deletions runtime/lib/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,33 @@

namespace tt::runtime {

void Tensor::deallocate() {
void Tensor::deallocate(bool force) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (this->matchesRuntime(DeviceRuntime::TTNN)) {
::ttnn::Tensor &tensor = this->as<::ttnn::Tensor>(DeviceRuntime::TTNN);
tensor.deallocate();
return;
::tt::runtime::ttnn::deallocateTensor(*this, force);
}
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (this->matchesRuntime(DeviceRuntime::TTMetal)) {
throw std::runtime_error("Not implemented");
}
#endif
throw std::runtime_error("Runtime not enabled");
}

Tensor Tensor::cpu() const {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (this->matchesRuntime(DeviceRuntime::TTNN)) {
return ::tt::runtime::ttnn::toCpu(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (this->matchesRuntime(DeviceRuntime::TTMetal)) {
throw std::runtime_error("Not implemented");
}
#endif
throw std::runtime_error("Runtime not enabled");
}
} // namespace tt::runtime

0 comments on commit 4555dd8

Please sign in to comment.