diff --git a/tt-train/sources/ttml/core/distributed_mapping.hpp b/tt-train/sources/ttml/core/distributed_mapping.hpp index 102240e51e2..d40644486da 100644 --- a/tt-train/sources/ttml/core/distributed_mapping.hpp +++ b/tt-train/sources/ttml/core/distributed_mapping.hpp @@ -74,7 +74,7 @@ class XTensorToMesh { tt::tt_metal::distributed::MeshShape m_mesh_shape; size_t get_num_devices() const { - return m_mesh_shape.first * m_mesh_shape.second; + return m_mesh_shape.num_rows * m_mesh_shape.num_cols; } }; @@ -130,8 +130,8 @@ class ShardTensor2dMesh : public XTensorToMesh, T> { throw std::invalid_argument("ShardTensor2dMesh requires at least one dimension to shard"); } - int rows = Base::m_mesh_shape.first; - int cols = Base::m_mesh_shape.second; + int rows = Base::m_mesh_shape.num_rows; + int cols = Base::m_mesh_shape.num_cols; auto row_dim = m_dims.first; auto col_dim = m_dims.second; @@ -178,8 +178,8 @@ class ShardTensor2dMesh : public XTensorToMesh, T> { std::unordered_map config_impl() const { return { {"strategy", "shard_2d"}, - {"mesh_shape_y", std::to_string(Base::m_mesh_shape.first)}, - {"mesh_shape_x", std::to_string(Base::m_mesh_shape.second)}}; + {"mesh_shape_y", std::to_string(Base::m_mesh_shape.num_rows)}, + {"mesh_shape_x", std::to_string(Base::m_mesh_shape.num_cols)}}; } private: @@ -193,16 +193,16 @@ class ConcatMesh2dToTensor : public MeshToXTensor, T> { ConcatMesh2dToTensor( tt::tt_metal::distributed::MeshShape mesh_shape, const tt::tt_metal::distributed::MeshShape& dims) : Base(std::move(mesh_shape)), m_dims(dims) { - if (m_dims.first == m_dims.second) { + if (m_dims.num_rows == m_dims.num_cols) { throw std::invalid_argument("Dimensions in 'dims' must be different"); } } std::vector> compose_impl(const std::vector>& tensors) const { - int rows = Base::m_mesh_shape.first; - int cols = Base::m_mesh_shape.second; - size_t row_dim = m_dims.first; - size_t col_dim = m_dims.second; + int rows = Base::m_mesh_shape.num_rows; + int cols = Base::m_mesh_shape.num_cols; + size_t row_dim = m_dims.num_rows; + size_t col_dim = m_dims.num_cols; std::vector> row_concatenated; row_concatenated.reserve(static_cast(rows)); diff --git a/tt-train/tests/core/distributed_test.cpp b/tt-train/tests/core/distributed_test.cpp index e273aaa4973..0f304788ca3 100644 --- a/tt-train/tests/core/distributed_test.cpp +++ b/tt-train/tests/core/distributed_test.cpp @@ -83,7 +83,7 @@ TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) { TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) { tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2}; - int num_devices = mesh_shape.first * mesh_shape.second; // 4 + int num_devices = mesh_shape.num_rows * mesh_shape.num_cols; // 4 auto tensor = xt::arange(4); // [0,1,2,3]