diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp index 8dc25e1abf4..cd3e9709f49 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp @@ -7,7 +7,11 @@ #include #include +#include "tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp" +#include "common/bfloat16.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/tensor/types.hpp" #include "ttnn/tensor/xtensor/conversion_utils.hpp" #include "ttnn/tensor/xtensor/xtensor_all_includes.hpp" @@ -15,8 +19,14 @@ namespace ttnn { namespace { using ::testing::Eq; +using ::testing::FloatNear; using ::testing::Pointwise; +template +testing::Matcher ShapeIs(Args... args) { + return testing::Eq(ttnn::SimpleShape({args...})); +} + const std::vector& get_shapes_for_test() { static auto* shapes = new std::vector{ ttnn::SimpleShape{1}, @@ -35,13 +45,14 @@ TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype, Layou } template -std::vector arange(int64_t start, int64_t end, int64_t step) { +std::vector arange(int64_t start, int64_t end, int64_t step, std::optional cap = std::nullopt) { std::vector result; for (int el : xt::arange(start, end, step)) { + int capped_el = cap ? el % *cap : el; if constexpr (std::is_same_v) { - result.push_back(T(static_cast(el))); + result.push_back(T(static_cast(capped_el))); } else { - result.push_back(static_cast(el)); + result.push_back(static_cast(capped_el)); } } return result; @@ -50,7 +61,7 @@ std::vector arange(int64_t start, int64_t end, int64_t step) { template class VectorConversionTest : public ::testing::Test {}; -using TestTypes = ::testing::Types; +using TestTypes = ::testing::Types; TYPED_TEST_SUITE(VectorConversionTest, TestTypes); TYPED_TEST(VectorConversionTest, Roundtrip) { @@ -74,21 +85,17 @@ TYPED_TEST(VectorConversionTest, RoundtripTilezedLayout) { ttnn::SimpleShape shape{128, 128}; auto input = arange(0, shape.volume(), 1); - // TODO: Support this. - EXPECT_ANY_THROW( - Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type(), Layout::TILE))); - auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type())) - .to(Layout::TILE) + auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type(), Layout::TILE)) .template to_vector(); + EXPECT_THAT(output, Pointwise(Eq(), input)); } TYPED_TEST(VectorConversionTest, InvalidDtype) { ttnn::SimpleShape shape{32, 32}; - auto input = arange(0, 42, 1); + auto input = arange(0, shape.volume(), 1); - ASSERT_NE(input.size(), shape.volume()); EXPECT_ANY_THROW(Tensor::from_vector( input, get_tensor_spec( @@ -97,7 +104,7 @@ TYPED_TEST(VectorConversionTest, InvalidDtype) { (std::is_same_v ? DataType::FLOAT32 : DataType::INT32)))); } -TEST(FloatVectorConversionTest, RoundtripBfloat16Representation) { +TEST(FloatVectorConversionTest, RoundtripBfloat16) { for (const auto& shape : get_shapes_for_test()) { auto input_bf16 = arange(0, static_cast(shape.volume()), 1); std::vector input_ft; @@ -115,5 +122,69 @@ TEST(FloatVectorConversionTest, RoundtripBfloat16Representation) { } } +class BlockFloatVectorConversionTest : public ::testing::TestWithParam {}; + +TEST_P(BlockFloatVectorConversionTest, InvalidLayout) { + ttnn::SimpleShape shape{32, 32}; + // Block float types are only supported in TILE layout. + EXPECT_ANY_THROW( + Tensor::from_vector(std::vector(shape.volume()), get_tensor_spec(shape, GetParam(), Layout::ROW_MAJOR))); +} + +TEST_P(BlockFloatVectorConversionTest, Roundtrip) { + ttnn::SimpleShape shape{32, 32}; + std::vector input = arange(0, shape.volume(), 1, /*cap=*/32); + + auto output = Tensor::from_vector(input, get_tensor_spec(shape, GetParam(), Layout::TILE)).to_vector(); + EXPECT_THAT(output, Pointwise(FloatNear(4.0f), input)); +} + +TEST_P(BlockFloatVectorConversionTest, RoundtripWithPadding) { + ttnn::SimpleShape shape{14, 47}; + std::vector input = arange(0, shape.volume(), 1, /*cap=*/32); + + auto output = Tensor::from_vector(input, get_tensor_spec(shape, GetParam(), Layout::TILE)); + + EXPECT_THAT(output.get_logical_shape(), ShapeIs(14, 47)); + EXPECT_THAT(output.get_padded_shape(), ShapeIs(32, 64)); + + EXPECT_THAT(output.to_vector(), Pointwise(FloatNear(4.0f), input)); +} + +TEST_P(BlockFloatVectorConversionTest, RoundtripWithPaddingAndCustomTile) { + ttnn::SimpleShape shape{14, 47}; + std::vector input = arange(0, shape.volume(), 1, /*cap=*/32); + + TensorSpec spec(shape, TensorLayout(GetParam(), PageConfig(Layout::TILE, Tile({16, 16})), MemoryConfig{})); + auto output = Tensor::from_vector(input, spec); + + EXPECT_THAT(output.get_logical_shape(), ShapeIs(14, 47)); + EXPECT_THAT(output.get_padded_shape(), ShapeIs(16, 48)); + + EXPECT_THAT(output.to_vector(), Pointwise(FloatNear(4.0f), input)); +} + +INSTANTIATE_TEST_SUITE_P( + BlockFloatVectorConversionTest, + BlockFloatVectorConversionTest, + ::testing::Values(DataType::BFLOAT4_B, DataType::BFLOAT8_B)); + +using DeviceVectorConversionTest = TTNNFixtureWithDevice; + +TEST_F(DeviceVectorConversionTest, RoundtripWithMemoryConfig) { + ttnn::SimpleShape shape{128, 128}; + + auto input = arange(0, shape.volume(), 1); + + TensorSpec spec( + shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{.buffer_type = BufferType::L1})); + auto output = Tensor::from_vector(input, spec, device_); + + EXPECT_TRUE(is_device_tensor(output)); + EXPECT_TRUE(output.memory_config().is_l1()); + + EXPECT_THAT(output.to_vector(), Pointwise(Eq(), input)); +} + } // namespace } // namespace ttnn diff --git a/tt_metal/common/bfloat4.hpp b/tt_metal/common/bfloat4.hpp index 991a4ec21c2..b94fdd2d9d0 100644 --- a/tt_metal/common/bfloat4.hpp +++ b/tt_metal/common/bfloat4.hpp @@ -10,15 +10,16 @@ #include #include "tt_metal/common/assert.hpp" +#include "tt_metal/common/blockfloat_common.hpp" #include "tt_metal/common/tt_backend_api_types.hpp" +#include "tt_metal/tt_stl/span.hpp" #include "tracy/Tracy.hpp" -#include "blockfloat_common.hpp" // TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor struct bfloat4_b {}; inline std::vector pack_fp32_vec_as_bfp4_tiles( - const std::vector& fp32_vec, + tt::stl::Span fp32_vec, bool row_major_input, bool is_exp_a, const std::optional& tile = std::nullopt) { diff --git a/tt_metal/common/bfloat8.hpp b/tt_metal/common/bfloat8.hpp index d302cb6ac19..e37405f8c64 100644 --- a/tt_metal/common/bfloat8.hpp +++ b/tt_metal/common/bfloat8.hpp @@ -10,9 +10,10 @@ #include #include "tt_metal/common/assert.hpp" +#include "tt_metal/common/blockfloat_common.hpp" #include "tt_metal/common/tt_backend_api_types.hpp" +#include "tt_metal/tt_stl/span.hpp" #include "tracy/Tracy.hpp" -#include "blockfloat_common.hpp" // TODO: empty struct to facilitate Tensor template logic. Reconsider how/why templating is supported in Tensor struct bfloat8_b {}; @@ -99,7 +100,7 @@ inline uint32_t create_packed_bfp8_packed_as_u32( } inline std::vector pack_fp32_vec_as_bfp8_tiles( - const std::vector& fp32_vec, + tt::stl::Span fp32_vec, bool row_major_input, bool is_exp_a, const std::optional& tile = std::nullopt) { diff --git a/tt_metal/common/blockfloat_common.hpp b/tt_metal/common/blockfloat_common.hpp index b29258fd9b8..c378fc83f60 100644 --- a/tt_metal/common/blockfloat_common.hpp +++ b/tt_metal/common/blockfloat_common.hpp @@ -13,6 +13,7 @@ #include "tt_metal/common/tt_backend_api_types.hpp" #include "tracy/Tracy.hpp" #include "tt_metal/impl/tile/tile.hpp" +#include "tt_metal/tt_stl/span.hpp" inline uint8_t get_max_exp(const std::vector& vec, bool is_exp_a) { TT_ASSERT(vec.size() == 16); @@ -288,7 +289,7 @@ inline uint32_t create_packed_bfp_packed_as_u32( template inline std::vector pack_fp32_vec_as_bfp_tiles( - const std::vector& fp32_vec, + tt::stl::Span fp32_vec, bool row_major_input, bool is_exp_a, const std::optional& tile = std::nullopt) { @@ -344,7 +345,7 @@ inline std::vector pack_fp32_vec_as_bfp_tiles( } else { data_index = fp32_element_index++; } - float float_num = fp32_vec.at(data_index); + float float_num = fp32_vec[data_index]; uint32_t uint32_num = *reinterpret_cast(&float_num); single_row.push_back(uint32_num); } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index f72f33a5039..f543e5f200e 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -8,6 +8,7 @@ #include #include +#include "tt_metal/common/assert.hpp" #include "tt_metal/common/bfloat16.hpp" #include "impl/buffers/buffer_constants.hpp" #include "tt_metal/tt_stl/overloaded.hpp" @@ -32,27 +33,39 @@ namespace tt::tt_metal { namespace { template -Tensor create_owned_tensor_from_span(tt::stl::Span data, const TensorSpec& spec) { - // TODO: support tilized layouts. - TT_FATAL(spec.layout() == Layout::ROW_MAJOR, "Unsupported layout: {}", spec.layout()); - auto buffer = tt::tt_metal::owned_buffer::create(std::vector(data.begin(), data.end())); - auto storage = OwnedStorage{std::move(buffer)}; - return Tensor{std::move(storage), spec}; +Tensor create_owned_tensor_from_row_major_data( + std::vector&& data, const TensorSpec& spec, std::optional device = std::nullopt) { + TensorSpec result_cpu_spec( + spec.logical_shape(), + TensorLayout(spec.data_type(), PageConfig(Layout::ROW_MAJOR, spec.tile()), MemoryConfig{})); + + Tensor output(OwnedStorage{owned_buffer::create(std::move(data))}, result_cpu_spec); + + if (spec.layout() == Layout::TILE) { + // TODO: whenever possible, perform tiliziation on device. + output = output.to(Layout::TILE); + } + + if (device.has_value()) { + output = output.to(device->get_devices(), spec.memory_config()); + } + + return output; } // TODO: optimize precomputing multipliers template -std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { - auto tiled_buffer = tt::tt_metal::host_buffer::get_as(cpu_tensor); - auto untiled_shape = cpu_tensor.get_logical_shape(); - auto tiled_shape = cpu_tensor.get_padded_shape(); +std::vector unpad_tensor_to_vec(const Tensor& cpu_tensor) { + auto tiled_buffer = host_buffer::get_as(cpu_tensor); + const auto untiled_shape = cpu_tensor.get_logical_shape(); + const auto tiled_shape = cpu_tensor.get_padded_shape(); // Calculate total size of the untiled tensor size_t total_size = untiled_shape.volume(); std::vector untiled_data(total_size); - auto compute_flat_index = [](const std::vector& indices, ttnn::SimpleShape& shape) -> uint32_t { + auto compute_flat_index = [](const std::vector& indices, const ttnn::SimpleShape& shape) -> uint32_t { uint32_t flat_index = 0; uint32_t multiplier = 1; for (int i = (int)indices.size() - 1; i >= 0; --i) { @@ -589,28 +602,53 @@ const Storage& Tensor::get_storage() const { } template <> -Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec) { +Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device) { size_t volume = spec.logical_shape().volume(); TT_FATAL( buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - if (spec.data_type() == DataType::FLOAT32) { - return create_owned_tensor_from_span(buffer, spec); - } else if (spec.data_type() == DataType::BFLOAT16) { - std::vector bfloat16_data; - bfloat16_data.reserve(buffer.size()); - std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { - return bfloat16(value); - }); - return create_owned_tensor_from_span( - tt::stl::Span(bfloat16_data.data(), bfloat16_data.size()), spec); - } else { - // TODO: support bf8 and bf4 - TT_THROW("Unsupported data type for from_span: {}", spec.data_type()); + switch (spec.data_type()) { + case DataType::FLOAT32: + return create_owned_tensor_from_row_major_data( + std::vector(buffer.begin(), buffer.end()), spec, device); + case DataType::BFLOAT16: { + std::vector bfloat16_data; + bfloat16_data.reserve(buffer.size()); + std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { + return bfloat16(value); + }); + return create_owned_tensor_from_row_major_data(std::move(bfloat16_data), spec, device); + } + case DataType::BFLOAT8_B: + case DataType::BFLOAT4_B: { + TT_FATAL( + spec.tensor_layout().get_layout() == Layout::TILE, + "Tile layout is required for BFLOAT8_B and BFLOAT4_B"); + + // TODO: Implement `encode_tensor_data` in terms of a Span, avoid tilizing the data, as pack_fp32_vec_as_* + // support row-major input. + const auto& tile = spec.tensor_layout().get_page_config().get_tile(); + auto physical_data = + tensor_impl::encode_tensor_data(std::vector(buffer.begin(), buffer.end()), spec); + std::vector packed_block_floats = + spec.data_type() == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(physical_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile) + : pack_fp32_vec_as_bfp4_tiles(physical_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); + + Tensor tensor(OwnedStorage{owned_buffer::create(std::move(packed_block_floats))}, spec); + if (device.has_value()) { + tensor = tensor.to(device->get_devices(), spec.memory_config()); + } + return tensor; + } + default: { + TT_THROW("Unsupported data type for from_span: {}", spec.data_type()); + } } } template -Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec) { +Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec, std::optional device) { size_t volume = spec.logical_shape().volume(); TT_FATAL( buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); @@ -619,19 +657,32 @@ Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec) "Unsupported data type for from_span: got {}, expected: {}", spec.data_type(), convert_to_data_type()); - return create_owned_tensor_from_span(buffer, spec); + return create_owned_tensor_from_row_major_data(std::vector(buffer.begin(), buffer.end()), spec, device); } template <> std::vector Tensor::to_vector() const { - auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR); - if (cpu_tensor.get_dtype() == DataType::BFLOAT16) { - return untile_tensor_to_vec(cpu_tensor); - } else if (cpu_tensor.get_dtype() == DataType::FLOAT32) { - return untile_tensor_to_vec(cpu_tensor); - } else { - // TODO: support bf4, bf8. - TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype()); + Tensor cpu_tensor = this->cpu(); + switch (cpu_tensor.get_dtype()) { + case DataType::BFLOAT16: return unpad_tensor_to_vec(cpu_tensor.to(Layout::ROW_MAJOR)); + case DataType::FLOAT32: return unpad_tensor_to_vec(cpu_tensor.to(Layout::ROW_MAJOR)); + case DataType::BFLOAT8_B: + case DataType::BFLOAT4_B: { + const auto& tile = cpu_tensor.get_tensor_spec().tile(); + std::vector packed_data = + owned_buffer::get_as(std::get(cpu_tensor.storage()).buffer).get(); + std::vector unpacked_data = + cpu_tensor.get_tensor_spec().data_type() == DataType::BFLOAT8_B + ? unpack_bfp8_tiles_into_float_vec( + packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile) + : unpack_bfp4_tiles_into_float_vec( + packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + + return tensor_impl::decode_tensor_data(unpacked_data, cpu_tensor.tensor_spec()); + } + default: { + TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype()); + } } } @@ -643,16 +694,26 @@ std::vector Tensor::to_vector() const { "Unsupported data type for to_vector: got {}, expected: {}", cpu_tensor.get_dtype(), convert_to_data_type()); - return untile_tensor_to_vec(cpu_tensor); + return unpad_tensor_to_vec(cpu_tensor); } // Instantiate explicitly for the supported types. -template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); -template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); -template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); +template Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device); +template Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device); +template Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device); +template Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device); +template Tensor Tensor::from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device); + template std::vector Tensor::to_vector() const; -template std::vector Tensor::to_vector() const; template std::vector Tensor::to_vector() const; +template std::vector Tensor::to_vector() const; +template std::vector Tensor::to_vector() const; +template std::vector Tensor::to_vector() const; Tensor Tensor::to( Device* target_device, diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 6827c421320..76dbed2f516 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -15,6 +15,7 @@ #include "common/bfloat8.hpp" #include "common/test_tiles.hpp" #include "common/tt_backend_api_types.hpp" +#include "ttnn/any_device.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/types.hpp" @@ -139,39 +140,30 @@ struct Tensor { // Converts a buffer of elements of type `T` to a `Tensor`. // Elements in the buffer are assumed to be stored in row-major order. The size of the buffer and the type of the - // elements have to match `spec`. + // elements have to match `spec`; block float formats such as BFLOAT8_B and BFLOAT4_B require `T` equal `float`. // // The data in the buffer is copied into a tensor with an owned storage. // - // IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t), - // and only row-major layout. - // - // TODO: - // 1. add support for returning a tensor with a borrowed storage based off the buffer. - // 2. add support for sharding. - // 3. add support for block float formats. - // 4. add support for tilized layouts. - // 5. add support for on-device tensor creation. + // TODO: add support for returning a tensor with borrowed storage based off the buffer. + // TODO: handle tilization and padding in face of sharding. template - static Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); + static Tensor from_span( + tt::stl::Span buffer, const TensorSpec& spec, std::optional device = std::nullopt); // Same as `from_span`, but takes a vector instead. template - static Tensor from_vector(const std::vector& buffer, const TensorSpec& spec) { - return from_span(tt::stl::Span(buffer.data(), buffer.size()), spec); + static Tensor from_vector( + const std::vector& buffer, const TensorSpec& spec, std::optional device = std::nullopt) { + return from_span(tt::stl::Span(buffer.data(), buffer.size()), spec, device); } // Converts a `Tensor` to a `std::vector`. // Elements in the vector will be stored in row-major order. The type of the requested vector has to match that of - // the `Tensor`. + // the `Tensor`; block float formats such as BFLOAT8_B and BFLOAT4_B require `T` equal `float`. // // If the tensor resides on a device, it will be brough back to host. // - // IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t). - // - // TODO: - // 1. add support for sharding. - // 2. add support for block float formats. + // TODO: handle tilization and padding in face of sharding. template std::vector to_vector() const; diff --git a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp index e01a6838bd8..69a6d96504b 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp @@ -151,9 +151,13 @@ std::vector chunk_impl(const Tensor& tensor, const TensorLayout& layout, std::vector chunk(const Tensor& tensor, int num_chunks, int dim) { const auto& reference_layout = tensor.tensor_spec().tensor_layout(); switch (reference_layout.get_data_type()) { - case DataType::BFLOAT16: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); + case DataType::BFLOAT4_B: + case DataType::BFLOAT8_B: + case DataType::BFLOAT16: case DataType::FLOAT32: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); case DataType::INT32: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); + case DataType::UINT8: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); + case DataType::UINT16: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); case DataType::UINT32: return adaptor::chunk_impl(tensor, reference_layout, num_chunks, dim); default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type()); } @@ -163,9 +167,13 @@ Tensor concat(const std::vector& tensors, int dim) { TT_FATAL(tensors.size() > 0, "Cannot concatenate an empty list of tensors"); const auto& reference_layout = tensors.front().tensor_spec().tensor_layout(); switch (reference_layout.get_data_type()) { - case DataType::BFLOAT16: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::BFLOAT4_B: + case DataType::BFLOAT8_B: + case DataType::BFLOAT16: case DataType::FLOAT32: return adaptor::concat_impl(tensors, reference_layout, dim); case DataType::INT32: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::UINT8: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::UINT16: return adaptor::concat_impl(tensors, reference_layout, dim); case DataType::UINT32: return adaptor::concat_impl(tensors, reference_layout, dim); default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type()); }