Skip to content

Commit

Permalink
Determine data type of target tensor during runtime. (#524)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT authored Aug 30, 2024
1 parent f8da4b7 commit 1e3bb79
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 0 deletions.
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
desc.dataType);
}

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<uint8_t> const &numHWCQs = {});

Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
desc.dataType);
}

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});

Expand Down
2 changes: 2 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
desc.dataType);
}

tt::target::DataType getTensorDataType(Tensor tensor);

Device openDevice(std::vector<int> const &deviceIds = {0},
std::vector<std::uint8_t> const &numHWCQs = {});

Expand Down
15 changes: 15 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ Tensor createTensor(std::shared_ptr<void> data,
throw std::runtime_error("runtime is not enabled");
}

tt::target::DataType getTensorDataType(Tensor tensor) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorDataType(tensor);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorDataType(tensor);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
Expand Down
15 changes: 15 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ Tensor createTensor(std::shared_ptr<void> data,
DeviceRuntime::TTMetal);
}

tt::target::DataType getTensorDataType(Tensor tensor) {
const MetalTensor &metalTensor =
tensor.as<MetalTensor>(DeviceRuntime::TTMetal);
if (std::holds_alternative<TensorDesc>(metalTensor)) {
TensorDesc desc = std::get<TensorDesc>(metalTensor);
return desc.dataType;
}
if (std::holds_alternative<std::shared_ptr<::tt::tt_metal::Buffer>>(
metalTensor)) {
throw std::runtime_error("Datatype mapping from buffer not supported yet.");
}
assert(false && "Unsupported tensor type");
return ::tt::target::DataType::Float32;
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
assert(numHWCQs.empty() || numHWCQs.size() == deviceIds.size());
Expand Down
6 changes: 6 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ Tensor createTensor(std::shared_ptr<void> data,
return Tensor(tensor, data, DeviceRuntime::TTNN);
}

tt::target::DataType getTensorDataType(Tensor tensor) {
const ::ttnn::Tensor &nnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
return utils::fromTTNNDataType(nnTensor.dtype());
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
assert(deviceIds.size() == 1 && "Only one device is supported for now");
Expand Down
20 changes: 20 additions & 0 deletions runtime/lib/ttnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) {
}
}

inline ::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) {
switch (dataType) {
case ::ttnn::DataType::FLOAT32:
return ::tt::target::DataType::Float32;
case ::ttnn::DataType::BFLOAT16:
return ::tt::target::DataType::BFloat16;
case ::ttnn::DataType::BFLOAT8_B:
return ::tt::target::DataType::BFP_BFloat8;
case ::ttnn::DataType::BFLOAT4_B:
return ::tt::target::DataType::BFP_BFloat4;
case ::ttnn::DataType::UINT32:
return ::tt::target::DataType::UInt32;
case ::ttnn::DataType::UINT16:
return ::tt::target::DataType::UInt16;

default:
throw std::runtime_error("Unsupported data type");
}
}

inline std::vector<uint32_t>
toShapeFromFBShape(const flatbuffers::Vector<int32_t> &vec) {
return std::vector<uint32_t>(vec.begin(), vec.end());
Expand Down

0 comments on commit 1e3bb79

Please sign in to comment.