Skip to content

Commit

Permalink
#13454: namespace tt_metal::distributed; add scaffolding for C++ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 14, 2024
1 parent db31023 commit 1267b81
Show file tree
Hide file tree
Showing 50 changed files with 541 additions and 392 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ ttnn/ttnn/runtime

ClangBuildAnalyzer.ini
ttnn/ttnn/.rpath_checked__ttnn
Testing
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ project(Metalium
HOMEPAGE_URL "https://github.com/tenstorrent/tt-metal"
LANGUAGES CXX
)
include(CTest)

CHECK_COMPILERS()

Expand Down Expand Up @@ -295,4 +296,3 @@ add_custom_target(clean-built

# Debian Package
include(cmake/packaging.cmake)

2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ add_library(test_common_libs INTERFACE)
target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt)

if(TT_METAL_BUILD_TESTS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal)
endif(TT_METAL_BUILD_TESTS)

if(TTNN_BUILD_TESTS)
Expand Down
2 changes: 2 additions & 0 deletions tests/tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add_subdirectory(distributed)
add_subdirectory(tt_metal)
15 changes: 15 additions & 0 deletions tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
set(UNIT_TESTS_DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_distributed.cpp
)

add_executable(distributed_unit_tests ${UNIT_TESTS_DISTRIBUTED_SRC})
target_link_libraries(distributed_unit_tests PRIVATE tt_metal test_common_libs)

target_include_directories(distributed_unit_tests PRIVATE
${PROJECT_SOURCE_DIR}/tt_metal
${PROJECT_SOURCE_DIR}/tt_metal/distributed
)

set_target_properties(distributed_unit_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/test/tt_metal/distributed)

gtest_discover_tests(distributed_unit_tests)
45 changes: 45 additions & 0 deletions tests/tt_metal/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include "tt_metal/distributed/mesh_device.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/llrt/tt_cluster.hpp"

namespace tt::tt_metal::distributed::test {

static inline void skip_test_if_not_t3000() {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
const auto arch = tt::Cluster::instance().arch();
const size_t num_devices = tt::Cluster::instance().number_of_devices();

if (slow_dispatch) {
GTEST_SKIP() << "Skipping Multi-Device test suite, since it can only be run in Fast Dispatch Mode.";
}
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
}
class MeshDevice_T3000 : public ::testing::Test {
protected:
void SetUp() override {
skip_test_if_not_t3000();
this->mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape(2, 4)));
}

void TearDown() override {
mesh_device_->close_devices();
mesh_device_.reset();
}
std::shared_ptr<MeshDevice> mesh_device_;
};

TEST_F(MeshDevice_T3000, SimpleMeshDeviceTest) {
EXPECT_EQ(mesh_device_->num_devices(), 8);
EXPECT_EQ(mesh_device_->num_rows(), 2);
EXPECT_EQ(mesh_device_->num_cols(), 4);
}

} // namespace tt::tt_metal::distributed::test
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#include "tt_metal/distributed/mesh_device.hpp"

using tt::tt_metal::Device;
using tt::tt_metal::distributed::MeshShape;
using tt::tt_metal::distributed::MeshDevice;
using tt::tt_metal::distributed::MeshDeviceView;
using tt::tt_metal::distributed::MeshDeviceConfig;

class T3000TestDevice {
public:
Expand All @@ -43,7 +47,7 @@ class T3000TestDevice {
num_devices_ = tt::tt_metal::GetNumAvailableDevices();
if (arch_ == tt::ARCH::WORMHOLE_B0 and tt::tt_metal::GetNumAvailableDevices() == 8 and
tt::tt_metal::GetNumPCIeDevices() == 4) {
mesh_device_ = tt::tt_metal::MeshDevice::create(tt::tt_metal::MeshDeviceConfig(tt::tt_metal::MeshShape{2, 4}));
mesh_device_ = MeshDevice::create(MeshDeviceConfig(MeshShape{2, 4}));

} else {
TT_THROW("This suite can only be run on T3000 Wormhole devices");
Expand All @@ -63,7 +67,7 @@ class T3000TestDevice {

tt::ARCH arch_;
size_t num_devices_;
std::shared_ptr<tt::tt_metal::MeshDevice> mesh_device_;
std::shared_ptr<MeshDevice> mesh_device_;

private:
bool device_open;
Expand Down
19 changes: 10 additions & 9 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp"
#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "ttnn/cpp/ttnn/distributed/mesh_device.hpp"
#include "ttnn/distributed/types.hpp"
#include "ttnn/distributed/api.hpp"
#include "ttnn/async_runtime.hpp"
#include "ttnn_multi_command_queue_fixture.hpp"

Expand Down Expand Up @@ -66,9 +67,9 @@ bool is_tgg_system()
return is_galaxy_system && (num_mmio_devices == 8) && (num_devices == 64);
}

MeshShape get_mesh_shape()
ttnn::MeshShape get_mesh_shape()
{
MeshShape shape;
ttnn::MeshShape shape;
if (is_tg_system())
{
shape = {8, 4};
Expand Down Expand Up @@ -116,8 +117,8 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
ttnn::MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<ttnn::MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand All @@ -137,7 +138,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
auto view = ttnn::MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
Expand Down Expand Up @@ -193,11 +194,11 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
}
validate_num_tunnels_and_tunnel_depth();

MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
ttnn::MeshShape mesh_shape = get_mesh_shape();
std::shared_ptr<ttnn::MeshDevice> mesh = ttnn::distributed::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
auto view = MeshDeviceView(*mesh);
auto view = ttnn::MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
Expand Down
5 changes: 2 additions & 3 deletions tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/llrt)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/impl)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/detail)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)

set(TT_METAL_OBJECTS
${CMAKE_CURRENT_SOURCE_DIR}/tt_metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_tracking.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed/mesh_device_view.cpp
$<TARGET_OBJECTS:profiler>
$<TARGET_OBJECTS:common>
$<TARGET_OBJECTS:jit_build>
$<TARGET_OBJECTS:llrt>
$<TARGET_OBJECTS:impl>
$<TARGET_OBJECTS:detail>
$<TARGET_OBJECTS:distributed>
)

add_library(tt_metal ${TT_METAL_OBJECTS})
Expand Down Expand Up @@ -53,4 +53,3 @@ set_target_properties(tt_metal PROPERTIES
if(BUILD_PROGRAMMING_EXAMPLES)
add_subdirectory(programming_examples)
endif(BUILD_PROGRAMMING_EXAMPLES)

7 changes: 7 additions & 0 deletions tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
set(DISTRIBUTED_SRC
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mesh_device_view.cpp)

add_library(distributed OBJECT ${DISTRIBUTED_SRC})
target_link_libraries(distributed PUBLIC common)
9 changes: 9 additions & 0 deletions tt_metal/distributed/distributed.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/distributed/distributed.hpp"

namespace tt::tt_metal::distributed {

} // namespace tt::tt_metal::distributed
21 changes: 21 additions & 0 deletions tt_metal/distributed/distributed.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

namespace tt::tt_metal {

inline namespace v0 {

class Device;
class Tensor;

} // namespace v0

namespace distributed {

class MeshDevice;

} // namespace tt::tt_metal::distributed
} // namespace tt::tt_metal
25 changes: 3 additions & 22 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
Expand Down Expand Up @@ -295,7 +295,7 @@ void MeshDevice::initialize(
auto& instance = SystemMesh::instance();
this->devices = instance.map_mesh_device(
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config);
this->primary_view = std::make_shared<tt::tt_metal::MeshDeviceView>(*this);
this->primary_view = std::make_shared<MeshDeviceView>(*this);
}

MeshDevice::~MeshDevice() {
Expand Down Expand Up @@ -393,15 +393,6 @@ std::vector<std::shared_ptr<MeshDevice>> MeshDevice::get_submeshes() const { ret

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); }

bool validate_worker_modes(const std::vector<Device*>& workers) {
bool worker_modes_match = true;
auto first_worker_mode = workers.at(0)->get_worker_mode();
for (auto worker : workers) {
worker_modes_match &= (worker->get_worker_mode() == first_worker_mode);
}
return worker_modes_match;
}

void MeshDevice::enable_async(bool enable) {
for (auto device : this->devices) {
device->enable_async(enable);
Expand All @@ -420,14 +411,4 @@ void MeshDevice::disable_and_clear_program_cache() {
}
}

std::vector<int> get_t3k_physical_device_ids_ring() {
auto& instance = SystemMesh::instance();
auto num_devices = instance.get_num_devices();
TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices");

auto physical_device_ids = instance.get_mapped_physical_device_ids(
MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0}));
return physical_device_ids;
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed
7 changes: 3 additions & 4 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

using DeviceIds = std::vector<int>;
using MeshDeviceID = size_t;
Expand Down Expand Up @@ -185,7 +185,6 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
};

std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device);
bool validate_worker_modes(const std::vector<Device *> &workers);
std::vector<int> get_t3k_physical_device_ids_ring();

} // namespace tt::tt_metal

} // namespace tt::tt_metal::distributed
6 changes: 2 additions & 4 deletions tt_metal/distributed/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

#include "tt_metal/distributed/mesh_device.hpp"

namespace tt::tt_metal {

using MeshDevice = tt::tt_metal::MeshDevice;
namespace tt::tt_metal::distributed {

static std::vector<MeshDeviceView::device_pointer> get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector<Coordinate>& coords) {
std::vector<MeshDeviceView::device_pointer> devices;
Expand Down Expand Up @@ -287,4 +285,4 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) {
}
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed
12 changes: 6 additions & 6 deletions tt_metal/distributed/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#include "tt_metal/impl/device/device.hpp"

namespace tt::tt_metal {
namespace tt::tt_metal::distributed {

// Forward declaration of MeshDevice
class MeshDevice;
Expand Down Expand Up @@ -128,19 +128,19 @@ inline MeshDeviceView make_mesh_device_view(std::vector<Device*> devices, MeshDe
return MeshDeviceView(std::move(devices), std::move(mapper));
}

} // namespace tt::tt_metal
} // namespace tt::tt_metal::distributed

namespace std {
// Specializations to enable structured bindings
template<> struct tuple_size<tt::tt_metal::Coordinate> : std::integral_constant<size_t, 2> {};
template<size_t I> struct tuple_element<I, tt::tt_metal::Coordinate> {
template<> struct tuple_size<tt::tt_metal::distributed::Coordinate> : std::integral_constant<size_t, 2> {};
template<size_t I> struct tuple_element<I, tt::tt_metal::distributed::Coordinate> {
using type = size_t;
};

// Specialization to enable hashing of Coordinate
template <>
struct hash<tt::tt_metal::Coordinate> {
size_t operator()(const tt::tt_metal::Coordinate& coord) const noexcept {
struct hash<tt::tt_metal::distributed::Coordinate> {
size_t operator()(const tt::tt_metal::distributed::Coordinate& coord) const noexcept {
size_t seed = 0;
tt::utils::hash_combine(seed, coord.row);
tt::utils::hash_combine(seed, coord.col);
Expand Down
2 changes: 1 addition & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/events.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/mesh_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/api.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp
Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/pybind11/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ttnn/tensor/tensor_impl.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/distributed/types.hpp"
#include "tt_metal/host_api.hpp"


Expand Down Expand Up @@ -294,13 +295,13 @@ void tensor_mem_config_module(py::module& m_tensor) {

m_tensor.def(
"load_tensor",
static_cast<Tensor (*)(const std::string&, Device*)>(&load_tensor<Device*>),
py::overload_cast<const std::string&, Device*>(&load_tensor),
py::arg("file_name"),
py::arg("device") = nullptr,
R"doc(Load tensor to file)doc");
m_tensor.def(
"load_tensor",
static_cast<Tensor (*)(const std::string&, MeshDevice*)>(&load_tensor<MeshDevice*>),
py::overload_cast<const std::string&, MeshDevice*>(&load_tensor),
py::arg("file_name"),
py::arg("device") = nullptr,
R"doc(Load tensor to file)doc");
Expand Down
Loading

0 comments on commit 1267b81

Please sign in to comment.