From 59f5c5ed9c7b563fb1afd487c3293254de483eae Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Fri, 13 Dec 2024 23:06:48 +0000 Subject: [PATCH] Moved from/to vector to tensor.hpp --- .../gtests/tensor/test_distributed_tensor.cpp | 28 ++-- .../gtests/tensor/test_partition.cpp | 1 - .../gtests/tensor/test_vector_conversion.cpp | 23 +-- .../sources/ttml/core/tt_tensor_utils.cpp | 2 +- .../sources/ttml/core/tt_tensor_utils.hpp | 4 +- ttnn/cpp/ttnn/tensor/CMakeLists.txt | 1 - ttnn/cpp/ttnn/tensor/tensor.cpp | 134 +++++++++++++++-- ttnn/cpp/ttnn/tensor/tensor.hpp | 19 +++ .../ttnn/tensor/xtensor/conversion_utils.cpp | 141 ------------------ .../ttnn/tensor/xtensor/conversion_utils.hpp | 20 +-- 10 files changed, 173 insertions(+), 200 deletions(-) delete mode 100644 ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp index 0070339dd38..77809e52a42 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp @@ -15,8 +15,6 @@ namespace ttnn::distributed::test { using ::testing::ElementsAre; -using ::ttnn::experimental::xtensor::from_vector; -using ::ttnn::experimental::xtensor::to_vector; using TensorDistributionTest = T3kMultiDeviceFixture; @@ -25,7 +23,7 @@ TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) { } TEST_F(TensorDistributionTest, Replication) { - Tensor input_tensor = from_vector( + Tensor input_tensor = Tensor::from_vector( std::vector{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_); @@ -34,13 +32,13 @@ TEST_F(TensorDistributionTest, Replication) { std::vector device_tensors = get_device_tensors(replicated_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (const auto& device_tensor : device_tensors) { - EXPECT_THAT(to_vector(device_tensor), ElementsAre(42.F, 13.F, -99.F)); + EXPECT_THAT(device_tensor.to_vector(), ElementsAre(42.F, 13.F, -99.F)); } } TEST_F(TensorDistributionTest, Shard1DInvalidDim) { const int num_devices = mesh_device_->num_devices(); - Tensor input_tensor = from_vector( + Tensor input_tensor = Tensor::from_vector( std::vector(num_devices, 0), get_tensor_spec(ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32)); @@ -58,7 +56,7 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) { TEST_F(TensorDistributionTest, Shard1DTooFewShards) { const int num_devices = mesh_device_->num_devices(); ASSERT_LT(3, num_devices); - Tensor input_tensor = from_vector( + Tensor input_tensor = Tensor::from_vector( std::vector{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); EXPECT_ANY_THROW({ @@ -74,7 +72,7 @@ TEST_F(TensorDistributionTest, Shard1D) { test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); } Tensor input_tensor = - from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32)); + Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32)); auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1); Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); @@ -82,14 +80,14 @@ TEST_F(TensorDistributionTest, Shard1D) { std::vector device_tensors = get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (int i = 0; i < device_tensors.size(); i++) { - EXPECT_THAT(to_vector(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); + EXPECT_THAT(device_tensors[i].to_vector(), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); } auto composer = concat_mesh_to_tensor_composer(/*dim=*/0); Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer); Tensor expected_tensor = - from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32)); + Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32)); EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); } @@ -121,7 +119,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { std::vector test_data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; Tensor input_tensor = - from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32)); + Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32)); input_tensor.print(); auto mapper = shard_tensor_2d_to_mesh_mapper( @@ -138,10 +136,10 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { int i = 0; for (; i < 4; i++) { - EXPECT_THAT(to_vector(device_tensors[i]), ElementsAre(0.0, 1.0, 2.0, 3.0)); + EXPECT_THAT(device_tensors[i].to_vector(), ElementsAre(0.0, 1.0, 2.0, 3.0)); } for (; i < device_tensors.size(); i++) { - EXPECT_THAT(to_vector(device_tensors[i]), ElementsAre(4.0, 5.0, 6.0, 7.0)); + EXPECT_THAT(device_tensors[i].to_vector(), ElementsAre(4.0, 5.0, 6.0, 7.0)); } } @@ -156,7 +154,7 @@ TEST_F(TensorDistributionTest, Shard2D) { test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); } Tensor input_tensor = - from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32)); + Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32)); auto mapper = shard_tensor_2d_to_mesh_mapper( *mesh_device_, @@ -170,7 +168,7 @@ TEST_F(TensorDistributionTest, Shard2D) { std::vector device_tensors = get_device_tensors(sharded_tensor); EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); for (int i = 0; i < device_tensors.size(); i++) { - EXPECT_THAT(to_vector(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); + EXPECT_THAT(device_tensors[i].to_vector(), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); } auto composer = concat_mesh_2d_to_tensor_composer( @@ -182,7 +180,7 @@ TEST_F(TensorDistributionTest, Shard2D) { Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer); Tensor expected_tensor = - from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32)); + Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32)); EXPECT_TRUE(ttnn::allclose(concatenated_tensor, expected_tensor)); } diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp index d5e16748ddb..822a688732e 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp @@ -17,7 +17,6 @@ using ::testing::SizeIs; using ::tt::tt_metal::Tensor; using ::ttnn::experimental::xtensor::chunk; using ::ttnn::experimental::xtensor::concatenate; -using ::ttnn::experimental::xtensor::from_vector; TEST(PartitionTest, ChunkBasicNonDivisible3) { // Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp index d48e1cf4ba2..8dc25e1abf4 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_vector_conversion.cpp @@ -16,8 +16,6 @@ namespace { using ::testing::Eq; using ::testing::Pointwise; -using ::ttnn::experimental::xtensor::from_vector; -using ::ttnn::experimental::xtensor::to_vector; const std::vector& get_shapes_for_test() { static auto* shapes = new std::vector{ @@ -58,8 +56,8 @@ TYPED_TEST_SUITE(VectorConversionTest, TestTypes); TYPED_TEST(VectorConversionTest, Roundtrip) { for (const auto& shape : get_shapes_for_test()) { auto input = arange(0, static_cast(shape.volume()), 1); - auto output = - to_vector(from_vector(input, get_tensor_spec(shape, convert_to_data_type()))); + auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type())) + .template to_vector(); EXPECT_THAT(output, Pointwise(Eq(), input)) << "for shape: " << shape; } } @@ -69,7 +67,7 @@ TYPED_TEST(VectorConversionTest, InvalidSize) { auto input = arange(0, 42, 1); ASSERT_NE(input.size(), shape.volume()); - EXPECT_ANY_THROW(from_vector(input, get_tensor_spec(shape, convert_to_data_type()))); + EXPECT_ANY_THROW(Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type()))); } TYPED_TEST(VectorConversionTest, RoundtripTilezedLayout) { @@ -77,10 +75,12 @@ TYPED_TEST(VectorConversionTest, RoundtripTilezedLayout) { auto input = arange(0, shape.volume(), 1); // TODO: Support this. - EXPECT_ANY_THROW(from_vector(input, get_tensor_spec(shape, convert_to_data_type(), Layout::TILE))); + EXPECT_ANY_THROW( + Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type(), Layout::TILE))); - auto output = to_vector( - from_vector(input, get_tensor_spec(shape, convert_to_data_type())).to(Layout::TILE)); + auto output = Tensor::from_vector(input, get_tensor_spec(shape, convert_to_data_type())) + .to(Layout::TILE) + .template to_vector(); EXPECT_THAT(output, Pointwise(Eq(), input)); } @@ -89,7 +89,7 @@ TYPED_TEST(VectorConversionTest, InvalidDtype) { auto input = arange(0, 42, 1); ASSERT_NE(input.size(), shape.volume()); - EXPECT_ANY_THROW(from_vector( + EXPECT_ANY_THROW(Tensor::from_vector( input, get_tensor_spec( shape, @@ -106,10 +106,11 @@ TEST(FloatVectorConversionTest, RoundtripBfloat16Representation) { return bf.to_float(); }); - auto output_bf16 = to_vector(from_vector(input_ft, get_tensor_spec(shape, DataType::BFLOAT16))); + auto output_bf16 = + Tensor::from_vector(input_ft, get_tensor_spec(shape, DataType::BFLOAT16)).to_vector(); EXPECT_THAT(output_bf16, Pointwise(Eq(), input_bf16)) << "for shape: " << shape; - auto output_ft = to_vector(from_vector(input_bf16, get_tensor_spec(shape, DataType::BFLOAT16))); + auto output_ft = Tensor::from_vector(input_bf16, get_tensor_spec(shape, DataType::BFLOAT16)).to_vector(); EXPECT_THAT(output_ft, Pointwise(Eq(), input_ft)) << "for shape: " << shape; } } diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index 1e5244e999a..8ca99db8453 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -32,7 +32,7 @@ T get_median(std::vector& vec) { template void print_tensor_stats_(const tt::tt_metal::Tensor& tensor, const std::string& name) { auto tensor_shape = tensor.get_shape(); - auto tensor_vec = ttml::core::to_vector(tensor); + auto tensor_vec = tensor.to_vector(); auto median = get_median(tensor_vec); auto mean = std::accumulate(tensor_vec.begin(), tensor_vec.end(), 0.F) / static_cast(tensor_vec.size()); diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.hpp b/tt-train/sources/ttml/core/tt_tensor_utils.hpp index dba8cdaa030..89ab8e01b5b 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.hpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.hpp @@ -39,7 +39,7 @@ template template [[nodiscard]] std::vector to_vector(const tt::tt_metal::Tensor& tensor) { - return ttnn::experimental::xtensor::to_vector(tensor); + return tensor.to_vector(); } [[nodiscard]] bool is_tensor_initialized(const tt::tt_metal::Tensor& tensor); @@ -56,7 +56,7 @@ template template [[nodiscard]] xt::xarray to_xtensor(const tt::tt_metal::Tensor& tensor) { - auto vec = to_vector(tensor); + auto vec = tensor.to_vector(); const auto& shape = tensor.get_shape().logical_shape(); std::vector shape_vec(shape.cbegin(), shape.cend()); return xt::adapt(std::move(vec), shape_vec); diff --git a/ttnn/cpp/ttnn/tensor/CMakeLists.txt b/ttnn/cpp/ttnn/tensor/CMakeLists.txt index e67a87c6c41..583b3ca9f43 100644 --- a/ttnn/cpp/ttnn/tensor/CMakeLists.txt +++ b/ttnn/cpp/ttnn/tensor/CMakeLists.txt @@ -11,7 +11,6 @@ set(TENSOR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/layout/page_config.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/size.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/tensor_layout.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/xtensor/conversion_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/xtensor/partition.cpp CACHE INTERNAL "Tensor sources to reuse in ttnn build" diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index eab2d044d5b..06925930e13 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -30,12 +30,9 @@ using namespace tt::constants; -namespace tt { - -namespace tt_metal { - +namespace tt::tt_metal { namespace { -namespace CMAKE_UNIQUE_NAMESPACE { + MemoryConfig extract_memory_config(const Storage& storage) { return std::visit( [](const auto& storage) -> MemoryConfig { @@ -50,7 +47,60 @@ MemoryConfig extract_memory_config(const Storage& storage) { }, storage); } -} // namespace CMAKE_UNIQUE_NAMESPACE + +template +Tensor create_owned_tensor_from_span(tt::stl::Span data, const TensorSpec& spec) { + // TODO: support tilized layouts. + TT_FATAL(spec.layout() == Layout::ROW_MAJOR, "Unsupported layout: {}", spec.layout()); + auto buffer = tt::tt_metal::owned_buffer::create(std::vector(data.begin(), data.end())); + auto storage = OwnedStorage{std::move(buffer)}; + return Tensor{std::move(storage), spec}; +} + +// TODO: optimize precomputing multipliers +template +std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { + auto tiled_buffer = tt::tt_metal::host_buffer::get_as(cpu_tensor); + auto untiled_shape = cpu_tensor.get_logical_shape(); + auto tiled_shape = cpu_tensor.get_padded_shape(); + + // Calculate total size of the untiled tensor + size_t total_size = untiled_shape.volume(); + + std::vector untiled_data(total_size); + + auto compute_flat_index = [](const std::vector& indices, ttnn::SimpleShape& shape) -> uint32_t { + uint32_t flat_index = 0; + uint32_t multiplier = 1; + for (int i = (int)indices.size() - 1; i >= 0; --i) { + flat_index += indices[i] * multiplier; + multiplier *= shape[i]; + } + return flat_index; + }; + + std::vector indices(tiled_shape.rank(), 0); + + for (size_t idx = 0; idx < total_size; ++idx) { + uint32_t untiled_index = compute_flat_index(indices, untiled_shape); + uint32_t tiled_index = compute_flat_index(indices, tiled_shape); + if constexpr (std::is_same_v) { + untiled_data[untiled_index] = tiled_buffer[tiled_index].to_float(); + } else { + untiled_data[untiled_index] = tiled_buffer[tiled_index]; + } + + for (int dim = (int)tiled_shape.rank() - 1; dim >= 0; --dim) { + if (++indices[dim] < untiled_shape[dim]) { + break; + } + indices[dim] = 0; + } + } + + return untiled_data; +} + } // namespace Tensor::TensorAttributes::TensorAttributes() : @@ -111,7 +161,7 @@ Tensor::Tensor( tile->get_tile_shape()); } } - auto memory_config = CMAKE_UNIQUE_NAMESPACE::extract_memory_config(storage); + auto memory_config = extract_memory_config(storage); init( std::move(storage), TensorSpec( @@ -559,6 +609,72 @@ const Storage& Tensor::get_storage() const { return this->tensor_attributes->storage; } +template <> +Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec) { + size_t volume = spec.logical_shape().volume(); + TT_FATAL( + buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); + if (spec.data_type() == DataType::FLOAT32) { + return create_owned_tensor_from_span(buffer, spec); + } else if (spec.data_type() == DataType::BFLOAT16) { + std::vector bfloat16_data; + bfloat16_data.reserve(buffer.size()); + std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { + return bfloat16(value); + }); + return create_owned_tensor_from_span( + tt::stl::Span(bfloat16_data.data(), bfloat16_data.size()), spec); + } else { + // TODO: support bf8 and bf4 + TT_THROW("Unsupported data type for from_span: {}", spec.data_type()); + } +} + +template +Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec) { + size_t volume = spec.logical_shape().volume(); + TT_FATAL( + buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); + TT_FATAL( + spec.data_type() == convert_to_data_type(), + "Unsupported data type for from_span: got {}, expected: {}", + spec.data_type(), + convert_to_data_type()); + return create_owned_tensor_from_span(buffer, spec); +} + +template <> +std::vector Tensor::to_vector() const { + auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR); + if (cpu_tensor.get_dtype() == DataType::BFLOAT16) { + return untile_tensor_to_vec(cpu_tensor); + } else if (cpu_tensor.get_dtype() == DataType::FLOAT32) { + return untile_tensor_to_vec(cpu_tensor); + } else { + // TODO: support bf4, bf8. + TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype()); + } +} + +template +std::vector Tensor::to_vector() const { + auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR); + TT_FATAL( + cpu_tensor.get_dtype() == convert_to_data_type(), + "Unsupported data type for to_vector: got {}, expected: {}", + cpu_tensor.get_dtype(), + convert_to_data_type()); + return untile_tensor_to_vec(cpu_tensor); +} + +// Instantiate explicitly for the supported types. +template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); +template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); +template Tensor Tensor::from_span(tt::stl::Span buffer, const TensorSpec& spec); +template std::vector Tensor::to_vector() const; +template std::vector Tensor::to_vector() const; +template std::vector Tensor::to_vector() const; + Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config,uint8_t cq_id, const std::vector& sub_device_ids) const { return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id, sub_device_ids); @@ -959,6 +1075,4 @@ bool validate_worker_modes(const std::vector& workers) { return worker_modes_match; } -} // namespace tt_metal - -} // namespace tt +} // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 08b7a653389..4f96c1f18ee 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -139,6 +139,25 @@ struct Tensor { std::vector get_workers(bool blocking = false) const; + // Converts a buffer of elements of type `T` to a `Tensor`. + // Elements are assumed to be stored in row-major order. The size of the span and the type have to match `spec`. + // + // TODO: tilized layouts and reduced precision types are currently not supported. + template + static Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); + + // Same as `from_span`, but takes a vector instead. + template + static Tensor from_vector(const std::vector& buffer, const TensorSpec& spec) { + return from_span(tt::stl::Span(buffer.data(), buffer.size()), spec); + } + + // Converts a `Tensor` to a buffer of elements of type `T`. + // Elements in the buffer will be stored in row-major order. The type of the elements has to match that of the + // `Tensor`. + template + std::vector to_vector() const; + Tensor to( Device* target_device, const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp deleted file mode 100644 index 5ba963c2786..00000000000 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "common/assert.hpp" -#include "common/bfloat16.hpp" -#include "ttnn/tensor/host_buffer/functions.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/tensor/types.hpp" -#include "ttnn/tensor/xtensor/conversion_utils.hpp" -#include - -namespace ttnn::experimental::xtensor { -namespace { - -using ::tt::tt_metal::DataType; -using ::tt::tt_metal::Tensor; - -template -Tensor create_owned_tensor(tt::stl::Span data, const TensorSpec& spec) { - // TODO: support tilized layouts. - TT_FATAL(spec.layout() == Layout::ROW_MAJOR, "Unsupported layout: {}", spec.layout()); - auto buffer = tt::tt_metal::owned_buffer::create(std::vector(data.begin(), data.end())); - auto storage = OwnedStorage{std::move(buffer)}; - return Tensor{std::move(storage), spec}; -} - -// TODO: optimize precomputing multipliers -template -std::vector untile_tensor_to_vec(const Tensor& cpu_tensor) { - auto tiled_buffer = tt::tt_metal::host_buffer::get_as(cpu_tensor); - auto untiled_shape = cpu_tensor.get_logical_shape(); - auto tiled_shape = cpu_tensor.get_padded_shape(); - - // Calculate total size of the untiled tensor - size_t total_size = untiled_shape.volume(); - - std::vector untiled_data(total_size); - - auto compute_flat_index = [](const std::vector& indices, ttnn::SimpleShape& shape) -> uint32_t { - uint32_t flat_index = 0; - uint32_t multiplier = 1; - for (int i = (int)indices.size() - 1; i >= 0; --i) { - flat_index += indices[i] * multiplier; - multiplier *= shape[i]; - } - return flat_index; - }; - - std::vector indices(tiled_shape.rank(), 0); - - for (size_t idx = 0; idx < total_size; ++idx) { - uint32_t untiled_index = compute_flat_index(indices, untiled_shape); - uint32_t tiled_index = compute_flat_index(indices, tiled_shape); - if constexpr (std::is_same_v) { - untiled_data[untiled_index] = tiled_buffer[tiled_index].to_float(); - } else { - untiled_data[untiled_index] = tiled_buffer[tiled_index]; - } - - for (int dim = (int)tiled_shape.rank() - 1; dim >= 0; --dim) { - if (++indices[dim] < untiled_shape[dim]) { - break; - } - indices[dim] = 0; - } - } - - return untiled_data; -} - -} // namespace - -template <> -Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec) { - size_t volume = spec.logical_shape().volume(); - TT_FATAL( - buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - if (spec.data_type() == DataType::FLOAT32) { - return create_owned_tensor(buffer, spec); - } else if (spec.data_type() == DataType::BFLOAT16) { - std::vector bfloat16_data; - bfloat16_data.reserve(buffer.size()); - std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { - return bfloat16(value); - }); - return create_owned_tensor(tt::stl::Span(bfloat16_data.data(), bfloat16_data.size()), spec); - } else { - // TODO: support bf8 and bf4 - TT_THROW("Unsupported data type for from_span: {}", spec.data_type()); - } -} - -template -Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec) { - size_t volume = spec.logical_shape().volume(); - TT_FATAL( - buffer.size() == volume, "Current buffer size is {} different from shape volume {}", buffer.size(), volume); - TT_FATAL( - spec.data_type() == convert_to_data_type(), - "Unsupported data type for from_span: got {}, expected: {}", - spec.data_type(), - convert_to_data_type()); - return create_owned_tensor(buffer, spec); -} - -// Instantiate explicitly for the supported types. -template Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); -template Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); -template Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); - -template <> -std::vector to_vector(const Tensor& tensor) { - auto cpu_tensor = tensor.cpu().to(Layout::ROW_MAJOR); - if (cpu_tensor.get_dtype() == DataType::BFLOAT16) { - return untile_tensor_to_vec(cpu_tensor); - } else if (cpu_tensor.get_dtype() == DataType::FLOAT32) { - return untile_tensor_to_vec(cpu_tensor); - } else { - // TODO: support bf4, bf8. - TT_THROW("Cannot convert tensor to vector for data type: {}", cpu_tensor.get_dtype()); - } -} - -template -std::vector to_vector(const Tensor& tensor) { - auto cpu_tensor = tensor.cpu().to(Layout::ROW_MAJOR); - TT_FATAL( - cpu_tensor.get_dtype() == convert_to_data_type(), - "Unsupported data type for to_vector: got {}, expected: {}", - cpu_tensor.get_dtype(), - convert_to_data_type()); - return untile_tensor_to_vec(cpu_tensor); -} - -// Instantiate explicitly for the supported types. -template std::vector to_vector(const Tensor& tensor); -template std::vector to_vector(const Tensor& tensor); -template std::vector to_vector(const Tensor& tensor); - -} // namespace ttnn::experimental::xtensor diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp index 9588dd45516..881cd453f3f 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp @@ -20,22 +20,6 @@ ttnn::SimpleShape get_shape_from_xarray(const E& xarr) { return ttnn::SimpleShape(shape_dims); } -// Converts a buffer of elements of type `T` to a Tensor. -// Elements are assumed to be stored in row-major order. The size of the span and the type have to match Tensor spec. -// TODO: tilized layouts and reduced precision types are currently not supported. -template -tt::tt_metal::Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); - -// Converts a Tensor to a buffer of elements of type `T`. -// Elements in the buffer will be stored in row-major order. The type of the elements has to match that of the Tensor. -template -std::vector to_vector(const tt::tt_metal::Tensor& tensor); - -template -tt::tt_metal::Tensor from_vector(const std::vector& buffer, const TensorSpec& spec) { - return from_span(tt::stl::Span(buffer.data(), buffer.size()), spec); -} - // Converts a span to an xtensor view. // IMPORTANT: the lifetime of the returned xtensor view is tied to the lifetime of the underlying buffer. template @@ -66,13 +50,13 @@ tt::tt_metal::Tensor from_xtensor(const xt::xarray& buffer, const TensorSpec& auto shape = get_shape_from_xarray(buffer); TT_FATAL(shape == spec.logical_shape(), "xtensor has a different shape than the supplied TensorSpec"); auto buffer_view = xtensor_to_span(buffer); - return from_span(buffer_view, spec); + return tt::tt_metal::Tensor::from_span(buffer_view, spec); } // Converts a Tensor to an xtensor. template xt::xarray to_xtensor(const tt::tt_metal::Tensor& tensor) { - auto vec = to_vector(tensor); + auto vec = tensor.to_vector(); auto shape = tensor.get_shape().logical_shape(); return xt::xarray(span_to_xtensor_view(tt::stl::Span(vec.data(), vec.size()), shape)); }