From 1e3bb7965537a3f6dfb41c4c67db03da0aa905e6 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Fri, 30 Aug 2024 07:56:20 -0400 Subject: [PATCH] Determine data type of target tensor during runtime. (#524) --- runtime/include/tt/runtime/detail/ttmetal.h | 2 ++ runtime/include/tt/runtime/detail/ttnn.h | 2 ++ runtime/include/tt/runtime/runtime.h | 2 ++ runtime/lib/runtime.cpp | 15 +++++++++++++++ runtime/lib/ttmetal/runtime.cpp | 15 +++++++++++++++ runtime/lib/ttnn/runtime.cpp | 6 ++++++ runtime/lib/ttnn/utils.h | 20 ++++++++++++++++++++ 7 files changed, 62 insertions(+) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index f7838b39b..964caa6a5 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -61,6 +61,8 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { desc.dataType); } +tt::target::DataType getTensorDataType(Tensor tensor); + Device openDevice(std::vector const &deviceIds = {0}, std::vector const &numHWCQs = {}); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index b9a1b79ea..89405df43 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -71,6 +71,8 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { desc.dataType); } +tt::target::DataType getTensorDataType(Tensor tensor); + Device openDevice(std::vector const &deviceIds = {0}, std::vector const &numHWCQs = {}); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index dfe98a7a9..80b0eb783 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -37,6 +37,8 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { desc.dataType); } +tt::target::DataType getTensorDataType(Tensor tensor); + Device openDevice(std::vector const &deviceIds = {0}, std::vector const &numHWCQs = {}); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index f225361eb..888f12b46 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -107,6 +107,21 @@ Tensor createTensor(std::shared_ptr data, throw std::runtime_error("runtime is not enabled"); } +tt::target::DataType getTensorDataType(Tensor tensor) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorDataType(tensor); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorDataType(tensor); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + Device openDevice(std::vector const &deviceIds, std::vector const &numHWCQs) { #if defined(TT_RUNTIME_ENABLE_TTNN) diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index dd278bc7b..4842f0b47 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -43,6 +43,21 @@ Tensor createTensor(std::shared_ptr data, DeviceRuntime::TTMetal); } +tt::target::DataType getTensorDataType(Tensor tensor) { + const MetalTensor &metalTensor = + tensor.as(DeviceRuntime::TTMetal); + if (std::holds_alternative(metalTensor)) { + TensorDesc desc = std::get(metalTensor); + return desc.dataType; + } + if (std::holds_alternative>( + metalTensor)) { + throw std::runtime_error("Datatype mapping from buffer not supported yet."); + } + assert(false && "Unsupported tensor type"); + return ::tt::target::DataType::Float32; +} + Device openDevice(std::vector const &deviceIds, std::vector const &numHWCQs) { assert(numHWCQs.empty() || numHWCQs.size() == deviceIds.size()); diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 4d647b806..37a69c154 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -50,6 +50,12 @@ Tensor createTensor(std::shared_ptr data, return Tensor(tensor, data, DeviceRuntime::TTNN); } +tt::target::DataType getTensorDataType(Tensor tensor) { + const ::ttnn::Tensor &nnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + return utils::fromTTNNDataType(nnTensor.dtype()); +} + Device openDevice(std::vector const &deviceIds, std::vector const &numHWCQs) { assert(deviceIds.size() == 1 && "Only one device is supported for now"); diff --git a/runtime/lib/ttnn/utils.h b/runtime/lib/ttnn/utils.h index 5bb5b84ea..2cc0ffde8 100644 --- a/runtime/lib/ttnn/utils.h +++ b/runtime/lib/ttnn/utils.h @@ -37,6 +37,26 @@ inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { } } +inline ::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) { + switch (dataType) { + case ::ttnn::DataType::FLOAT32: + return ::tt::target::DataType::Float32; + case ::ttnn::DataType::BFLOAT16: + return ::tt::target::DataType::BFloat16; + case ::ttnn::DataType::BFLOAT8_B: + return ::tt::target::DataType::BFP_BFloat8; + case ::ttnn::DataType::BFLOAT4_B: + return ::tt::target::DataType::BFP_BFloat4; + case ::ttnn::DataType::UINT32: + return ::tt::target::DataType::UInt32; + case ::ttnn::DataType::UINT16: + return ::tt::target::DataType::UInt16; + + default: + throw std::runtime_error("Unsupported data type"); + } +} + inline std::vector toShapeFromFBShape(const flatbuffers::Vector &vec) { return std::vector(vec.begin(), vec.end());