From 8fc5928f8a63609ac66eecbc28426c5b0300a9f9 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Mon, 16 Dec 2024 20:21:46 +0000 Subject: [PATCH] Revert changes to distributed tensor config, as it breaks backward compatibility with tensor serialization --- ttnn/cpp/ttnn/distributed/api.cpp | 4 +++- ttnn/cpp/ttnn/distributed/distributed_tensor.cpp | 4 +++- .../cpp/ttnn/distributed/distributed_tensor_config.cpp | 9 +++------ .../cpp/ttnn/distributed/distributed_tensor_config.hpp | 10 ++++++---- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 1562eb85f99..fee7fa1566c 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -161,7 +161,9 @@ std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_d return std::visit( tt::stl::overloaded{ - [&](const ShardTensor2D& s) { return mesh_device.get_view()->get_devices(s.shard_mesh); }, + [&](const ShardTensor2D& s) { + return mesh_device.get_view()->get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); + }, [&](const auto&) { return get_workers_for_tensor(); }}, host_storage.strategy); } else if (std::holds_alternative(tensor.get_storage())) { diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index b8dcc46f340..4908413132f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -94,7 +94,9 @@ class ShardTensorTo2dMesh : public TensorToMesh { return tensor_shards; } - DistributedTensorConfig config() const override { return DistributedTensorConfig{ShardTensor2D(mesh_shape_)}; } + DistributedTensorConfig config() const override { + return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; + } private: MeshShape mesh_shape_; diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp index 6a06c1f8603..6e69a86b8be 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp @@ -17,10 +17,7 @@ DistributedTensorConfig create_shard_distributed_tensor_config( } DistributedTensorConfig create_shard_2d_distributed_tensor_config( const std::unordered_map& metadata) { - return ShardTensor2D(distributed::MeshShape{ - .num_rows = std::stoi(metadata.at("mesh_shape_y")), - .num_cols = std::stoi(metadata.at("mesh_shape_x")), - }); + return ShardTensor2D(ShardMesh(std::stoi(metadata.at("mesh_shape_y")), std::stoi(metadata.at("mesh_shape_x")))); } DistributedTensorConfig create_replicate_distributed_tensor_config( const std::unordered_map& metadata) { @@ -54,8 +51,8 @@ bool operator==(const AllGatherTensor&, const AllGatherTensor&) { } bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) { return lhs.shard_dimension == rhs.shard_dimension; } bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs) { - return lhs.shard_mesh.num_rows == rhs.shard_mesh.num_rows && // - lhs.shard_mesh.num_cols == rhs.shard_mesh.num_cols; + return lhs.shard_mesh.x == rhs.shard_mesh.x && // + lhs.shard_mesh.y == rhs.shard_mesh.y; } } // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp index 9fbc57f9ae9..6d7c11099d0 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp @@ -7,8 +7,6 @@ #include #include -#include - namespace tt::tt_metal { struct ReplicateTensor { @@ -23,9 +21,13 @@ struct ShardTensor { }; bool operator==(const ShardTensor& lhs, const ShardTensor& rhs); +struct ShardMesh { + std::uint16_t y = 0; + std::uint16_t x = 0; +}; struct ShardTensor2D { - distributed::MeshShape shard_mesh; - ShardTensor2D(distributed::MeshShape mesh) : shard_mesh(std::move(mesh)) {} + ShardMesh shard_mesh; // logic 2D grid that defines the mapping of shards to devices + ShardTensor2D(ShardMesh mesh) : shard_mesh(std::move(mesh)) {} }; bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs);