Skip to content

Commit

Permalink
#0: done
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 2, 2024
1 parent 434a7fc commit e49454b
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 62 deletions.
9 changes: 7 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
request.node.pci_ids = device_ids[:num_pcie_devices_requested]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, 4), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1)
ttnn.MeshShape(2, 2),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
offset=(0, 1),
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down Expand Up @@ -280,9 +284,10 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device

request.node.pci_ids = ttnn.get_pcie_device_ids()
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, 8),
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
mesh_type=ttnn.MeshType.Ring,
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ class T3kMultiDeviceFixture : public ::testing::Test {
}
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = MeshDevice::create(
MeshShape{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER);
DispatchCoreType::WORKER,
MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring));
}

void TearDown() override {
Expand Down
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,26 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):

def test_visualize_mesh_device(t3k_mesh_device):
ttnn.visualize_mesh_device(t3k_mesh_device)


def test_matmul_multiple_submeshes(t3k_mesh_device):
"""Test all_gather with multiple submeshes"""

def model(submesh):
ttnn.visualize_mesh_device(submesh)

full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16)
for i in range(submesh.get_num_devices()):
full_tensor[..., i * 32 : (i + 1) * 32] = i

ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3))
ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1)

for device_tensor in ttnn.get_device_tensors(ttnn_tensor):
device_tensor_torch = ttnn.to_torch(device_tensor)
assert torch.all(device_tensor_torch == full_tensor)

submesh_devices = t3k_mesh_device.create_submeshes((2, 2), ttnn.MeshType.Ring)
for submesh in submesh_devices:
model(submesh)
56 changes: 35 additions & 21 deletions tt_metal/impl/device/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,20 @@ std::vector<Device*> SystemMesh::map_mesh_device(
std::size_t l1_small_size,
std::size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& user_provided_physical_device_ids) {
const MeshDeviceConfig& config) {

auto [requested_num_rows, requested_num_cols] = mesh_device->shape();
auto [max_num_rows, max_num_cols] = this->logical_mesh_shape;
auto [row_offset, col_offset] = offset;
auto [row_offset, col_offset] = config.offset;

log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset);
TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows);
TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols);


auto physical_device_ids = user_provided_physical_device_ids.empty() ?
this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) :
user_provided_physical_device_ids;
auto physical_device_ids = config.physical_device_ids.empty() ?
this->get_mapped_physical_device_ids(config) :
config.physical_device_ids;

this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices(
physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type);
Expand All @@ -181,7 +180,7 @@ std::vector<Device*> SystemMesh::map_mesh_device(
this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device});
}

this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this
this->register_mesh_device(mesh_device, mapped_devices); // here
return mapped_devices;
}

Expand Down Expand Up @@ -213,25 +212,27 @@ static MeshDeviceID generate_unique_mesh_id() {
return next_id++;
}

MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr<MeshDevice> parent_mesh)
: mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {}
MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType order, std::shared_ptr<MeshDevice> parent_mesh)
: mesh_device_shape(mesh_device_shape), order(order), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {}

std::shared_ptr<MeshDevice> MeshDevice::create(
const MeshShape& mesh_device_shape,
std::size_t l1_small_size,
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& user_provided_physical_device_ids)
const MeshDeviceConfig& config)
{
auto mesh_device = std::make_shared<MeshDevice>(mesh_device_shape);
mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset, user_provided_physical_device_ids);
auto mesh_device = std::make_shared<MeshDevice>(config.mesh_shape, config.mesh_type);
mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config);

return mesh_device;
}

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) {
std::shared_ptr<MeshDevice> MeshDevice::create_submesh(
const MeshShape &submesh_shape,
const MeshOffset &offset,
MeshType order)
{
if (submesh_shape.first <= 0 || submesh_shape.second <= 0) {
TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second);
}
Expand All @@ -248,7 +249,7 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape &submesh_
this->mesh_device_shape.first, this->mesh_device_shape.second);
}

auto submesh = std::make_shared<MeshDevice>(submesh_shape, shared_from_this());
auto submesh = std::make_shared<MeshDevice>(submesh_shape, order, shared_from_this());
auto start_coordinate = Coordinate{offset.first, offset.second};
auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1};
submesh->primary_view = std::make_shared<MeshDeviceView>(*this, start_coordinate, end_coordinate);
Expand All @@ -261,13 +262,26 @@ std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape &submesh_
return submesh;
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::create_submeshes(
const MeshShape &submesh_shape,
MeshType order)
{
std::vector<std::shared_ptr<MeshDevice>> submeshes;
for (int row = 0; row < this->num_rows(); row += submesh_shape.first) {
for (int col = 0; col < this->num_cols(); col += submesh_shape.second) {
auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, order);
submeshes.push_back(submesh);
}
}
return submeshes;
}

void MeshDevice::initialize(
std::size_t l1_small_size,
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& physical_device_ids)
const MeshDeviceConfig& config)
{
auto [num_rows, num_cols] = this->shape();
auto num_requested_devices = num_rows * num_cols;
Expand All @@ -279,7 +293,7 @@ void MeshDevice::initialize(

auto& instance = SystemMesh::instance();
this->devices = instance.map_mesh_device(
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids);
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config);
this->primary_view = std::make_shared<tt::tt_metal::MeshDeviceView>(*this);
}

Expand All @@ -304,7 +318,7 @@ Device* MeshDevice::get_device(chip_id_t physical_device_id) const {
return SystemMesh::instance().get_device(physical_device_id);
}

std::vector<Device*> MeshDevice::get_devices() const { return this->primary_view->get_devices(IterationOrder::LINE); }
std::vector<Device*> MeshDevice::get_devices() const { return this->primary_view->get_devices(this->order); }

Device* MeshDevice::get_device(std::size_t row_idx, std::size_t col_idx) const {
return this->get_device_index(row_idx * num_cols() + col_idx);
Expand Down Expand Up @@ -417,7 +431,7 @@ std::vector<int> get_t3k_physical_device_ids_ring() {
TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices");

auto physical_device_ids = instance.get_mapped_physical_device_ids(
MeshDeviceConfig{MeshShape{1, 8}, MeshOffset{0, 0}});
MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0}));
return physical_device_ids;
}

Expand Down
42 changes: 33 additions & 9 deletions tt_metal/impl/device/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ class MeshDeviceView;
struct MeshDeviceConfig {
MeshShape mesh_shape;
MeshOffset offset;
std::vector<chip_id_t> physical_device_ids;
MeshType mesh_type;

MeshDeviceConfig(
const MeshShape &mesh_shape,
MeshType mesh_type = MeshType::RowMajor) :
mesh_shape(mesh_shape),
offset(MeshOffset{0, 0}),
physical_device_ids(std::vector<chip_id_t>()),
mesh_type(mesh_type) {}

MeshDeviceConfig(
const MeshShape &mesh_shape,
const MeshOffset &offset = MeshOffset{0, 0},
const std::vector<chip_id_t> &physical_device_ids = {},
MeshType mesh_type = MeshType::RowMajor) :
mesh_shape(mesh_shape),
offset(offset),
physical_device_ids(physical_device_ids),
mesh_type(mesh_type) {}
};

// SystemMesh creates a virtualization over the physical devices in the system.
Expand Down Expand Up @@ -80,8 +100,7 @@ class SystemMesh {
std::size_t l1_small_size,
std::size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t> &offset = {0, 0},
const std::vector<chip_id_t> &physical_device_ids = {});
const MeshDeviceConfig &config);

// Unmap MeshDevice, releasing the associated physical devices.
void unmap_mesh_device(const MeshDevice* mesh_device);
Expand All @@ -93,6 +112,7 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
private:
MeshDeviceID mesh_id;
MeshShape mesh_device_shape;
MeshType order;
std::shared_ptr<MeshDeviceView> primary_view;
std::vector<Device *> devices;
std::shared_ptr<MeshDevice> parent_mesh;
Expand All @@ -103,11 +123,10 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t> &offset,
const std::vector<chip_id_t> &physical_device_ids);
const MeshDeviceConfig &config);

public:
MeshDevice(const MeshShape &mesh_device_shape, std::shared_ptr<MeshDevice> parent_mesh = nullptr);
MeshDevice(const MeshShape &mesh_device_shape, MeshType order, std::shared_ptr<MeshDevice> parent_mesh = nullptr);
~MeshDevice();

MeshDevice(const MeshDevice &) = delete;
Expand Down Expand Up @@ -146,17 +165,22 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
std::vector<std::shared_ptr<MeshDeviceView>> get_submesh_views();
std::shared_ptr<MeshDeviceView> get_view(const Device* device);

std::shared_ptr<MeshDevice> create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset = {0, 0});
std::shared_ptr<MeshDevice> create_submesh(
const MeshShape &submesh_shape,
const MeshOffset &offset = MeshOffset{0, 0},
MeshType order = MeshType::RowMajor);

std::vector<std::shared_ptr<MeshDevice>> create_submeshes(
const MeshShape &submesh_shape,
MeshType order = MeshType::RowMajor);

static std::shared_ptr<MeshDevice> fetch_mesh_device(const std::vector<Device*>& devices);
static std::shared_ptr<MeshDevice> create(
const MeshShape &mesh_device_shape,
std::size_t l1_small_size,
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<std::size_t, std::size_t> &offset = {0, 0},
const std::vector<chip_id_t> &physical_device_ids = {});
const MeshDeviceConfig &config);
};

std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device);
Expand Down
19 changes: 11 additions & 8 deletions tt_metal/impl/device/mesh_device_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,14 @@ std::vector<Coordinate> MeshDeviceView::get_ring_coordinates(const MeshShape& ri

// Traverse the bottom row from right to left, if there is more than one row
if (ring_rows > 1 and ring_cols > 1) {
for (std::size_t col = end_col - 1; col >= start_col; --col) {
boundary_coords.emplace_back(Coordinate{end_row, col});
// Traverse the bottom row from right to left
for (int col = static_cast<int>(end_col - 1); col >= static_cast<int>(start_col); --col) {
boundary_coords.emplace_back(Coordinate{end_row, static_cast<std::size_t>(col)});
}
for (std::size_t row = end_row - 1; row >= start_row; --row) {
boundary_coords.emplace_back(Coordinate{row, start_col});

// Traverse the leftmost column from bottom-1 to top+1
for (int row = static_cast<int>(end_row - 1); row > static_cast<int>(start_row); --row) {
boundary_coords.emplace_back(Coordinate{static_cast<std::size_t>(row), start_col});
}
}

Expand All @@ -271,13 +274,13 @@ std::vector<MeshDeviceView::device_pointer> MeshDeviceView::get_ring_devices() {
return get_devices_from_coordinates(*this, boundary_coords);
}

MeshDeviceView::DeviceView MeshDeviceView::get_devices(IterationOrder order) {
MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType order) {
switch (order) {
case IterationOrder::ROW_MAJOR:
case MeshType::RowMajor:
return this->devices_;
case IterationOrder::RING:
case MeshType::Ring:
return this->get_ring_devices();
case IterationOrder::LINE:
case MeshType::Line:
return this->get_line_devices();
default:
TT_THROW("Unsupported iteration order: {}", order);
Expand Down
14 changes: 7 additions & 7 deletions tt_metal/impl/device/mesh_device_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ struct Coordinate {
* (CCL-ops), such as line all-gather, which require column or row views of the device mesh.
*/

enum class IterationOrder {
ROW_MAJOR,
RING,
LINE
enum class MeshType {
RowMajor,
Ring,
Line
};

class MeshDeviceView {
Expand All @@ -79,7 +79,7 @@ class MeshDeviceView {
// devices are returned in row-major order with start/end coordinates inclusive
[[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end);
[[nodiscard]] DeviceView get_devices(const MeshShape& shape);
[[nodiscard]] DeviceView get_devices(IterationOrder order = IterationOrder::ROW_MAJOR);
[[nodiscard]] DeviceView get_devices(MeshType order = MeshType::RowMajor);

[[nodiscard]] DeviceView get_devices_on_row(std::size_t row) const;
[[nodiscard]] DeviceView get_devices_on_column(std::size_t col) const;
Expand Down Expand Up @@ -110,10 +110,10 @@ class MeshDeviceView {
// The current support only provides left-to-right and right-to-left snaking of the line.
[[nodiscard]] static std::vector<Coordinate> get_line_coordinates(std::size_t length, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols);
[[nodiscard]] std::vector<Coordinate> get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, std::size_t num_rows, std::size_t num_cols);

private:
[[nodiscard]] std::vector<device_pointer> get_ring_devices();
[[nodiscard]] std::vector<device_pointer> get_line_devices();

private:
std::vector<device_pointer> devices_;
std::unordered_map<chip_id_t, Coordinate> device_coordinates_;
Coordinate top_left_;
Expand Down
Loading

0 comments on commit e49454b

Please sign in to comment.