Skip to content

Commit

Permalink
#13454: namespace tt_metal::distributed; add scaffolding for C++ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 11, 2024
1 parent b1ff88e commit 72aa584
Show file tree
Hide file tree
Showing 24 changed files with 147 additions and 55 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ ttnn/ttnn/runtime

ClangBuildAnalyzer.ini
ttnn/ttnn/.rpath_checked__ttnn
Testing
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@

include(CTest)
enable_testing()
include(GoogleTest)
add_library(test_common_libs INTERFACE)
target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt)

if(TT_METAL_BUILD_TESTS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal)
endif(TT_METAL_BUILD_TESTS)

if(TTNN_BUILD_TESTS)
Expand Down
2 changes: 2 additions & 0 deletions tests/tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_subdirectory(distributed)
add_subdirectory(tt_metal)
15 changes: 15 additions & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
)

add_executable(distributed_unit_tests ${UNIT_TESTS_DISTRIBUTED_SRC})
target_link_libraries(distributed_unit_tests PRIVATE tt_metal test_common_libs)

target_include_directories(distributed_unit_tests PRIVATE
${PROJECT_SOURCE_DIR}/tt_metal
${PROJECT_SOURCE_DIR}/tt_metal/distributed
)

set_target_properties(distributed_unit_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/test/tt_metal/distributed)

gtest_discover_tests(distributed_unit_tests)
45 changes: 45 additions & 0 deletions tests/tt_metal/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include "tt_metal/distributed/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/llrt/tt_cluster.hpp"

namespace tt::tt_metal::distributed::test {

static inline void skip_test_if_not_t3000() {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
const auto arch = tt::Cluster::instance().arch();
const size_t num_devices = tt::Cluster::instance().number_of_devices();

if (slow_dispatch) {
GTEST_SKIP() << "Skipping Multi-Device test suite, since it can only be run in Fast Dispatch Mode.";
}
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
}
class MeshDevice_T3000 : public ::testing::Test {
protected:
void SetUp() override {
skip_test_if_not_t3000();
this->mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape(2, 4)));
}

void TearDown() override {
mesh_device_->close_devices();
mesh_device_.reset();
}
std::shared_ptr<MeshDevice> mesh_device_;
};

TEST_F(MeshDevice_T3000, SimpleMeshDeviceTest) {
EXPECT_EQ(mesh_device_->num_devices(), 8);
EXPECT_EQ(mesh_device_->num_rows(), 2);
EXPECT_EQ(mesh_device_->num_cols(), 4);
}

} // namespace tt::tt_metal::distributed::test
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#include "tt_metal/distributed/mesh_device.hpp"

using tt::tt_metal::Device;
using tt::tt_metal::distributed::MeshShape;
using tt::tt_metal::distributed::MeshDevice;
using tt::tt_metal::distributed::MeshDeviceView;
using tt::tt_metal::distributed::MeshDeviceConfig;

class T3000TestDevice {
public:
Expand All @@ -43,7 +47,7 @@ class T3000TestDevice {
num_devices_ = tt::tt_metal::GetNumAvailableDevices();
if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() == 8 and
tt::tt_metal::GetNumPCIeDevices() == 4) {
mesh_device_ = tt::tt_metal::MeshDevice::create(tt::tt_metal::MeshDeviceConfig(tt::tt_metal::MeshShape{2, 4}));
mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4}));

} else {
TT_THROW("This suite can only be run on T3000 Wormhole devices");
Expand All @@ -63,7 +67,7 @@ class T3000TestDevice {

tt::ARCH arch_;
size_t num_devices_;
std::shared_ptr<tt::tt_metal::MeshDevice> mesh_device_;
std::shared_ptr<MeshDevice> mesh_device_;

private:
bool device_open;
Expand Down
16 changes: 8 additions & 8 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ bool is_tgg_system()
return is_galaxy_system && (num_mmio_devices == 8) && (num_devices == 64);
}

MeshShape get_mesh_shape()
ttnn::MeshShape get_mesh_shape()
{
MeshShape shape;
ttnn::MeshShape shape;
if (is_tg_system())
{
shape = {8, 4};
Expand Down Expand Up @@ -116,8 +116,8 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
ttnn::MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<ttnn::MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand All @@ -137,7 +137,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
auto view = ttnn::MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
Expand Down Expand Up @@ -193,11 +193,11 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
}
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
ttnn::MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<ttnn::MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
auto view = MeshDeviceView(*mesh);
auto view = ttnn::MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
Expand Down
16 changes: 3 additions & 13 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
Expand Down Expand Up @@ -295,7 +295,7 @@ void MeshDevice::initialize(
auto& instance = SystemMesh::instance();
this->devices = instance.map_mesh_device(
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config);
this->primary_view = std::make_shared<tt::tt_metal::MeshDeviceView>(*this);
this->primary_view = std::make_shared<MeshDeviceView>(*this);
}

MeshDevice::~MeshDevice() {
Expand Down Expand Up @@ -420,14 +420,4 @@ void MeshDevice::disable_and_clear_program_cache() {
}
}

std::vector<int> get_t3k_physical_device_ids_ring() {
auto& instance = SystemMesh::instance();
auto num_devices = instance.get_num_devices();
TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices");

auto physical_device_ids = instance.get_mapped_physical_device_ids(
MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0}));
return physical_device_ids;
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed
5 changes: 2 additions & 3 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

using DeviceIds = std::vector<int>;
using MeshDeviceID = size_t;
Expand Down Expand Up @@ -186,6 +186,5 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device);
bool validate_worker_modes(const std::vector<Device *> &workers);
std::vector<int> get_t3k_physical_device_ids_ring();

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed
6 changes: 2 additions & 4 deletions tt_metal/distributed/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {

using MeshDevice = tt::tt_metal::MeshDevice;
namespace tt::tt_metal::distributed {

static std::vector<MeshDeviceView::device_pointer> get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector<Coordinate>& coords) {
std::vector<MeshDeviceView::device_pointer> devices;
Expand Down Expand Up @@ -287,4 +285,4 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) {
}
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed
12 changes: 6 additions & 6 deletions tt_metal/distributed/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#include "tt_metal/impl/device/device.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

// Forward declaration of MeshDevice
class MeshDevice;
Expand Down Expand Up @@ -128,19 +128,19 @@ inline MeshDeviceView make_mesh_device_view(std::vector<Device*> devices, MeshDe
return MeshDeviceView(std::move(devices), std::move(mapper));
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed

namespace std {
// Specializations to enable structured bindings
template<> struct tuple_size<tt::tt_metal::Coordinate> : std::integral_constant<size_t, 2> {};
template<size_t I> struct tuple_element<I, tt::tt_metal::Coordinate> {
template<> struct tuple_size<tt::tt_metal::distributed::Coordinate> : std::integral_constant<size_t, 2> {};
template<size_t I> struct tuple_element<I, tt::tt_metal::distributed::Coordinate> {
using type = size_t;
};

// Specialization to enable hashing of Coordinate
template <>
struct hash<tt::tt_metal::Coordinate> {
size_t operator()(const tt::tt_metal::Coordinate& coord) const noexcept {
struct hash<tt::tt_metal::distributed::Coordinate> {
size_t operator()(const tt::tt_metal::distributed::Coordinate& coord) const noexcept {
size_t seed = 0;
tt::utils::hash_combine(seed, coord.row);
tt::utils::hash_combine(seed, coord.col);
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/pybind11/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,13 @@ void tensor_mem_config_module(py::module& m_tensor) {

m_tensor.def(
"load_tensor",
static_cast<Tensor (*)(const std::string&, Device*)>(&load_tensor<Device*>),
py::overload_cast<const std::string&, Device*>(&load_tensor),
py::arg("file_name"),
py::arg("device") = nullptr,
R"doc(Load tensor to file)doc");
m_tensor.def(
"load_tensor",
static_cast<Tensor (*)(const std::string&, MeshDevice*)>(&load_tensor<MeshDevice*>),
py::overload_cast<const std::string&, MeshDevice*>(&load_tensor),
py::arg("file_name"),
py::arg("device") = nullptr,
R"doc(Load tensor to file)doc");
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/distributed/distributed_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void py_module(py::module& module) {
)doc");
module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only());
module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only());
module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring);
module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring);
}

} // namespace ttnn::distributed
11 changes: 11 additions & 0 deletions ttnn/cpp/ttnn/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,15 @@ Tensor aggregate_as_tensor(std::vector<Tensor>& tensor_shards)
}
}

std::vector<int> get_t3k_physical_device_ids_ring() {
using namespace tt::tt_metal::distributed;
auto& instance = SystemMesh::instance();
auto num_devices = instance.get_num_devices();
TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices");

auto physical_device_ids = instance.get_mapped_physical_device_ids(
MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0}));
return physical_device_ids;
}

} // namespace ttnn::distributed
18 changes: 18 additions & 0 deletions ttnn/cpp/ttnn/distributed/types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "tt_metal/distributed/mesh_device.hpp"

namespace ttnn::distributed::types {

using MeshShape = tt::tt_metal::distributed::MeshShape;
using DeviceIds = tt::tt_metal::distributed::DeviceIds;
using MeshDevice = tt::tt_metal::distributed::MeshDevice;
using MeshDeviceView = tt::tt_metal::distributed::MeshDeviceView;
using MeshType = tt::tt_metal::distributed::MeshType;
using MeshDeviceConfig = tt::tt_metal::distributed::MeshDeviceConfig;

} // namespace ttnn::distributed
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/events.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <memory>
#include "tt_metal/impl/event/event.hpp"
#include "ttnn/distributed/types.hpp"

namespace ttnn::events {

Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/events.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

#include <memory>

#include "tt_metal/distributed/mesh_device.hpp"

namespace ttnn::events {

struct MultiDeviceEvent
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/run_operation_inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void launch_op(
return;
}
check_output(output_tensors, workers);
validate_worker_modes(workers);
distributed::validate_worker_modes(workers);
// Record ref counts for all tensors before pushing to worker queue.
std::vector<uint32_t> input_tensor_ref_count(input_tensors.size());
std::vector<uint32_t> optional_input_tensor_ref_count(optional_input_tensors.size());
Expand Down
10 changes: 7 additions & 3 deletions ttnn/cpp/ttnn/tensor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void dump_tensor(const std::string& file_name, const Tensor& tensor, const std::
}

template<typename T>
Tensor load_tensor(const std::string& file_name, T device) {
Tensor load_tensor_helper(const std::string& file_name, T device) {
std::ifstream input_stream(file_name, std::ios::in | std::ios::binary);
if (not input_stream) {
throw std::runtime_error(fmt::format("Cannot open \"{}\"", file_name));
Expand Down Expand Up @@ -320,8 +320,12 @@ Tensor load_tensor(const std::string& file_name, T device) {
}

// Explicit instantiations
template Tensor load_tensor<Device*>(const std::string&, Device*);
template Tensor load_tensor<MeshDevice*>(const std::string&, MeshDevice*);
Tensor load_tensor(const std::string& file_name, Device* device) {
return load_tensor_helper<Device*>(file_name, device);
}
Tensor load_tensor(const std::string& file_name, MeshDevice* device) {
return load_tensor_helper<MeshDevice*>(file_name, device);
}

} // namespace tt_metal

Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/tensor/serialization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace tt_metal {

void dump_tensor(const std::string& file_name, const Tensor& tensor, const std::unordered_map<std::string, std::string>& strategy);

template <typename T>
Tensor load_tensor(const std::string& file_name, T device = nullptr);
Tensor load_tensor(const std::string& file_name, Device* device = nullptr);
Tensor load_tensor(const std::string& file_name, MeshDevice* device = nullptr);

} // namespace tt_metalls

Expand Down
5 changes: 5 additions & 0 deletions ttnn/cpp/ttnn/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ namespace tt {

namespace tt_metal {


namespace distributed {
class MeshDevice;
}
using MeshDevice = tt::tt_metal::distributed::MeshDevice;
struct Tensor {
struct TensorAttributes : public std::enable_shared_from_this<TensorAttributes> {
Storage storage;
Expand Down
Loading

0 comments on commit 72aa584

Please sign in to comment.