Skip to content

Commit

Permalink
Feedback from review, included tt-train
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 11, 2024
1 parent 8e5f076 commit ea22481
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
4 changes: 1 addition & 3 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ target_link_libraries(
test_common_libs
INTERFACE
pthread
gtest
gtest_main
gmock
gmock_main
magic_enum
fmt::fmt-header-only
span
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(setup_ttnn_test_target target_name)
test_common_libs
ttnn
Metalium::Metal
GTest::gtest_main
GTest::gmock_main
)
target_include_directories(
${target_name}
Expand Down
2 changes: 1 addition & 1 deletion tt-train/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ file(
add_executable(ttml_tests ${SOURCES})
target_link_libraries(
ttml_tests
GTest::gtest_main
GTest::gmock_main
ttml
)
add_definitions(-DTEST_DATA_DIR="${CMAKE_SOURCE_DIR}/data")
Expand Down
16 changes: 11 additions & 5 deletions tt-train/tests/core/distributed_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <core/xtensor_all_includes.hpp>

#include "core/distributed_mapping.hpp"

namespace {

using ::testing::SizeIs;

template <typename T>
class MeshOpsTest : public ::testing::Test {
protected:
Expand All @@ -25,7 +30,7 @@ TYPED_TEST(MeshOpsTest, ChunkBasicNonDivisible3) {
// Chunk into 3 parts along dimension 0
auto chunks = ttml::core::chunk(tensor, 3, 0);

ASSERT_EQ(chunks.size(), 3u);
ASSERT_THAT(chunks, SizeIs(3));
EXPECT_EQ(chunks[0].shape()[0], 4u); // first chunk size 4
EXPECT_EQ(chunks[1].shape()[0], 4u); // next chunk size 4
EXPECT_EQ(chunks[2].shape()[0], 2u); // last chunk size 2
Expand All @@ -38,7 +43,7 @@ TYPED_TEST(MeshOpsTest, ChunkBasicLessChunksThanProvided) {
// Chunk into 6 parts along dimension 0
auto chunks = ttml::core::chunk(tensor, 6, 0);

ASSERT_EQ(chunks.size(), 5u);
ASSERT_THAT(chunks, SizeIs(5));
EXPECT_EQ(chunks[0].shape()[0], 3u); // first chunk size 3
EXPECT_EQ(chunks[1].shape()[0], 3u); // next chunk size 3
EXPECT_EQ(chunks[2].shape()[0], 3u); // next chunk size 3
Expand All @@ -56,7 +61,7 @@ TYPED_TEST(MeshOpsTest, ShardXTensorToMeshBasicShard) {
auto shards = sharder.map(tensor);

// With 4 shards, each shard should have size 2
ASSERT_EQ(shards.size(), 4u);
ASSERT_THAT(shards, SizeIs(4));
for (auto& s : shards) {
EXPECT_EQ(s.size(), 2u);
}
Expand All @@ -73,7 +78,7 @@ TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) {
ttml::core::ShardTensor2dMesh<TypeParam> sharder(mesh_shape, {0, 1});
auto shards = sharder.map(tensor);

ASSERT_EQ(shards.size(), 4u);
ASSERT_THAT(shards, SizeIs(4));
// Check shapes of shards
for (auto& shard : shards) {
EXPECT_EQ(shard.shape()[0], 2u);
Expand All @@ -90,7 +95,7 @@ TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) {
ttml::core::ReplicateXTensorToMesh<TypeParam> replicator(mesh_shape);
auto replicas = replicator.map(tensor);

ASSERT_EQ(static_cast<int>(replicas.size()), num_devices);
ASSERT_THAT(replicas, SizeIs(num_devices));
for (const auto& t : replicas) {
EXPECT_TRUE(xt::allclose(t, tensor));
}
Expand Down Expand Up @@ -243,3 +248,4 @@ TYPED_TEST(MeshOpsTest, ConcatenateSameParametersAsCompose) {
TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)};
EXPECT_TRUE(xt::allclose(composed, expected));
}
} // namespace

0 comments on commit ea22481

Please sign in to comment.