diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp index 77809e52a424..a09069347515 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp @@ -7,7 +7,6 @@ #include "ttnn/distributed/api.hpp" #include "ttnn/operations/functions.hpp" -#include "ttnn/tensor/xtensor/conversion_utils.hpp" #include "ttnn_test_fixtures.hpp" #include #include @@ -97,18 +96,18 @@ TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) { ASSERT_EQ(num_cols, 4); EXPECT_ANY_THROW( - shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); + shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); EXPECT_ANY_THROW( - shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); + shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); } TEST_F(TensorDistributionTest, Shard2DInvalidShardConfig) { - EXPECT_ANY_THROW(shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{})); + EXPECT_ANY_THROW(shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{})); } TEST_F(TensorDistributionTest, Concat2DInvalidConfig) { - EXPECT_ANY_THROW(concat_mesh_2d_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2})); + EXPECT_ANY_THROW(concat_2d_mesh_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2})); } TEST_F(TensorDistributionTest, Shard2DReplicateDim) { @@ -122,7 +121,7 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) { Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32)); input_tensor.print(); - auto mapper = shard_tensor_2d_to_mesh_mapper( + auto mapper = shard_tensor_to_2d_mesh_mapper( *mesh_device_, MeshShape{num_rows, num_cols}, Shard2dConfig{ @@ -156,7 +155,7 @@ TEST_F(TensorDistributionTest, Shard2D) { Tensor input_tensor = Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32)); - auto mapper = shard_tensor_2d_to_mesh_mapper( + auto mapper = shard_tensor_to_2d_mesh_mapper( *mesh_device_, MeshShape{num_rows, num_cols}, Shard2dConfig{ @@ -171,7 +170,7 @@ TEST_F(TensorDistributionTest, Shard2D) { EXPECT_THAT(device_tensors[i].to_vector(), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); } - auto composer = concat_mesh_2d_to_tensor_composer( + auto composer = concat_2d_mesh_to_tensor_composer( *mesh_device_, Concat2dConfig{ .row_dim = 0, diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp index 822a688732e6..4b0062581c0b 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp @@ -16,7 +16,7 @@ namespace { using ::testing::SizeIs; using ::tt::tt_metal::Tensor; using ::ttnn::experimental::xtensor::chunk; -using ::ttnn::experimental::xtensor::concatenate; +using ::ttnn::experimental::xtensor::concat; TEST(PartitionTest, ChunkBasicNonDivisible3) { // Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -51,7 +51,7 @@ TEST(PartitionTest, DefaultAxis) { xt::xarray b = {{5.0, 6.0}, {7.0, 8.0}}; std::vector> input = {a, b}; - xt::xarray result = concatenate(input); // axis=0 by default + xt::xarray result = concat(input); // axis=0 by default xt::xarray expected = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}}; xt::allclose(result, expected); @@ -62,7 +62,7 @@ TEST(PartitionTest, AxisOne) { xt::xarray y = {{7, 8}, {9, 10}}; std::vector> input = {x, y}; - xt::xarray result = concatenate(input, 1); + xt::xarray result = concat(input, 1); xt::xarray expected = {{1, 2, 3, 7, 8}, {4, 5, 6, 9, 10}}; xt::allclose(result, expected); @@ -74,7 +74,7 @@ TEST(PartitionTest, MultipleArraysAxis0) { xt::xarray c = {5.0f, 6.0f}; std::vector> input = {a, b, c}; - xt::xarray result = concatenate(input, 0); + xt::xarray result = concat(input, 0); xt::xarray expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; xt::allclose(result, expected); @@ -85,7 +85,7 @@ TEST(PartitionTest, EmptyArray) { xt::xarray b; // Empty std::vector> input = {a, b}; - EXPECT_ANY_THROW({ xt::xarray result = concatenate(input, 0); }); + EXPECT_ANY_THROW({ xt::xarray result = concat(input, 0); }); } TEST(PartitionTest, HigherDimensions) { @@ -95,10 +95,10 @@ TEST(PartitionTest, HigherDimensions) { arr2.reshape({2, 2, 2}); std::vector> input = {arr1, arr2}; - xt::xarray result = concatenate(input, 0); + xt::xarray result = concat(input, 0); // Expected: shape (4,2,2) with arr1 stacked over arr2 along axis 0 - xt::xarray expected = concatenate(xt::xtuple(arr1, arr2), 0); + xt::xarray expected = xt::concatenate(xt::xtuple(arr1, arr2), 0); xt::allclose(result, expected); } @@ -109,7 +109,7 @@ TEST(PartitionTest, HigherAxis) { // Both have shape (2,2,2) std::vector> input = {arr1, arr2}; - xt::xarray result = concatenate(input, 2); + xt::xarray result = concat(input, 2); // Expected shape: (2,2,4) xt::xarray expected = {{{1, 2, 9, 10}, {3, 4, 11, 12}}, {{5, 6, 13, 14}, {7, 8, 15, 16}}}; diff --git a/tt-train/sources/ttml/core/distributed_mapping.hpp b/tt-train/sources/ttml/core/distributed_mapping.hpp index 2ff8b7eaccb3..1ba3a9e5c02b 100644 --- a/tt-train/sources/ttml/core/distributed_mapping.hpp +++ b/tt-train/sources/ttml/core/distributed_mapping.hpp @@ -172,11 +172,11 @@ class ConcatMesh2dToTensor : public MeshToXTensor, T> { auto row_end = row_start + cols; std::vector> row_tensors(row_start, row_end); - auto concatenated_row = core::concatenate(row_tensors, col_dim); + auto concatenated_row = core::concat(row_tensors, col_dim); row_concatenated.push_back(std::move(concatenated_row)); } - auto result = core::concatenate(row_concatenated, row_dim); + auto result = core::concat(row_concatenated, row_dim); return {result}; } @@ -216,7 +216,7 @@ class ConcatMeshToXTensor : public MeshToXTensor, T> { } std::vector> compose_impl(const std::vector>& tensors) const { - return {core::concatenate(tensors, m_concat_dim)}; + return {core::concat(tensors, m_concat_dim)}; } private: diff --git a/tt-train/sources/ttml/core/xtensor_utils.hpp b/tt-train/sources/ttml/core/xtensor_utils.hpp index ef292d6869c9..074cc4a58519 100644 --- a/tt-train/sources/ttml/core/xtensor_utils.hpp +++ b/tt-train/sources/ttml/core/xtensor_utils.hpp @@ -29,8 +29,8 @@ auto xtensor_to_span(const xt::xarray& xtensor) { } template -xt::xarray concatenate(const std::vector>& v, size_t axis = 0) { - return ttnn::experimental::xtensor::concatenate(v, axis); +xt::xarray concat(const std::vector>& v, size_t axis = 0) { + return ttnn::experimental::xtensor::concat(v, axis); } } // namespace ttml::core diff --git a/tt-train/tests/core/distributed_test.cpp b/tt-train/tests/core/distributed_test.cpp index 67ad0327831e..0617c317ef34 100644 --- a/tt-train/tests/core/distributed_test.cpp +++ b/tt-train/tests/core/distributed_test.cpp @@ -143,7 +143,7 @@ TYPED_TEST(MeshOpsTest, ConcatenateSameParametersAsCompose) { std::vector> shards = {s1, s2, s3}; ttml::core::ConcatMeshToXTensor composer(mesh_shape, 0); - auto composed = ttml::core::concatenate(shards); + auto composed = ttml::core::concat(shards); xt::xarray expected = { TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)}; diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 3c29a6432618..b8dcc46f340b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -46,9 +46,9 @@ class ShardTensorToMesh : public TensorToMesh { int shard_dim_ = -1; }; -class Shard2dTensorToMesh : public TensorToMesh { +class ShardTensorTo2dMesh : public TensorToMesh { public: - Shard2dTensorToMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) {} std::vector map(const Tensor& tensor) override { @@ -85,7 +85,7 @@ class Shard2dTensorToMesh : public TensorToMesh { TT_FATAL( static_cast(tensor_shards.size()) == rows * cols, - "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), rows, @@ -106,16 +106,16 @@ class ConcatMeshToTensor : public MeshToTensor { ConcatMeshToTensor(int dim) : concat_dim_(dim) {} Tensor compose(const std::vector& tensors) override { - return experimental::xtensor::concatenate(tensors, concat_dim_); + return experimental::xtensor::concat(tensors, concat_dim_); } private: int concat_dim_ = -1; }; -class ConcatMesh2dToTensor : public MeshToTensor { +class Concat2dMeshToTensor : public MeshToTensor { public: - ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} Tensor compose(const std::vector& tensors) override { @@ -128,10 +128,10 @@ class ConcatMesh2dToTensor : public MeshToTensor { auto row_start = tensors.begin() + i * cols; auto row_end = row_start + cols; std::vector row_tensors(row_start, row_end); - row_concatenated.push_back(experimental::xtensor::concatenate(row_tensors, col_dim)); + row_concatenated.push_back(experimental::xtensor::concat(row_tensors, col_dim)); } - return experimental::xtensor::concatenate(row_concatenated, row_dim); + return experimental::xtensor::concat(row_concatenated, row_dim); } private: @@ -149,29 +149,29 @@ std::unique_ptr shard_tensor_to_mesh_mapper(MeshDevice& mesh_devic return std::make_unique(mesh_device.num_devices(), dim); } -std::unique_ptr shard_tensor_2d_to_mesh_mapper( +std::unique_ptr shard_tensor_to_2d_mesh_mapper( MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) { TT_FATAL( config.row_dim.has_value() || config.col_dim.has_value(), - "ShardTensor2dMesh requires at least one dimension to shard"); + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); TT_FATAL( mesh_shape.num_rows <= mesh_device.shape().num_rows && // mesh_shape.num_cols <= mesh_device.shape().num_cols, - "ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape."); - return std::make_unique(mesh_shape, config); + "Device mesh shape does not match the provided mesh shape."); + return std::make_unique(mesh_shape, config); } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { return std::make_unique(dim); } -std::unique_ptr concat_mesh_2d_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config) { +std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config) { TT_FATAL( config.row_dim != config.col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); - return std::make_unique(mesh_device, config); + return std::make_unique(mesh_device, config); } Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper) { diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index d97a329b63c6..7aaee73caa4c 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -36,7 +36,7 @@ struct Shard2dConfig { std::optional row_dim; std::optional col_dim; }; -std::unique_ptr shard_tensor_2d_to_mesh_mapper( +std::unique_ptr shard_tensor_to_2d_mesh_mapper( MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config); // Creates a composer that concatenates a tensor across a single dimension. @@ -47,7 +47,7 @@ struct Concat2dConfig { int row_dim = -1; int col_dim = -1; }; -std::unique_ptr concat_mesh_2d_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); +std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); // Distributes a host tensor onto multi-device configuration according to the `mapper`. Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorToMesh& mapper); diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index ddaf038faa4c..30e18978b8e5 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -140,9 +140,20 @@ struct Tensor { std::vector get_workers(bool blocking = false) const; // Converts a buffer of elements of type `T` to a `Tensor`. - // Elements are assumed to be stored in row-major order. The size of the span and the type have to match `spec`. + // Elements in the buffer are assumed to be stored in row-major order. The size of the buffer and the type of the + // elements have to match `spec`. // - // TODO: tilized layouts and reduced precision types are currently not supported. + // The data in the buffer is copied into a tensor with an owned storage. + // + // IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t), + // and only row-major layout. + // + // TODO: + // 1. add support for returning a tensor with a borrowed storage based off the buffer. + // 2. add support for sharding. + // 3. add support for block float formats. + // 4. add support for tilized layouts. + // 5. add support for on-device tensor creation. template static Tensor from_span(tt::stl::Span buffer, const TensorSpec& spec); @@ -152,9 +163,17 @@ struct Tensor { return from_span(tt::stl::Span(buffer.data(), buffer.size()), spec); } - // Converts a `Tensor` to a buffer of elements of type `T`. - // Elements in the buffer will be stored in row-major order. The type of the elements has to match that of the - // `Tensor`. + // Converts a `Tensor` to a `std::vector`. + // Elements in the vector will be stored in row-major order. The type of the requested vector has to match that of + // the `Tensor`. + // + // If the tensor resides on a device, it will be brough back to host. + // + // IMPORTANT: this function supports a limited subset of types (float32, bfloat16, uint32_t, int32_t). + // + // TODO: + // 1. add support for sharding. + // 2. add support for block float formats. template std::vector to_vector() const; diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp index 881cd453f3f1..40705ef1d740 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp @@ -45,6 +45,7 @@ auto xtensor_to_span(const xt::xarray& xtensor) { } // Converts an xtensor to a Tensor. +// IMPORTANT: this copies the data into the returned Tensor, which can be an expensive operation. template tt::tt_metal::Tensor from_xtensor(const xt::xarray& buffer, const TensorSpec& spec) { auto shape = get_shape_from_xarray(buffer); @@ -54,6 +55,7 @@ tt::tt_metal::Tensor from_xtensor(const xt::xarray& buffer, const TensorSpec& } // Converts a Tensor to an xtensor. +// IMPORTANT: this copies the data into the returned Tensor, which can be an expensive operation. template xt::xarray to_xtensor(const tt::tt_metal::Tensor& tensor) { auto vec = tensor.to_vector(); diff --git a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp index a77d3ad3a7cf..e01a6838bd8c 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/partition.cpp @@ -94,7 +94,7 @@ std::vector> chunk(const xt::xarray& xtensor, int num_chunks, i } template -xt::xarray concatenate(const std::vector>& v, int dim) { +xt::xarray concat(const std::vector>& v, int dim) { constexpr size_t MAX_TUPLE_SIZE = 64; if (v.empty()) { @@ -112,22 +112,22 @@ xt::xarray concatenate(const std::vector>& v, int dim) { } } -template xt::xarray concatenate(const std::vector>& v, int dim); -template xt::xarray concatenate(const std::vector>& v, int dim); -template xt::xarray concatenate(const std::vector>& v, int dim); -template xt::xarray concatenate(const std::vector>& v, int dim); +template xt::xarray concat(const std::vector>& v, int dim); +template xt::xarray concat(const std::vector>& v, int dim); +template xt::xarray concat(const std::vector>& v, int dim); +template xt::xarray concat(const std::vector>& v, int dim); // Adaptor APIs from xtensor to ttnn::Tensor. namespace adaptor { namespace { template -Tensor concatenate_impl(const std::vector& tensors, const TensorLayout& layout, int dim) { +Tensor concat_impl(const std::vector& tensors, const TensorLayout& layout, int dim) { std::vector> xtensors; for (const auto& tensor : tensors) { xtensors.push_back(to_xtensor(tensor)); } - xt::xarray result = concatenate(xtensors, dim); + xt::xarray result = concat(xtensors, dim); return from_xtensor(result, TensorSpec(get_shape_from_xarray(result), layout)); } @@ -159,14 +159,14 @@ std::vector chunk(const Tensor& tensor, int num_chunks, int dim) { } } -Tensor concatenate(const std::vector& tensors, int dim) { +Tensor concat(const std::vector& tensors, int dim) { TT_FATAL(tensors.size() > 0, "Cannot concatenate an empty list of tensors"); const auto& reference_layout = tensors.front().tensor_spec().tensor_layout(); switch (reference_layout.get_data_type()) { - case DataType::BFLOAT16: return adaptor::concatenate_impl(tensors, reference_layout, dim); - case DataType::FLOAT32: return adaptor::concatenate_impl(tensors, reference_layout, dim); - case DataType::INT32: return adaptor::concatenate_impl(tensors, reference_layout, dim); - case DataType::UINT32: return adaptor::concatenate_impl(tensors, reference_layout, dim); + case DataType::BFLOAT16: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::FLOAT32: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::INT32: return adaptor::concat_impl(tensors, reference_layout, dim); + case DataType::UINT32: return adaptor::concat_impl(tensors, reference_layout, dim); default: TT_THROW("Unsupported data type: {}", reference_layout.get_data_type()); } } diff --git a/ttnn/cpp/ttnn/tensor/xtensor/partition.hpp b/ttnn/cpp/ttnn/tensor/xtensor/partition.hpp index 2cae6db2947f..59a888f446fe 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/partition.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/partition.hpp @@ -19,9 +19,9 @@ template std::vector> chunk(const xt::xarray& tensor, int num_chunks, int dim = 0); // Concatenates a list of tensors along the specified dimension. -tt::tt_metal::Tensor concatenate(const std::vector& tensors, int dim = 0); +tt::tt_metal::Tensor concat(const std::vector& tensors, int dim = 0); template -xt::xarray concatenate(const std::vector>& v, int dim = 0); +xt::xarray concat(const std::vector>& v, int dim = 0); } // namespace ttnn::experimental::xtensor