Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#16067: [WIP] Flatbuffers-based serialization / deserialization for Tensors #16295

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cmake/flatbuffers.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Function to generate FlatBuffers headers
function(GENERATE_FBS_HEADER FBS_FILE)
get_filename_component(FBS_FILE_NAME ${FBS_FILE} NAME)
set(FBS_GENERATED_HEADER "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers_gen/${FBS_FILE_NAME}_generated.h")
add_custom_command(
OUTPUT
${FBS_GENERATED_HEADER}
COMMAND
flatc --cpp --scoped-enums -I ${CMAKE_CURRENT_SOURCE_DIR} -o "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers_gen/"
${FBS_FILE}
DEPENDS
flatc
${FBS_FILE}
COMMENT "Building C++ header for ${FBS_FILE}"
)
set(FBS_GENERATED_HEADER ${FBS_GENERATED_HEADER} PARENT_SCOPE)
endfunction()
15 changes: 15 additions & 0 deletions dependencies/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,18 @@ CPMAddPackage(
OPTIONS
"XTENSOR_ENABLE_TESTS OFF"
)

############################################################################################################################
# flatbuffers : https://github.com/google/flatbuffers
############################################################################################################################

CPMAddPackage(
NAME flatbuffers
GITHUB_REPOSITORY google/flatbuffers
GIT_TAG v24.3.25
OPTIONS
"FLATBUFFERS_BUILD_FLATC ON"
"FLATBUFFERS_BUILD_TESTS OFF"
"FLATBUFFERS_SKIP_MONSTER_EXTRA ON"
"FLATBUFFERS_STRICT_MODE ON"
)
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_flatbuffers_conversion.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_sharding_with_alignment.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "ttnn/tensor/serialization/flatbuf_utils.hpp"

namespace tt::tt_metal {

TEST(FlatbuffersConversion, Roundtrip) {}

} // namespace tt::tt_metal
16 changes: 14 additions & 2 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,18 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp
)

#Split src and python bindings
#We only build python bindings optionally
# Include helper functions and generate headers from flatbuffer schemas
include(${PROJECT_SOURCE_DIR}/cmake/flatbuffers.cmake)

set(FLATBUFFER_SCHEMAS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/tensor/serialization/tensor.fbs)

foreach(FBS_FILE ${FLATBUFFER_SCHEMAS})
GENERATE_FBS_HEADER(${FBS_FILE})
list(APPEND ALL_TTNN_SRCS ${FBS_GENERATED_HEADER})
endforeach()

# Split src and python bindings
# We only build python bindings optionally
set(TTNN_SRC)

set(PYBIND_SRC
Expand Down Expand Up @@ -609,8 +619,10 @@ set(TTNN_PUBLIC_INCLUDE_DIRS
${CMAKE_CURRENT_SOURCE_DIR} # ${PROJECT_SOURCE_DIR}/ttnn
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/deprecated # symlink to tt_eager; should become native folder once merge complete
${CMAKE_CURRENT_SOURCE_DIR}/cpp
${CMAKE_CURRENT_BINARY_DIR}/flatbuffers_gen
)
set(TTNN_PUBLIC_LINK_LIBRARIES
FlatBuffers::FlatBuffers
metal_common_libs
Metalium::Metal
Boost::container
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(TENSOR_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/types.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/serialization/flatbuf_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape/shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape/shape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/alignment.cpp
Expand Down
102 changes: 102 additions & 0 deletions ttnn/cpp/ttnn/tensor/serialization/flatbuf_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "buffers/buffer_constants.hpp"
#include "flatbuffers/flatbuffer_builder.h"
#include "tensor_generated.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ clang-diagnostic-error ⚠️
tensor_generated.h file not found

#include "ttnn/tensor/tensor.hpp"

namespace tt::tt_metal {
namespace {

flatbuf::Layout to_flatbuf_layout(Layout layout) {
switch (layout) {
case Layout::ROW_MAJOR: return flatbuf::Layout::ROW_MAJOR;
case Layout::TILE: return flatbuf::Layout::TILE;
case Layout::INVALID: return flatbuf::Layout::INVALID;
}
}

flatbuf::DataType to_flatbuf_data_type(DataType dtype) {
switch (dtype) {
case DataType::FLOAT32: return flatbuf::DataType::FLOAT32;
case DataType::BFLOAT16: return flatbuf::DataType::BFLOAT16;
case DataType::UINT8: return flatbuf::DataType::UINT8;
case DataType::UINT16: return flatbuf::DataType::UINT16;
case DataType::INT32: return flatbuf::DataType::INT32;
case DataType::UINT32: return flatbuf::DataType::UINT32;
case DataType::BFLOAT8_B: return flatbuf::DataType::BFLOAT8_B;
case DataType::BFLOAT4_B: return flatbuf::DataType::BFLOAT4_B;
case DataType::INVALID: return flatbuf::DataType::INVALID;
}
}

Layout from_flatbuf_layout(flatbuf::Layout layout) {
switch (layout) {
case flatbuf::Layout::ROW_MAJOR: return Layout::ROW_MAJOR;
case flatbuf::Layout::TILE: return Layout::TILE;
case flatbuf::Layout::INVALID: return Layout::INVALID;
}
}

DataType from_flatbuf_data_type(flatbuf::DataType dtype) {
switch (dtype) {
case flatbuf::DataType::FLOAT32: return DataType::FLOAT32;
case flatbuf::DataType::BFLOAT16: return DataType::BFLOAT16;
case flatbuf::DataType::UINT8: return DataType::UINT8;
case flatbuf::DataType::UINT16: return DataType::UINT16;
case flatbuf::DataType::INT32: return DataType::INT32;
case flatbuf::DataType::UINT32: return DataType::UINT32;
case flatbuf::DataType::BFLOAT8_B: return DataType::BFLOAT8_B;
case flatbuf::DataType::BFLOAT4_B: return DataType::BFLOAT4_B;
case flatbuf::DataType::INVALID: return DataType::INVALID;
}
}

flatbuf::TensorMemoryLayout to_flatbuf_memory_layout(TensorMemoryLayout layout) {
switch (layout) {
case TensorMemoryLayout::INTERLEAVED: return flatbuf::TensorMemoryLayout::Interleaved;
case TensorMemoryLayout::SINGLE_BANK: return flatbuf::TensorMemoryLayout::SingleBank;
case TensorMemoryLayout::HEIGHT_SHARDED: return flatbuf::TensorMemoryLayout::HeightSharded;
case TensorMemoryLayout::WIDTH_SHARDED: return flatbuf::TensorMemoryLayout::WidthSharded;
case TensorMemoryLayout::BLOCK_SHARDED: return flatbuf::TensorMemoryLayout::BlockSharded;
}
}

flatbuf::BufferType to_flatbuf_buffer_type(BufferType buffer_type) {
switch (buffer_type) {
case BufferType::DRAM: return flatbuf::BufferType::DRAM;
case BufferType::L1: return flatbuf::BufferType::L1;
case BufferType::SYSTEM_MEMORY: return flatbuf::BufferType::SYSTEM_MEMORY;
case BufferType::L1_SMALL: return flatbuf::BufferType::L1_SMALL;
case BufferType::TRACE: return flatbuf::BufferType::TRACE;
}
}

} // namespace

flatbuffers::Offset<flatbuf::Tensor> to_flatbuf_tensor(const Tensor& tensor, flatbuffers::FlatBufferBuilder& builder) {
std::vector<uint32_t> shape_values(tensor.get_logical_shape().cbegin(), tensor.get_logical_shape().cend());

auto shape_fb = flatbuf::CreateTensorShape(builder, builder.CreateVector(shape_values));

const auto& tile_spec = tensor.tensor_spec().page_config().get_tile();
auto tile_fb = flatbuf::Tile(tile_spec.get_height(), tile_spec.get_width(), tile_spec.get_transpose_within_face());
auto page_config_fb = flatbuf::CreatePageConfig(builder, to_flatbuf_layout(tensor.layout()), &tile_fb);

auto memory_config_fb = flatbuf::CreateMemoryConfig(
builder,
to_flatbuf_memory_layout(tensor.memory_config().memory_layout),
to_flatbuf_buffer_type(tensor.memory_config().buffer_type));

auto tensor_spec_fb = flatbuf::CreateTensorSpec(
builder, shape_fb, to_flatbuf_data_type(tensor.dtype()), page_config_fb, memory_config_fb);

// TODO: finish with the tensor data.

return flatbuf::CreateTensor(builder, tensor_spec_fb);
}

Tensor from_flatbuf_tensor(const flatbuf::Tensor* fb_tensor) {
// TODO: Implement.
return Tensor();
}

} // namespace tt::tt_metal
11 changes: 11 additions & 0 deletions ttnn/cpp/ttnn/tensor/serialization/flatbuf_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "flatbuffers/flatbuffer_builder.h"
#include "tensor_generated.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ clang-diagnostic-error ⚠️
tensor_generated.h file not found

#include "ttnn/tensor/tensor.hpp"

namespace tt::tt_metal {

flatbuffers::Offset<flatbuf::Tensor> to_flatbuf_tensor(const Tensor& tensor, flatbuffers::FlatBufferBuilder& builder);

Tensor from_flatbuf_tensor(const flatbuf::Tensor& tensor);

} // namespace tt::tt_metal
79 changes: 79 additions & 0 deletions ttnn/cpp/ttnn/tensor/serialization/tensor.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

namespace tt.tt_metal.flatbuf;

// TODO: move to a more appropriate location.
enum BufferType: ushort {
INVALID = 0,
DRAM = 1,
L1 = 2,
SYSTEM_MEMORY = 3,
L1_SMALL = 4,
TRACE = 5,
}

enum Layout: ushort {
INVALID = 0,
ROW_MAJOR = 1,
TILE = 2,
}

struct Tile {
height: uint;
width: uint;
transpose: bool;
}

table PageConfig {
layout: Layout (id: 0);
tile: Tile (id: 1);
}

enum DataType: ushort {
INVALID = 0,
UINT8 = 1,
UINT16 = 2,
UINT32 = 3,
INT32 = 4,
BFLOAT4_B = 5,
BFLOAT8_B = 6,
BFLOAT16 = 7,
FLOAT32 = 8,
}

enum TensorMemoryLayout: ushort {
None = 0,
Interleaved = 1,
SingleBank = 2,
HeightSharded = 3,
WidthSharded = 4,
BlockSharded = 5,
}

table MemoryConfig {
memory_layout: TensorMemoryLayout (id: 0);
buffer_type: BufferType (id: 1);
// TODO: shard spec.
}

table TensorShape {
values: [uint] (id: 0);
}

table TensorSpec {
shape: TensorShape (id: 0);
data_type: DataType (id: 1);
page_config: PageConfig (id: 2);
memory_config: MemoryConfig (id: 3);
}

// TODO: decide what to do in case of distributed tensors:
// 1. encode distribution strategy.
// 2. encode sub-tensor shapes.
table TensorBuffer {
data: [ubyte] (id: 0);
}

table Tensor {
spec: TensorSpec (id: 0);
buffers: [TensorBuffer] (id: 1);
}
Loading