Skip to content

Commit

Permalink
Fix tt train
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 9, 2024
1 parent 7b89e93 commit c1d4f1b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions tt-train/sources/ttml/core/distributed_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class XTensorToMesh {
tt::tt_metal::distributed::MeshShape m_mesh_shape;

size_t get_num_devices() const {
return m_mesh_shape.first * m_mesh_shape.second;
return m_mesh_shape.num_rows * m_mesh_shape.num_cols;
}
};

Expand Down Expand Up @@ -130,8 +130,8 @@ class ShardTensor2dMesh : public XTensorToMesh<ShardTensor2dMesh<T>, T> {
throw std::invalid_argument("ShardTensor2dMesh requires at least one dimension to shard");
}

int rows = Base::m_mesh_shape.first;
int cols = Base::m_mesh_shape.second;
int rows = Base::m_mesh_shape.num_rows;
int cols = Base::m_mesh_shape.num_cols;
auto row_dim = m_dims.first;
auto col_dim = m_dims.second;

Expand Down Expand Up @@ -178,8 +178,8 @@ class ShardTensor2dMesh : public XTensorToMesh<ShardTensor2dMesh<T>, T> {
std::unordered_map<std::string, std::string> config_impl() const {
return {
{"strategy", "shard_2d"},
{"mesh_shape_y", std::to_string(Base::m_mesh_shape.first)},
{"mesh_shape_x", std::to_string(Base::m_mesh_shape.second)}};
{"mesh_shape_y", std::to_string(Base::m_mesh_shape.num_rows)},
{"mesh_shape_x", std::to_string(Base::m_mesh_shape.num_cols)}};
}

private:
Expand All @@ -193,16 +193,16 @@ class ConcatMesh2dToTensor : public MeshToXTensor<ConcatMesh2dToTensor<T>, T> {
ConcatMesh2dToTensor(
tt::tt_metal::distributed::MeshShape mesh_shape, const tt::tt_metal::distributed::MeshShape& dims) :
Base(std::move(mesh_shape)), m_dims(dims) {
if (m_dims.first == m_dims.second) {
if (m_dims.num_rows == m_dims.num_cols) {
throw std::invalid_argument("Dimensions in 'dims' must be different");
}
}

std::vector<xt::xarray<T>> compose_impl(const std::vector<xt::xarray<T>>& tensors) const {
int rows = Base::m_mesh_shape.first;
int cols = Base::m_mesh_shape.second;
size_t row_dim = m_dims.first;
size_t col_dim = m_dims.second;
int rows = Base::m_mesh_shape.num_rows;
int cols = Base::m_mesh_shape.num_cols;
size_t row_dim = m_dims.num_rows;
size_t col_dim = m_dims.num_cols;

std::vector<xt::xarray<T>> row_concatenated;
row_concatenated.reserve(static_cast<size_t>(rows));
Expand Down
2 changes: 1 addition & 1 deletion tt-train/tests/core/distributed_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) {

TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) {
tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2};
int num_devices = mesh_shape.first * mesh_shape.second; // 4
int num_devices = mesh_shape.num_rows * mesh_shape.num_cols; // 4

auto tensor = xt::arange<TypeParam>(4); // [0,1,2,3]

Expand Down

0 comments on commit c1d4f1b

Please sign in to comment.