Skip to content

Commit

Permalink
Revert changes to distributed tensor config, as it breaks backward co…
Browse files Browse the repository at this point in the history
…mpatibility with tensor serialization
  • Loading branch information
omilyutin-tt committed Dec 16, 2024
1 parent a955137 commit 8fc5928
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ std::vector<Device*> get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_d

return std::visit(
tt::stl::overloaded{
[&](const ShardTensor2D& s) { return mesh_device.get_view()->get_devices(s.shard_mesh); },
[&](const ShardTensor2D& s) {
return mesh_device.get_view()->get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x});
},
[&](const auto&) { return get_workers_for_tensor(); }},
host_storage.strategy);
} else if (std::holds_alternative<MultiDeviceStorage>(tensor.get_storage())) {
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/distributed/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class ShardTensorTo2dMesh : public TensorToMesh {
return tensor_shards;
}

DistributedTensorConfig config() const override { return DistributedTensorConfig{ShardTensor2D(mesh_shape_)}; }
DistributedTensorConfig config() const override {
return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}};
}

private:
MeshShape mesh_shape_;
Expand Down
9 changes: 3 additions & 6 deletions ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ DistributedTensorConfig create_shard_distributed_tensor_config(
}
DistributedTensorConfig create_shard_2d_distributed_tensor_config(
const std::unordered_map<std::string, std::string>& metadata) {
return ShardTensor2D(distributed::MeshShape{
.num_rows = std::stoi(metadata.at("mesh_shape_y")),
.num_cols = std::stoi(metadata.at("mesh_shape_x")),
});
return ShardTensor2D(ShardMesh(std::stoi(metadata.at("mesh_shape_y")), std::stoi(metadata.at("mesh_shape_x"))));
}
DistributedTensorConfig create_replicate_distributed_tensor_config(
const std::unordered_map<std::string, std::string>& metadata) {
Expand Down Expand Up @@ -54,8 +51,8 @@ bool operator==(const AllGatherTensor&, const AllGatherTensor&) {
}
bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) { return lhs.shard_dimension == rhs.shard_dimension; }
bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs) {
return lhs.shard_mesh.num_rows == rhs.shard_mesh.num_rows && //
lhs.shard_mesh.num_cols == rhs.shard_mesh.num_cols;
return lhs.shard_mesh.x == rhs.shard_mesh.x && //
lhs.shard_mesh.y == rhs.shard_mesh.y;
}

} // namespace tt::tt_metal
10 changes: 6 additions & 4 deletions ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <unordered_map>
#include <variant>

#include <ttnn/distributed/types.hpp>

namespace tt::tt_metal {

struct ReplicateTensor {
Expand All @@ -23,9 +21,13 @@ struct ShardTensor {
};
bool operator==(const ShardTensor& lhs, const ShardTensor& rhs);

struct ShardMesh {
std::uint16_t y = 0;
std::uint16_t x = 0;
};
struct ShardTensor2D {
distributed::MeshShape shard_mesh;
ShardTensor2D(distributed::MeshShape mesh) : shard_mesh(std::move(mesh)) {}
ShardMesh shard_mesh; // logic 2D grid that defines the mapping of shards to devices
ShardTensor2D(ShardMesh mesh) : shard_mesh(std::move(mesh)) {}
};
bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs);

Expand Down

0 comments on commit 8fc5928

Please sign in to comment.