Skip to content

Commit

Permalink
Get rid of ttnn::distributed::api
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Dec 13, 2024
1 parent 19360ce commit c8cc96d
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 42 deletions.
52 changes: 26 additions & 26 deletions tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ TEST_F(TensorDistributionTest, Replication) {
Tensor input_tensor = from_vector(
std::vector<float>{42.F, 13.F, -99.F}, GetTensorSpec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

auto mapper = api::replicate_tensor_to_mesh_mapper(*mesh_device_);
Tensor replicated_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_);
Tensor replicated_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = api::get_device_tensors(replicated_tensor);
std::vector<Tensor> device_tensors = get_device_tensors(replicated_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (const auto& device_tensor : device_tensors) {
EXPECT_THAT(to_vector<float>(device_tensor), ElementsAre(42.F, 13.F, -99.F));
Expand All @@ -44,13 +44,13 @@ TEST_F(TensorDistributionTest, Shard1DInvalidDim) {
std::vector<float>(num_devices, 0), GetTensorSpec(ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32));

EXPECT_ANY_THROW({
auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, -1);
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, -1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});

EXPECT_ANY_THROW({
auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, 4);
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 4);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});
}

Expand All @@ -61,8 +61,8 @@ TEST_F(TensorDistributionTest, Shard1DTooFewShards) {
std::vector<float>{42.F, 13.F, -99.F}, GetTensorSpec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

EXPECT_ANY_THROW({
auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, 3);
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 3);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});
}

Expand All @@ -75,17 +75,17 @@ TEST_F(TensorDistributionTest, Shard1D) {
Tensor input_tensor =
from_vector(test_data, GetTensorSpec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32));

auto mapper = api::shard_tensor_to_mesh_mapper(*mesh_device_, 1);
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = api::get_device_tensors(sharded_tensor);
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = api::concat_mesh_to_tensor_composer(/*dim=*/0);
Tensor concatenated_tensor = api::aggregate_tensor(sharded_tensor, *composer);
auto composer = concat_mesh_to_tensor_composer(/*dim=*/0);
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
from_vector(test_data, GetTensorSpec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32));
Expand All @@ -98,18 +98,18 @@ TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) {
ASSERT_EQ(num_cols, 4);

EXPECT_ANY_THROW(
api::shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));

EXPECT_ANY_THROW(
api::shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DInvalidShardConfig) {
EXPECT_ANY_THROW(api::shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{}));
EXPECT_ANY_THROW(shard_tensor_2d_to_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{}));
}

TEST_F(TensorDistributionTest, Concat2DInvalidConfig) {
EXPECT_ANY_THROW(api::concat_mesh_2d_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2}));
EXPECT_ANY_THROW(concat_mesh_2d_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
Expand All @@ -123,16 +123,16 @@ TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
from_vector(test_data, GetTensorSpec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32));
input_tensor.print();

auto mapper = api::shard_tensor_2d_to_mesh_mapper(
auto mapper = shard_tensor_2d_to_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
.row_dim = 1,
});
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
sharded_tensor.print();

std::vector<Tensor> device_tensors = api::get_device_tensors(sharded_tensor);
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());

int i = 0;
Expand All @@ -157,28 +157,28 @@ TEST_F(TensorDistributionTest, Shard2D) {
Tensor input_tensor =
from_vector(test_data, GetTensorSpec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32));

auto mapper = api::shard_tensor_2d_to_mesh_mapper(
auto mapper = shard_tensor_2d_to_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
.row_dim = 1,
.col_dim = 2,
});
Tensor sharded_tensor = api::distribute_tensor(input_tensor, *mesh_device_, *mapper);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = api::get_device_tensors(sharded_tensor);
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(to_vector<float>(device_tensors[i]), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = api::concat_mesh_2d_to_tensor_composer(
auto composer = concat_mesh_2d_to_tensor_composer(
*mesh_device_,
Concat2dConfig{
.row_dim = 0,
.col_dim = 2,
});
Tensor concatenated_tensor = api::aggregate_tensor(sharded_tensor, *composer);
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
from_vector(test_data, GetTensorSpec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32));
Expand Down
4 changes: 2 additions & 2 deletions tt-train/sources/ttml/core/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace ttml::core {

MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) :
m_mesh_device(ttnn::distributed::api::open_mesh_device(
m_mesh_device(ttnn::distributed::open_mesh_device(
shape,
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
Expand All @@ -24,7 +24,7 @@ MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) :

MeshDevice::~MeshDevice() {
assert(m_mesh_device);
ttnn::distributed::api::close_mesh_device(m_mesh_device);
ttnn::distributed::close_mesh_device(m_mesh_device);
}

} // namespace ttml::core
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/core/tt_tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ template <class T = float>
auto to_xtensor(const tt::tt_metal::Tensor& tensor, const MeshToXTensorVariant<T>& composer) {
auto cpu_tensor = tensor.cpu();
cpu_tensor = cpu_tensor.to(Layout::ROW_MAJOR);
auto cpu_tensors = ttnn::distributed::api::get_device_tensors(cpu_tensor);
auto cpu_tensors = ttnn::distributed::get_device_tensors(cpu_tensor);
std::vector<xt::xarray<T>> res;
res.reserve(cpu_tensors.size());
for (const auto& shard : cpu_tensors) {
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

using namespace tt::tt_metal;

namespace ttnn::distributed::api {
namespace ttnn::distributed {

std::shared_ptr<MeshDevice> open_mesh_device(
const MeshShape& mesh_shape,
Expand Down Expand Up @@ -298,4 +298,4 @@ Tensor create_multi_device_tensor(
}
}

} // namespace ttnn::distributed::api
} // namespace ttnn::distributed
8 changes: 1 addition & 7 deletions ttnn/cpp/ttnn/distributed/api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "ttnn/distributed/types.hpp"
#include "ttnn/distributed/distributed_tensor_config.hpp"

namespace ttnn::distributed::api {
namespace ttnn::distributed {

std::shared_ptr<MeshDevice> open_mesh_device(
const MeshShape& mesh_shape,
Expand Down Expand Up @@ -56,10 +56,4 @@ Tensor create_multi_device_tensor(
tt::tt_metal::StorageType storage_type,
const tt::tt_metal::DistributedTensorConfig& strategy);

} // namespace ttnn::distributed::api

namespace ttnn::distributed {

using namespace api;

} // namespace ttnn::distributed
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "ttnn/distributed/types.hpp"
#include "ttnn/tensor/xtensor/partition.hpp"

namespace ttnn::distributed::api {
namespace ttnn::distributed {
namespace {

class ReplicateTensorToMesh : public TensorToMesh {
Expand Down Expand Up @@ -189,4 +189,4 @@ Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer) {
: composer.compose({tensor});
}

} // namespace ttnn::distributed::api
} // namespace ttnn::distributed
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/distributed/distributed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/distributed/types.hpp"

namespace ttnn::distributed::api {
namespace ttnn::distributed {

// Mapper interface that distributes a host tensor onto a multi-device configuration.
class TensorToMesh {
Expand Down Expand Up @@ -55,4 +55,4 @@ Tensor distribute_tensor(const Tensor& tensor, MeshDevice& mesh_device, TensorTo
// Aggregates a multi-device tensor into a host tensor according to the `composer`.
Tensor aggregate_tensor(const Tensor& tensor, MeshToTensor& composer);

} // namespace ttnn::distributed::api
} // namespace ttnn::distributed

0 comments on commit c8cc96d

Please sign in to comment.