From 8e5f07660e3eb441d8d4bb2b31b491553f7b54a1 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Wed, 11 Dec 2024 04:49:36 +0000 Subject: [PATCH 1/2] Add gmock to cmake dependenciesOC --- tests/CMakeLists.txt | 1 + .../tensor/test_create_tensor_multi_device.cpp | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 544d624088c..c6f28a5c558 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries( pthread gtest gtest_main + gmock magic_enum fmt::fmt-header-only span diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp index f4279cc8753..7ef367335f6 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp @@ -8,6 +8,7 @@ #include "buffers/buffer_constants.hpp" #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/cpp/ttnn/tensor/types.hpp" #include "ttnn/distributed/api.hpp" @@ -17,6 +18,7 @@ namespace ttnn::distributed::test { namespace { +using ::testing::SizeIs; using ::tt::tt_metal::BufferType; using ::tt::tt_metal::Layout; using ::tt::tt_metal::MemoryConfig; @@ -57,7 +59,7 @@ TEST_P(MultiDeviceTensorCreationTest, EmptyLike) { MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); - EXPECT_EQ(tensor.get_workers().size(), 1); + EXPECT_THAT(tensor.get_workers(), SizeIs(1)); const Tensor mesh_replicated_tensor = ttnn::empty_like( tensor, @@ -67,7 +69,7 @@ TEST_P(MultiDeviceTensorCreationTest, EmptyLike) { MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); - EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_THAT(mesh_replicated_tensor.get_workers(), SizeIs(mesh_device->num_devices())); const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); @@ -86,7 +88,7 @@ TEST_P(MultiDeviceTensorCreationTest, Full) { MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); - EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_THAT(mesh_replicated_tensor.get_workers(), SizeIs(mesh_device->num_devices())); EXPECT_EQ(mesh_replicated_tensor.shape(), ttnn::SimpleShape({32, 32})); EXPECT_EQ(mesh_replicated_tensor.dtype(), DataType::BFLOAT16); EXPECT_EQ(mesh_replicated_tensor.layout(), Layout::ROW_MAJOR); @@ -109,7 +111,7 @@ TEST_P(MultiDeviceTensorCreationTest, FullLike) { MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); - EXPECT_EQ(tensor.get_workers().size(), 1); + EXPECT_THAT(tensor.get_workers(), SizeIs(1)); Tensor mesh_replicated_tensor = ttnn::full_like( tensor, @@ -119,7 +121,7 @@ TEST_P(MultiDeviceTensorCreationTest, FullLike) { std::ref(*mesh_device)); EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); - EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_THAT(mesh_replicated_tensor.get_workers(), SizeIs(mesh_device->num_devices())); EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); @@ -161,7 +163,7 @@ TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) { opt_output); EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); - EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_THAT(mesh_replicated_tensor.get_workers(), SizeIs(mesh_device->num_devices())); EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); From ea224810e0a1967d7e66c588f1cb41566a97a470 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Wed, 11 Dec 2024 19:43:56 +0000 Subject: [PATCH 2/2] Feedback from review, included tt-train --- tests/CMakeLists.txt | 4 +--- tests/ttnn/CMakeLists.txt | 2 +- tt-train/tests/CMakeLists.txt | 2 +- tt-train/tests/core/distributed_test.cpp | 16 +++++++++++----- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c6f28a5c558..6a15d0c6db4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,9 +5,7 @@ target_link_libraries( test_common_libs INTERFACE pthread - gtest - gtest_main - gmock + gmock_main magic_enum fmt::fmt-header-only span diff --git a/tests/ttnn/CMakeLists.txt b/tests/ttnn/CMakeLists.txt index c14a587dd72..3117e6b8920 100644 --- a/tests/ttnn/CMakeLists.txt +++ b/tests/ttnn/CMakeLists.txt @@ -6,7 +6,7 @@ function(setup_ttnn_test_target target_name) test_common_libs ttnn Metalium::Metal - GTest::gtest_main + GTest::gmock_main ) target_include_directories( ${target_name} diff --git a/tt-train/tests/CMakeLists.txt b/tt-train/tests/CMakeLists.txt index 0faac2d3ee3..5eee6b8e77b 100644 --- a/tt-train/tests/CMakeLists.txt +++ b/tt-train/tests/CMakeLists.txt @@ -13,7 +13,7 @@ file( add_executable(ttml_tests ${SOURCES}) target_link_libraries( ttml_tests - GTest::gtest_main + GTest::gmock_main ttml ) add_definitions(-DTEST_DATA_DIR="${CMAKE_SOURCE_DIR}/data") diff --git a/tt-train/tests/core/distributed_test.cpp b/tt-train/tests/core/distributed_test.cpp index 0f304788ca3..4d9bc0e8ae6 100644 --- a/tt-train/tests/core/distributed_test.cpp +++ b/tt-train/tests/core/distributed_test.cpp @@ -2,12 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 +#include #include #include #include "core/distributed_mapping.hpp" +namespace { + +using ::testing::SizeIs; + template class MeshOpsTest : public ::testing::Test { protected: @@ -25,7 +30,7 @@ TYPED_TEST(MeshOpsTest, ChunkBasicNonDivisible3) { // Chunk into 3 parts along dimension 0 auto chunks = ttml::core::chunk(tensor, 3, 0); - ASSERT_EQ(chunks.size(), 3u); + ASSERT_THAT(chunks, SizeIs(3)); EXPECT_EQ(chunks[0].shape()[0], 4u); // first chunk size 4 EXPECT_EQ(chunks[1].shape()[0], 4u); // next chunk size 4 EXPECT_EQ(chunks[2].shape()[0], 2u); // last chunk size 2 @@ -38,7 +43,7 @@ TYPED_TEST(MeshOpsTest, ChunkBasicLessChunksThanProvided) { // Chunk into 6 parts along dimension 0 auto chunks = ttml::core::chunk(tensor, 6, 0); - ASSERT_EQ(chunks.size(), 5u); + ASSERT_THAT(chunks, SizeIs(5)); EXPECT_EQ(chunks[0].shape()[0], 3u); // first chunk size 3 EXPECT_EQ(chunks[1].shape()[0], 3u); // next chunk size 3 EXPECT_EQ(chunks[2].shape()[0], 3u); // next chunk size 3 @@ -56,7 +61,7 @@ TYPED_TEST(MeshOpsTest, ShardXTensorToMeshBasicShard) { auto shards = sharder.map(tensor); // With 4 shards, each shard should have size 2 - ASSERT_EQ(shards.size(), 4u); + ASSERT_THAT(shards, SizeIs(4)); for (auto& s : shards) { EXPECT_EQ(s.size(), 2u); } @@ -73,7 +78,7 @@ TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) { ttml::core::ShardTensor2dMesh sharder(mesh_shape, {0, 1}); auto shards = sharder.map(tensor); - ASSERT_EQ(shards.size(), 4u); + ASSERT_THAT(shards, SizeIs(4)); // Check shapes of shards for (auto& shard : shards) { EXPECT_EQ(shard.shape()[0], 2u); @@ -90,7 +95,7 @@ TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) { ttml::core::ReplicateXTensorToMesh replicator(mesh_shape); auto replicas = replicator.map(tensor); - ASSERT_EQ(static_cast(replicas.size()), num_devices); + ASSERT_THAT(replicas, SizeIs(num_devices)); for (const auto& t : replicas) { EXPECT_TRUE(xt::allclose(t, tensor)); } @@ -243,3 +248,4 @@ TYPED_TEST(MeshOpsTest, ConcatenateSameParametersAsCompose) { TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)}; EXPECT_TRUE(xt::allclose(composed, expected)); } +} // namespace