diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp index e92e583d59ea..585326afc8b1 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp @@ -3,6 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include #include #include "buffers/buffer_constants.hpp" @@ -24,11 +25,11 @@ using ::tt::tt_metal::TensorMemoryLayout; class MultiDeviceTensorCreationTest : public T3kMultiDeviceFixture, public ::testing::WithParamInterface {}; -TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) { +TEST_P(MultiDeviceTensorCreationTest, Empty) { MeshDevice* mesh_device = this->mesh_device_.get(); mesh_device->enable_async(GetParam()); - const auto mesh_replicated_tensor = ttnn::empty( + const Tensor mesh_replicated_tensor = ttnn::empty( ttnn::Shape(std::array{32, 32}), DataType::BFLOAT16, Layout::ROW_MAJOR, @@ -39,7 +40,133 @@ TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) { EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, EmptyLike) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + const Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + const Tensor mesh_replicated_tensor = ttnn::empty_like( + tensor, + DataType::BFLOAT16, + Layout::ROW_MAJOR, + *mesh_device, + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, Full) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + const Tensor mesh_replicated_tensor = ttnn::full( + ttnn::Shape(std::array{32, 32}), + /*fill_value=*/42, + DataType::BFLOAT16, + Layout::ROW_MAJOR, + std::ref(*mesh_device), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), ttnn::SimpleShape({32, 32})); + EXPECT_EQ(mesh_replicated_tensor.dtype(), DataType::BFLOAT16); + EXPECT_EQ(mesh_replicated_tensor.layout(), Layout::ROW_MAJOR); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, FullLike) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + Tensor mesh_replicated_tensor = ttnn::full_like( + tensor, + /*fill_value=*/42, + /*dtype=*/std::nullopt, + /*layout=*/std::nullopt, + std::ref(*mesh_device)); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); + EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + Tensor opt_output = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device, + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + Tensor mesh_replicated_tensor = ttnn::full_like( + tensor, + /*fill_value=*/42, + /*dtype=*/std::nullopt, + /*layout=*/std::nullopt, + /*device=*/std::nullopt, + /*memory_config=*/std::nullopt, + opt_output); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); + EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); } diff --git a/tests/ttnn/unit_tests/operations/test_creation.py b/tests/ttnn/unit_tests/operations/test_creation.py index 07f13d5708f1..f6f6773dc815 100644 --- a/tests/ttnn/unit_tests/operations/test_creation.py +++ b/tests/ttnn/unit_tests/operations/test_creation.py @@ -32,62 +32,6 @@ def test_zeros_like(device, input_shape): assert torch.allclose(torch_output_tensor, output_tensor) -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -def test_zeros_like_bf8b(device, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.zeros_like(torch_input_tensor) - - input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) - output_tensor = ttnn.zeros_like(input_tensor) - assert ttnn.is_tensor_storage_on_device(output_tensor) - output_tensor = ttnn.from_device(output_tensor) - output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16) - - assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, output_tensor) - - -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -@pytest.mark.parametrize( - "layout", - [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], -) -def test_zeros_like_opt(device, layout, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.zeros_like(torch_input_tensor) - opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16) - opt_tensor = ttnn.from_torch( - opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG - ) - - input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout) - input_tensor = ttnn.to_device(input_tensor, device) - - cq_id = 0 - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.zeros_like(input_tensor, optional_tensor=opt_tensor, queue_id=cq_id) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - - assert ttnn.is_tensor_storage_on_device(opt_tensor) - opt_tensor = ttnn.from_device(opt_tensor) - opt_tensor = ttnn.to_torch(opt_tensor) - - assert_with_pcc(torch_output_tensor, opt_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, opt_tensor) - - @pytest.mark.parametrize( "input_shape", [ @@ -110,35 +54,13 @@ def test_ones_like(device, input_shape): assert torch.allclose(torch_output_tensor, output_tensor) -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -def test_ones_like_bf8b(device, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.ones_like(torch_input_tensor) - - input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) - input_tensor = ttnn.to_device(input_tensor, device) - output_tensor = ttnn.ones_like(input_tensor) - assert ttnn.is_tensor_storage_on_device(output_tensor) - output_tensor = ttnn.from_device(output_tensor) - output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16) - - assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, output_tensor) - - @pytest.mark.parametrize( "input_shape", [[32, 32], [5, 96, 64], [1, 2, 64, 64], [1, 2, 4, 64, 64]], ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) def test_full_like(device, input_shape, fill_value): torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) @@ -161,7 +83,7 @@ def test_full_like(device, input_shape, fill_value): ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) def test_full_like_bf8b(device, input_shape, fill_value): torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) @@ -187,7 +109,7 @@ def test_full_like_bf8b(device, input_shape, fill_value): ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) @pytest.mark.parametrize( "layout", @@ -286,6 +208,7 @@ def test_full(device, input_shape, fill_value, layout): [ [32, 32], [5, 96, 64], + [1, 50257], ], ) @pytest.mark.parametrize( @@ -314,6 +237,34 @@ def test_full_with_opt_tensor(device, input_shape, layout, fill_value): assert torch.allclose(torch_tensor, opt_tensor) +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], + [5, 96, 64], + [1, 50257], + ], +) +@pytest.mark.parametrize( + "fill_value", + [-5.25, 0, 1.0], +) +@pytest.mark.parametrize( + "layout", + [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], +) +def test_full_multi_device(mesh_device, input_shape, fill_value, layout): + torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value) + + tensor = ttnn.full(input_shape, device=mesh_device, fill_value=fill_value, layout=layout) + assert ttnn.is_tensor_storage_on_device(tensor) + output_tensors = ttnn.to_torch(tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device)) + + for output_tensor in output_tensors: + assert_with_pcc(torch_tensor, output_tensor, 0.9999) + assert torch.allclose(torch_tensor, output_tensor) + + @pytest.mark.parametrize( "start", [4, 8, 16, 32], @@ -403,7 +354,6 @@ def test_empty_multi_device(mesh_device, input_shapes): ) def test_empty_like(device, input_shapes): torch_input_tensor = torch.ones((input_shapes), dtype=torch.bfloat16) - torch_output_tensor = torch.empty(torch_input_tensor.shape, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT) input_tensor = ttnn.to_device(input_tensor, device) @@ -412,4 +362,28 @@ def test_empty_like(device, input_shapes): output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) - assert list(torch_output_tensor.shape) == list(output_tensor.shape) + assert list(torch_input_tensor.shape) == list(output_tensor.shape) + + +@pytest.mark.parametrize( + "input_shapes", + [ + [2, 1, 4, 4], # 256x256 + [2, 1280, 8, 8], + [2, 640, 16, 16], + [2, 1280, 8, 8], # 512x512 + [2, 1280, 16, 16], + [2, 1280, 16, 16], + ], +) +def test_empty_like_multi_device(mesh_device, input_shapes): + torch_input_tensor = torch.empty((input_shapes), dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT) + input_tensor = ttnn.to_device(input_tensor, mesh_device) + output_tensor = ttnn.empty_like(input_tensor, layout=ttnn.TILE_LAYOUT) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device)) + for output_tensor in output_tensors: + assert list(torch_input_tensor.shape) == list(output_tensor.shape) diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index e7c625d1c7f6..f228b1a95b06 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -166,11 +166,10 @@ tt::tt_metal::Tensor full( (padded[2] + additional_padding_h), (padded[3] + additional_padding_w), }); - // temporary solution to avoid using the device, and use only MeshDevice in highlevel api - return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0))); + return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device)); } // if not padding available, we can just create a tensor with the given shape - return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0))); + return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device)); } tt::tt_metal::Tensor zeros(const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, DataType dtype) { diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index 8a2b510f2857..dc3a298cbfa8 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -424,14 +424,13 @@ struct Arange { } // namespace creation } // namespace operations -constexpr auto full = - ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full", ttnn::operations::creation::Full>(); +constexpr auto full = ttnn::decorators::register_operation<"ttnn::full", ttnn::operations::creation::Full>(); constexpr auto zeros = ttnn::decorators::register_operation<"ttnn::zeros", ttnn::operations::creation::Zeros>(); constexpr auto ones = ttnn::decorators::register_operation<"ttnn::ones", ttnn::operations::creation::Ones>(); constexpr auto empty = ttnn::decorators::register_operation<"ttnn::empty", ttnn::operations::creation::Empty>(); constexpr auto full_like = - ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full_like", ttnn::operations::creation::FullLike>(); + ttnn::decorators::register_operation<"ttnn::full_like", ttnn::operations::creation::FullLike>(); constexpr auto zeros_like = ttnn::decorators::register_operation<"ttnn::zeros_like", ttnn::operations::creation::ZerosLike>(); constexpr auto ones_like = @@ -439,7 +438,6 @@ constexpr auto ones_like = constexpr auto empty_like = ttnn::decorators::register_operation<"ttnn::empty_like", ttnn::operations::creation::EmptyLike>(); -constexpr auto arange = - ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::arange", ttnn::operations::creation::Arange>(); +constexpr auto arange = ttnn::decorators::register_operation<"ttnn::arange", ttnn::operations::creation::Arange>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index c7a5ad1149dc..120528359563 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -61,6 +61,7 @@ static Tensor full( shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); + // TODO: 15061 - Generalize the header to support generic vector / view types. std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); if (!optional_output_tensor.has_value()) { @@ -133,6 +134,7 @@ static Tensor full_impl( } } +// TODO: #14974 - Can this be deleted, as it is only used in tests? template static Tensor full( const tt::tt_metal::LegacyShape& shape, @@ -153,7 +155,7 @@ static Tensor full( std::nullopt); } -// TODO: #14974 - Can this be deleted? +// TODO: #14974 - Can this be deleted, as it is only used in tests? static Tensor zeros( const tt::tt_metal::LegacyShape& shape, const DataType data_type = DataType::BFLOAT16, @@ -164,7 +166,7 @@ static Tensor zeros( return full(shape, 0.0f, data_type, layout, device, output_mem_config); } -// TODO: #14974 - Can this be deleted? +// TODO: #14974 - Can this be deleted, as it is only used in tests? static Tensor ones( const tt::tt_metal::LegacyShape& shape, const DataType data_type = DataType::BFLOAT16, diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index f715080ca92b..518ce518a957 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -239,7 +239,7 @@ struct Tensor { } else if (storage_type == tt::tt_metal::StorageType::MULTI_DEVICE) { std::vector buffers; auto storage = std::get(this->get_storage()); - for (auto buffer : storage.get_buffers()) { + for (const auto& buffer : storage.get_buffers()) { buffers.push_back(buffer.get()); } return buffers; diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 2afd19567e34..12e543db69ec 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -623,8 +623,9 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) // Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers. // No need to check for borrowed storage. if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - tensor.tensor_attributes->num_shards_to_be_populated > 1) + tensor.tensor_attributes->num_shards_to_be_populated > 1) { return tensor; + } if (tensor.storage_type() == StorageType::BORROWED) { ZoneScopedN("CopyBorrowedStorage");