diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index ef4c1177c7e6..d884c453334f 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -26,8 +26,10 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_sharding_with_alignment.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_vector_conversion.cpp ) add_executable(unit_tests_ttnn ${TTNN_UNIT_TESTS_SRC}) diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp index a66c1b737497..e60ee4912e3e 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp @@ -4,12 +4,10 @@ #include -#include "common/bfloat16.hpp" #include "ttnn/distributed/api.hpp" #include "ttnn/operations/functions.hpp" #include "ttnn/tensor/xtensor/conversion_utils.hpp" #include "ttnn_test_fixtures.hpp" -#include #include #include @@ -20,7 +18,8 @@ using ::ttnn::experimental::xtensor::from_vector; using TensorDistributionTest = T3kMultiDeviceFixture; TEST_F(TensorDistributionTest, Replication) { - Tensor input_tensor = from_vector(std::vector{42.F, 13.F, -99.F}, ttnn::Shape{1, 1, 1, 3}); + Tensor input_tensor = + from_vector(std::vector{42.F, 13.F, -99.F}, ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32); auto mapper = api::replicate_tensor_to_mesh_mapper(*mesh_device_); Tensor replicated_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper); @@ -28,13 +27,14 @@ TEST_F(TensorDistributionTest, Replication) { std::vector device_tensors = api::get_device_tensors(replicated_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (const auto& device_tensor : device_tensors) { - EXPECT_TRUE(ttnn::allclose(device_tensor.cpu(), input_tensor)); + EXPECT_TRUE(ttnn::allclose(device_tensor.cpu(), input_tensor)); } } TEST_F(TensorDistributionTest, Shard1DInvalidDim) { const int num_devices = mesh_device_->num_devices(); - Tensor input_tensor = from_vector(std::vector(num_devices, 0), ttnn::Shape{1, 1, 1, num_devices}); + Tensor input_tensor = + from_vector(std::vector(num_devices, 0), ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32); EXPECT_ANY_THROW({ auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, -1); @@ -50,7 +50,8 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) { TEST_F(TensorDistributionTest, Shard1DTooFewShards) { const int num_devices = mesh_device_->num_devices(); ASSERT_LT(3, num_devices); - Tensor input_tensor = from_vector(std::vector{42.F, 13.F, -99.F}, ttnn::Shape{1, 1, 1, 3}); + Tensor input_tensor = + from_vector(std::vector{42.F, 13.F, -99.F}, ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32); EXPECT_ANY_THROW({ auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, 3); @@ -64,7 +65,7 @@ TEST_F(TensorDistributionTest, Shard1D) { for (int i = 0; i < num_devices; i++) { test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); } - Tensor input_tensor = from_vector(test_data, ttnn::Shape{1, num_devices, 3, 1}); + Tensor input_tensor = from_vector(test_data, ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32); auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, 1); Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper); @@ -72,15 +73,16 @@ TEST_F(TensorDistributionTest, Shard1D) { std::vector device_tensors = api::get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (int i = 0; i < device_tensors.size(); i++) { - auto expected = from_vector(std::vector{i * 1.F, i * 2.F, i * 3.F}, ttnn::Shape{1, 1, 3, 1}); - EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); + auto expected = from_vector( + std::vector{i * 1.F, i * 2.F, i * 3.F}, ttnn::SimpleShape{1, 1, 3, 1}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); } auto composer = api::concat_mesh_to_tensor_composer(/*dim=*/0); Tensor concatenated_tensor = api::aggregate_tensor(sharded_tensor, *composer); - Tensor expected_tensor = from_vector(test_data, ttnn::Shape{num_devices, 1, 3, 1}); - EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); + Tensor expected_tensor = from_vector(test_data, ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); } TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) { @@ -110,7 +112,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { const int num_devices = num_rows * num_cols; std::vector test_data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - Tensor input_tensor = from_vector(test_data, ttnn::Shape{1, num_rows, num_cols, 1}); + Tensor input_tensor = from_vector(test_data, ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32); input_tensor.print(); auto mapper = api::shard_tensor_2d_to_mesh_mapper( @@ -127,12 +129,14 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { int i = 0; for (; i < 4; i++) { - auto expected = from_vector(std::vector{0.0, 1.0, 2.0, 3.0}, ttnn::Shape{1, 1, 4, 1}); - EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); + auto expected = + from_vector(std::vector{0.0, 1.0, 2.0, 3.0}, ttnn::SimpleShape{1, 1, 4, 1}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); } for (; i < device_tensors.size(); i++) { - auto expected = from_vector(std::vector{4.0, 5.0, 6.0, 7.0}, ttnn::Shape{1, 1, 4, 1}); - EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); + auto expected = + from_vector(std::vector{4.0, 5.0, 6.0, 7.0}, ttnn::SimpleShape{1, 1, 4, 1}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); } } @@ -146,7 +150,7 @@ TEST_F(TensorDistributionTest, Shard2D) { for (int i = 0; i < num_devices; i++) { test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); } - Tensor input_tensor = from_vector(test_data, ttnn::Shape{1, num_rows, num_cols, 3}); + Tensor input_tensor = from_vector(test_data, ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32); auto mapper = api::shard_tensor_2d_to_mesh_mapper( *mesh_device_, @@ -160,8 +164,9 @@ TEST_F(TensorDistributionTest, Shard2D) { std::vector device_tensors = api::get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (int i = 0; i < device_tensors.size(); i++) { - auto expected = from_vector(std::vector{i * 1.F, i * 2.F, i * 3.F}, ttnn::Shape{1, 1, 1, 3}); - EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); + auto expected = from_vector( + std::vector{i * 1.F, i * 2.F, i * 3.F}, ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(device_tensors[i].cpu(), expected)); } auto composer = api::concat_mesh_2d_to_tensor_composer( @@ -172,8 +177,8 @@ TEST_F(TensorDistributionTest, Shard2D) { }); Tensor concatenated_tensor = api::aggregate_tensor(sharded_tensor, *composer); - Tensor expected_tensor = from_vector(test_data, ttnn::Shape{num_rows, 1, num_cols, 3}); - EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); + Tensor expected_tensor = from_vector(test_data, ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32); + EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); } } // namespace ttnn::distributed::test diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp new file mode 100644 index 000000000000..0d8a6883a76b --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/xtensor/conversion_utils.hpp" +#include "ttnn/tensor/xtensor/partition.hpp" +#include "ttnn/tensor/xtensor/xtensor_all_includes.hpp" + +namespace ttnn { +namespace { + +using ::tt::tt_metal::Tensor; +using ::ttnn::experimental::xtensor::chunk; +using ::ttnn::experimental::xtensor::concatenate; +using ::ttnn::experimental::xtensor::from_vector; + +} // namespace +} // namespace ttnn diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp new file mode 100644 index 000000000000..8a72b8e9b360 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/xtensor/conversion_utils.hpp" + +namespace ttnn { +namespace { + +using ::testing::Eq; +using ::testing::Pointwise; +using ::tt::tt_metal::Tensor; +using ::ttnn::experimental::xtensor::from_vector; +using ::ttnn::experimental::xtensor::to_vector; + +const std::vector& GetShapesForTest() { + static auto* shapes = new std::vector{ + ttnn::SimpleShape{1, 1, 1, 1}, + ttnn::SimpleShape{1, 1, 1, 10}, + ttnn::SimpleShape{1, 32, 32, 16}, + ttnn::SimpleShape{1, 40, 3, 128}, + ttnn::SimpleShape{2, 2}, + ttnn::SimpleShape{1, 1, 1, 1, 10}, + }; + return *shapes; +} + +template +std::vector Arange(int64_t start, int64_t end, int64_t step) { + std::vector result; + for (int64_t i = start; i < end; i += step) { + if constexpr (std::is_same_v) { + result.push_back(T(static_cast(i))); + } else { + result.push_back(static_cast(i)); + } + } + return result; +} + +template +class VectorConversionTest : public ::testing::Test {}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(VectorConversionTest, TestTypes); + +TYPED_TEST(VectorConversionTest, Basic) { + for (const auto& shape : GetShapesForTest()) { + auto input = Arange(0, static_cast(shape.volume()), 1); + auto output = to_vector(from_vector(input, shape, convert_to_data_type())); + EXPECT_THAT(output, Pointwise(Eq(), input)) << "for shape: " << shape; + } +} + +TYPED_TEST(VectorConversionTest, InvalidSize) { + ttnn::SimpleShape shape{32, 32}; + auto input = Arange(0, 42, 1); + + ASSERT_NE(input.size(), shape.volume()); + EXPECT_ANY_THROW(from_vector(input, shape, convert_to_data_type())); +} + +TYPED_TEST(VectorConversionTest, InvalidDtype) { + ttnn::SimpleShape shape{32, 32}; + auto input = Arange(0, 42, 1); + + ASSERT_NE(input.size(), shape.volume()); + EXPECT_ANY_THROW(from_vector( + input, + shape, + // Use INT32 for verification, except for when the actual type is int32_t. + (std::is_same_v ? DataType::FLOAT32 : DataType::INT32))); +} + +TEST(FloatVectorConversionTest, Bfloat16Representation) { + for (const auto& shape : GetShapesForTest()) { + auto input_bf16 = Arange(0, static_cast(shape.volume()), 1); + std::vector input_ft; + input_ft.reserve(input_bf16.size()); + std::transform(input_bf16.begin(), input_bf16.end(), std::back_inserter(input_ft), [](bfloat16 bf) { + return bf.to_float(); + }); + + auto output_bf16 = to_vector(from_vector(input_ft, shape, DataType::BFLOAT16)); + EXPECT_THAT(output_bf16, Pointwise(Eq(), input_bf16)) << "for shape: " << shape; + + auto output_ft = to_vector(from_vector(input_bf16, shape, DataType::BFLOAT16)); + EXPECT_THAT(output_ft, Pointwise(Eq(), input_ft)) << "for shape: " << shape; + } +} + +} // namespace +} // namespace ttnn diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 973288fb014a..7536fbacd75a 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 973288fb014a22ce72cdba1c38a9f41f48532d6d +Subproject commit 7536fbacd75a4ad62047c63c9c54176fae079e06 diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 33a7f6a02671..0f57d4e9dec6 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 33a7f6a026719af509a119d8a4e8e36c7c31854c +Subproject commit 0f57d4e9dec602b68671be8891e7af876285f275 diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 15c373e814d5..a2c64b7d99ae 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -16,35 +16,33 @@ namespace { class ReplicateTensorToMesh : public TensorToMesh { public: - ReplicateTensorToMesh(MeshDevice& mesh_device) : mesh_device_(mesh_device) {} + ReplicateTensorToMesh(int num_devices) : num_devices_(num_devices) {} std::vector map(const Tensor& tensor) override { std::vector tensors; - tensors.reserve(mesh_device_.num_devices()); - std::fill_n(std::back_inserter(tensors), mesh_device_.num_devices(), tensor); + tensors.reserve(num_devices_); + std::fill_n(std::back_inserter(tensors), num_devices_, tensor); return tensors; } - DistributedTensorConfig config() const override { - return DistributedTensorConfig{ReplicateTensor{mesh_device_.num_devices()}}; - } + DistributedTensorConfig config() const override { return DistributedTensorConfig{ReplicateTensor{num_devices_}}; } private: - MeshDevice& mesh_device_; + int num_devices_ = -1; }; class ShardTensorToMesh : public TensorToMesh { public: - ShardTensorToMesh(MeshDevice& mesh_device, int dim) : mesh_device_(mesh_device), shard_dim_(dim) {} + ShardTensorToMesh(int num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} std::vector map(const Tensor& tensor) override { - return experimental::xtensor::chunk(tensor, mesh_device_.num_devices(), shard_dim_); + return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); } DistributedTensorConfig config() const override { return DistributedTensorConfig{ShardTensor{shard_dim_}}; } private: - MeshDevice& mesh_device_; + int num_devices_ = -1; int shard_dim_ = -1; }; @@ -144,11 +142,11 @@ class ConcatMesh2dToTensor : public MeshToTensor { } // namespace std::unique_ptr replicate_tensor_to_mesh_mapper(MeshDevice& mesh_device) { - return std::make_unique(mesh_device); + return std::make_unique(mesh_device.num_devices()); } std::unique_ptr shard_tensor_to_mesh_mapper(MeshDevice& mesh_device, int dim) { - return std::make_unique(mesh_device, dim); + return std::make_unique(mesh_device.num_devices(), dim); } std::unique_ptr shard_tensor_2d_to_mesh_mapper( diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp index 3b9906cd4edd..4e4d829df324 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp @@ -12,50 +12,21 @@ namespace ttnn::experimental::xtensor { namespace { +using ::tt::tt_metal::DataType; using ::tt::tt_metal::Tensor; -// copypaste from deprecated tensor pybinds ttnn -tt::tt_metal::OwnedBuffer create_owned_buffer(const std::vector& data, DataType data_type) { - using ::tt::tt_metal::owned_buffer::create; - - switch (data_type) { - case DataType::BFLOAT8_B: { - auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return create(std::move(uint32_vector)); - } - case DataType::BFLOAT4_B: { - auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return create(std::move(uint32_vector)); - } - case DataType::FLOAT32: { - auto data_copy = data; - return create(std::move(data_copy)); - } - case DataType::BFLOAT16: { - std::vector bfloat16_data(data.size()); - std::transform(std::begin(data), std::end(data), std::begin(bfloat16_data), [](float value) { - return bfloat16(value); - }); - return create(std::move(bfloat16_data)); - } - default: { - TT_THROW("Cannot create a host buffer for data type: {}", data_type); - } - } -} - template Tensor create_owned_tensor( - std::vector data, const ttnn::Shape& shape, tt::tt_metal::DataType data_type, tt::tt_metal::Layout layout) { - auto buffer = tt::tt_metal::owned_buffer::create(std::move(data)); + tt::stl::Span data, const ttnn::SimpleShape& shape, DataType data_type, tt::tt_metal::Layout layout) { + auto buffer = tt::tt_metal::owned_buffer::create(std::vector(data.begin(), data.end())); auto storage = OwnedStorage{std::move(buffer)}; return Tensor{std::move(storage), shape, data_type, layout}; } // TODO: optimize precomputing multipliers -template +template std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { - auto tiled_buffer = tt::tt_metal::host_buffer::get_as(cpu_tensor); + auto tiled_buffer = tt::tt_metal::host_buffer::get_as(cpu_tensor); auto untiled_shape = cpu_tensor.get_logical_shape(); auto tiled_shape = cpu_tensor.get_padded_shape(); @@ -79,7 +50,7 @@ std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { for (size_t idx = 0; idx < total_size; ++idx) { uint32_t untiled_index = compute_flat_index(indices, untiled_shape); uint32_t tiled_index = compute_flat_index(indices, tiled_shape); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { untiled_data[untiled_index] = tiled_buffer[tiled_index].to_float(); } else { untiled_data[untiled_index] = tiled_buffer[tiled_index]; @@ -99,41 +70,51 @@ std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { } // namespace template <> -Tensor from_vector(const std::vector& buffer, const ttnn::Shape& shape) { - const DataType data_type = DataType::BFLOAT16; - auto logical_shape = shape.logical_shape(); - size_t volume = logical_shape.volume(); +Tensor from_span(tt::stl::Span buffer, const ttnn::SimpleShape& shape, DataType dtype) { + size_t volume = shape.volume(); TT_FATAL( buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - auto owned_buffer = create_owned_buffer(buffer, data_type); - return Tensor(OwnedStorage{owned_buffer}, logical_shape, data_type, Layout::ROW_MAJOR); + if (dtype == DataType::FLOAT32) { + return create_owned_tensor(buffer, shape, dtype, Layout::ROW_MAJOR); + } else if (dtype == DataType::BFLOAT16) { + std::vector bfloat16_data; + bfloat16_data.reserve(buffer.size()); + std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { + return bfloat16(value); + }); + return create_owned_tensor( + tt::stl::Span(bfloat16_data.data(), bfloat16_data.size()), shape, dtype, Layout::ROW_MAJOR); + } else { + // TODO: support bf8 and bf4 + TT_THROW("Unsupported data type for from_span: {}", dtype); + } } -// Workaround implementation due to issue with tilize for float32 -// it is expected that tilize will be fixed in the after next tt-metal main update template <> -Tensor from_vector(const std::vector& buffer, const ttnn::Shape& shape) { - auto tensor = from_vector(buffer, shape); - return ttnn::typecast(tensor, DataType::FLOAT32); +Tensor from_span(tt::stl::Span buffer, const ttnn::SimpleShape& shape, DataType dtype) { + size_t volume = shape.volume(); + TT_FATAL( + buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); + TT_FATAL(dtype == DataType::BFLOAT16, "Unsupported data type for from_span: {}", dtype); + return create_owned_tensor(buffer, shape, dtype, Layout::ROW_MAJOR); } template <> -Tensor from_vector(const std::vector& buffer, const ttnn::Shape& shape) { - MemoryConfig output_mem_config{}; - auto logical_shape = shape.logical_shape(); - auto volume = logical_shape.volume(); +Tensor from_span(tt::stl::Span buffer, const ttnn::SimpleShape& shape, DataType dtype) { + size_t volume = shape.volume(); TT_FATAL( buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - return create_owned_tensor(buffer, logical_shape, DataType::UINT32, Layout::ROW_MAJOR); + TT_FATAL(dtype == DataType::UINT32, "Unsupported data type for from_span: {}", dtype); + return create_owned_tensor(buffer, shape, DataType::UINT32, Layout::ROW_MAJOR); } template <> -Tensor from_vector(const std::vector& buffer, const ttnn::Shape& shape) { - auto logical_shape = shape.logical_shape(); - auto volume = logical_shape.volume(); +Tensor from_span(tt::stl::Span buffer, const ttnn::SimpleShape& shape, DataType dtype) { + size_t volume = shape.volume(); TT_FATAL( buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - return create_owned_tensor(buffer, logical_shape, DataType::INT32, Layout::ROW_MAJOR); + TT_FATAL(dtype == DataType::INT32, "Unsupported data type for from_span: {}", dtype); + return create_owned_tensor(buffer, shape, DataType::INT32, Layout::ROW_MAJOR); } template <> @@ -144,19 +125,38 @@ std::vector to_vector(const Tensor& tensor) { } else if (cpu_tensor.get_dtype() == DataType::FLOAT32) { return untile_tensor_to_vec(cpu_tensor); } else { + // TODO: support bf4, bf8. TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype()); } } +template <> +std::vector to_vector(const Tensor& tensor) { + auto cpu_tensor = tensor.cpu().to(Layout::ROW_MAJOR); + TT_FATAL( + cpu_tensor.get_dtype() == DataType::BFLOAT16, + "Unsupported data type for to_vector: {}", + cpu_tensor.get_dtype()); + return untile_tensor_to_vec(cpu_tensor); +} + template <> std::vector to_vector(const Tensor& tensor) { auto cpu_tensor = tensor.cpu().to(Layout::ROW_MAJOR); + TT_FATAL( + cpu_tensor.get_dtype() == DataType::UINT32, + "Unsupported data type for to_vector: {}", + cpu_tensor.get_dtype()); return untile_tensor_to_vec(cpu_tensor); } template <> std::vector to_vector(const Tensor& tensor) { auto cpu_tensor = tensor.cpu().to(Layout::ROW_MAJOR); + TT_FATAL( + cpu_tensor.get_dtype() == DataType::INT32, + "Unsupported data type for to_vector: {}", + cpu_tensor.get_dtype()); return untile_tensor_to_vec(cpu_tensor); } diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp index fe84efd50bbb..446e1626314c 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp @@ -11,43 +11,65 @@ namespace ttnn::experimental::xtensor { template -ttnn::Shape get_shape_from_xarray(const E& xarr) { +ttnn::SimpleShape get_shape_from_xarray(const E& xarr) { ttnn::SmallVector shape_dims; for (size_t i = 0; i < xarr.shape().size(); ++i) { shape_dims.push_back(xarr.shape()[i]); } - return ttnn::Shape(shape_dims); + return ttnn::SimpleShape(shape_dims); } -template -tt::tt_metal::Tensor from_vector(const std::vector& buffer, const ttnn::Shape& shape); +template +tt::tt_metal::Tensor from_span( + tt::stl::Span buffer, const ttnn::SimpleShape& shape, tt::tt_metal::DataType dtype); -template +template +tt::tt_metal::Tensor from_vector( + const std::vector& buffer, const ttnn::SimpleShape& shape, tt::tt_metal::DataType dtype) { + return from_span(tt::stl::Span(buffer.data(), buffer.size()), shape, dtype); +} + +template std::vector to_vector(const tt::tt_metal::Tensor& tensor); -template +template +xt::xarray tt_span_to_xtensor(tt::stl::Span vec, const ttnn::SimpleShape& shape) { + std::vector shape_vec(shape.cbegin(), shape.cend()); + return xt::adapt(vec.data(), vec.size(), xt::no_ownership(), shape_vec); +} + +// TODO: make the usage of std::span / tt::stl::Span consistent. +template xt::xarray span_to_xtensor(std::span vec, const ttnn::SimpleShape& shape) { std::vector shape_vec(shape.cbegin(), shape.cend()); return xt::adapt(vec.data(), vec.size(), xt::no_ownership(), shape_vec); } -template + +template +auto xtensor_to_tt_span(const xt::xarray& xtensor) { + auto adaptor = xt::adapt(xtensor.data(), xtensor.size(), xt::no_ownership()); + return tt::stl::Span(adaptor.data(), adaptor.size()); +} + +// TODO: make the usage of std::span / tt::stl::Span consistent. +template auto xtensor_to_span(const xt::xarray& xtensor) { auto adaptor = xt::adapt(xtensor.data(), xtensor.size(), xt::no_ownership()); return std::span(adaptor.data(), adaptor.size()); } -template -tt::tt_metal::Tensor from_xtensor(const xt::xarray& buffer) { +template +tt::tt_metal::Tensor from_xtensor(const xt::xarray& buffer, tt::tt_metal::DataType dtype) { auto shape = get_shape_from_xarray(buffer); - auto buffer_view = xtensor_to_span(buffer); - return from_vector(std::vector(buffer_view.begin(), buffer_view.end()), shape); + auto buffer_view = xtensor_to_tt_span(buffer); + return from_span(buffer_view, shape, dtype); } -template +template xt::xarray to_xtensor(const tt::tt_metal::Tensor& tensor) { auto vec = to_vector(tensor); auto shape = tensor.get_shape().logical_shape(); - return span_to_xtensor(std::span(vec.data(), vec.size()), shape); + return tt_span_to_xtensor(tt::stl::Span(vec.data(), vec.size()), shape); } } // namespace ttnn::experimental::xtensor diff --git a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp index 161ee993373e..52b52c6671e0 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp @@ -82,6 +82,7 @@ std::vector> chunk(const xt::xarray& xtensor, int num_chunks, i auto chunk_view = xt::strided_view(xtensor, indices); + // TODO: optimize away this copy. // Construct xarray from the view // This forces a copy of that slice into a new xarray chunks.push_back(xt::xarray(chunk_view)); @@ -119,25 +120,25 @@ template xt::xarray concatenate(const std::vector>& namespace adaptor { namespace { -template -Tensor concatenate_impl(const std::vector& tensors, int dim) { - std::vector> xtensors; +template +Tensor concatenate_impl(const std::vector& tensors, DataType dtype, int dim) { + std::vector> xtensors; for (const auto& tensor : tensors) { - xtensors.push_back(to_xtensor(tensor)); + xtensors.push_back(to_xtensor(tensor)); } - xt::xarray result = concatenate(xtensors, dim); - return from_xtensor(result); + xt::xarray result = concatenate(xtensors, dim); + return from_xtensor(result, dtype); } -template -std::vector chunk_impl(const Tensor& tensor, int num_chunks, int dim) { - xt::xarray xtensor = to_xtensor(tensor); - auto xtensor_chunks = chunk(xtensor, num_chunks, dim); +template +std::vector chunk_impl(const Tensor& tensor, DataType dtype, int num_chunks, int dim) { + xt::xarray xtensor = to_xtensor(tensor); + auto xtensor_chunks = chunk(xtensor, num_chunks, dim); std::vector tensors; tensors.reserve(xtensor_chunks.size()); for (const auto& c : xtensor_chunks) { - tensors.push_back(from_xtensor(c)); + tensors.push_back(from_xtensor(c, dtype)); } return tensors; } @@ -147,20 +148,21 @@ std::vector chunk_impl(const Tensor& tensor, int num_chunks, int dim) { std::vector chunk(const Tensor& tensor, int num_chunks, int dim) { switch (tensor.dtype()) { - case DataType::BFLOAT16: return adaptor::chunk_impl(tensor, num_chunks, dim); - case DataType::FLOAT32: return adaptor::chunk_impl(tensor, num_chunks, dim); - case DataType::INT32: return adaptor::chunk_impl(tensor, num_chunks, dim); - case DataType::UINT32: return adaptor::chunk_impl(tensor, num_chunks, dim); + case DataType::BFLOAT16: return adaptor::chunk_impl(tensor, DataType::BFLOAT16, num_chunks, dim); + case DataType::FLOAT32: return adaptor::chunk_impl(tensor, DataType::FLOAT32, num_chunks, dim); + case DataType::INT32: return adaptor::chunk_impl(tensor, DataType::INT32, num_chunks, dim); + case DataType::UINT32: return adaptor::chunk_impl(tensor, DataType::UINT32, num_chunks, dim); default: TT_THROW("Unsupported data type: {}", tensor.dtype()); } } Tensor concatenate(const std::vector& tensors, int dim) { + TT_FATAL(tensors.size() > 0, "Cannot concatenate an empty list of tensors"); switch (tensors.front().dtype()) { - case DataType::BFLOAT16: return adaptor::concatenate_impl(tensors, dim); - case DataType::FLOAT32: return adaptor::concatenate_impl(tensors, dim); - case DataType::INT32: return adaptor::concatenate_impl(tensors, dim); - case DataType::UINT32: return adaptor::concatenate_impl(tensors, dim); + case DataType::BFLOAT16: return adaptor::concatenate_impl(tensors, DataType::BFLOAT16, dim); + case DataType::FLOAT32: return adaptor::concatenate_impl(tensors, DataType::FLOAT32, dim); + case DataType::INT32: return adaptor::concatenate_impl(tensors, DataType::INT32, dim); + case DataType::UINT32: return adaptor::concatenate_impl(tensors, DataType::UINT32, dim); default: TT_THROW("Unsupported data type: {}", tensors.front().dtype()); } }