Skip to content

Commit

Permalink
#0: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Oct 2, 2024
1 parent 73e8bed commit 434a7fc
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 158 deletions.
12 changes: 7 additions & 5 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) {
}
// Iterate over each row and run line all-gather multiple times.
// For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
auto view = MeshDeviceView(*mesh);
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh->get_devices_on_row(row);
auto devs = view.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) {
std::shared_ptr<MeshDevice> mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
auto view = MeshDeviceView(*mesh);
std::vector<Device*> ring_devices = view.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
190 changes: 140 additions & 50 deletions tt_metal/impl/device/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static std::string get_config_path(const std::string& filename) {
return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename;
}

static std::map<LogicalCoordinate, PhysicalCoordinate> load_translation_map(const std::string& filename, const std::string& key) {
static std::unordered_map<LogicalCoordinate, PhysicalCoordinate> load_translation_map(const std::string& filename, const std::string& key) {
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Unable to open file: " + filename);
Expand All @@ -40,7 +40,7 @@ static std::map<LogicalCoordinate, PhysicalCoordinate> load_translation_map(cons
throw std::runtime_error("Key '" + key + "' not found in JSON file: " + filename);
}

std::map<LogicalCoordinate, PhysicalCoordinate> result;
std::unordered_map<LogicalCoordinate, PhysicalCoordinate> result;
for (const auto& mapping : j[key]) {
if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 4) {
throw std::runtime_error("Invalid coordinate format in JSON file: " + filename);
Expand All @@ -65,7 +65,7 @@ MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) {
return shape;
}

std::map<LogicalCoordinate, PhysicalCoordinate> SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) {
std::unordered_map<LogicalCoordinate, PhysicalCoordinate> SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) {
const std::unordered_map<std::size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
Expand Down Expand Up @@ -140,14 +140,22 @@ std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDevi
}
return physical_device_ids;
}
void SystemMesh::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices) {
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
}
this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device});
this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids});
}

std::vector<Device*> SystemMesh::map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
std::size_t num_command_queues,
std::size_t l1_small_size,
std::size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const std::pair<size_t, size_t>& offset,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& user_provided_physical_device_ids) {

auto [requested_num_rows, requested_num_cols] = mesh_device->shape();
Expand All @@ -158,7 +166,6 @@ std::vector<Device*> SystemMesh::map_mesh_device(
TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows);
TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols);

this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device});

auto physical_device_ids = user_provided_physical_device_ids.empty() ?
this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) :
Expand All @@ -171,43 +178,51 @@ std::vector<Device*> SystemMesh::map_mesh_device(
for (auto physical_device_id : physical_device_ids) {
auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id);
mapped_devices.push_back(mapped_device);
this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id);
this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device});
}

this->register_mesh_device(mesh_device, mapped_devices); // TODO: change this
return mapped_devices;
}

void SystemMesh::unmap_mesh_device(const std::shared_ptr<MeshDevice>& mesh_device) {
void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) {
auto mesh_id = mesh_device->get_mesh_id();

// Clean up all state related to this virtual mesh
this->assigned_mesh_device_devices.erase(mesh_id);

// Remove the devices from assigned_physical_id_to_device
for (auto physical_id : this->assigned_devices.at(mesh_id)) {
this->assigned_physical_id_to_device.erase(physical_id);
// Close the devices
if (mesh_device->is_parent_mesh()) {
for (auto physical_id : this->assigned_devices.at(mesh_id)) {
this->assigned_physical_id_to_device.erase(physical_id);
}
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
}
this->assigned_devices.erase(mesh_id);
}

// Close the devices
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
Device* SystemMesh::get_device(const chip_id_t physical_device_id) const {
auto it = this->assigned_physical_id_to_device.find(physical_device_id);
if (it == this->assigned_physical_id_to_device.end()) {
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
}
return it->second;
}

static MeshDeviceID generate_unique_mesh_id() {
static std::atomic<MeshDeviceID> next_id{0};
return next_id++;
}

MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {}
MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, std::shared_ptr<MeshDevice> parent_mesh)
: mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {}

std::shared_ptr<MeshDevice> MeshDevice::create(
const MeshShape& mesh_device_shape,
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
std::size_t l1_small_size,
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<size_t, size_t>& offset,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& user_provided_physical_device_ids)
{
auto mesh_device = std::make_shared<MeshDevice>(mesh_device_shape);
Expand All @@ -216,12 +231,42 @@ std::shared_ptr<MeshDevice> MeshDevice::create(
return mesh_device;
}

std::shared_ptr<MeshDevice> MeshDevice::create_submesh(const MeshShape &submesh_shape, const MeshOffset &offset) {
if (submesh_shape.first <= 0 || submesh_shape.second <= 0) {
TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second);
}

if (offset.first < 0 || offset.second < 0) {
TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second);
}

if (offset.first + submesh_shape.first > this->mesh_device_shape.first ||
offset.second + submesh_shape.second > this->mesh_device_shape.second) {
TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).",
submesh_shape.first, submesh_shape.second,
offset.first, offset.second,
this->mesh_device_shape.first, this->mesh_device_shape.second);
}

auto submesh = std::make_shared<MeshDevice>(submesh_shape, shared_from_this());
auto start_coordinate = Coordinate{offset.first, offset.second};
auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1};
submesh->primary_view = std::make_shared<MeshDeviceView>(*this, start_coordinate, end_coordinate);
submesh->devices = submesh->primary_view->get_devices();
SystemMesh::instance().register_mesh_device(submesh, submesh->devices);
this->submeshes.push_back(submesh);
log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second);
log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices);

return submesh;
}

void MeshDevice::initialize(
size_t l1_small_size,
size_t trace_region_size,
size_t num_command_queues,
std::size_t l1_small_size,
std::size_t trace_region_size,
std::size_t num_command_queues,
DispatchCoreType dispatch_core_type,
const std::pair<size_t, size_t>& offset,
const std::pair<std::size_t, std::size_t>& offset,
const std::vector<chip_id_t>& physical_device_ids)
{
auto [num_rows, num_cols] = this->shape();
Expand All @@ -235,42 +280,36 @@ void MeshDevice::initialize(
auto& instance = SystemMesh::instance();
this->devices = instance.map_mesh_device(
shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids);
this->primary_view = std::make_unique<tt::tt_metal::MeshDeviceView>(*this);

for (int device_index = 0; device_index < this->devices.size(); device_index++) {
this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index});
}
this->primary_view = std::make_shared<tt::tt_metal::MeshDeviceView>(*this);
}

MeshDevice::~MeshDevice() {
if (not this->devices.empty()) {
this->close_devices();
}
for (auto submesh : this->submeshes) {
submesh->close_devices();
}
this->primary_view.reset();
this->devices.clear();
this->parent_mesh.reset();
}

Device* MeshDevice::get_device_index(int logical_device_id) const {
Device* MeshDevice::get_device_index(std::size_t logical_device_id) const {
TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index");
return this->devices.at(logical_device_id);
}

Device* MeshDevice::get_device(int physical_device_id) const {
return this->devices.at(this->physical_id_to_device_index.at(physical_device_id));
Device* MeshDevice::get_device(chip_id_t physical_device_id) const {
return SystemMesh::instance().get_device(physical_device_id);
}

std::vector<Device*> MeshDevice::get_devices() const { return this->devices; }
std::vector<Device*> MeshDevice::get_devices() const { return this->primary_view->get_devices(IterationOrder::LINE); }

Device* MeshDevice::get_device(int row_idx, int col_idx) const {
Device* MeshDevice::get_device(std::size_t row_idx, std::size_t col_idx) const {
return this->get_device_index(row_idx * num_cols() + col_idx);
}

std::vector<Device*> MeshDevice::get_devices_on_row(int row_idx) const {
return this->primary_view->get_devices_on_row(row_idx);
}

std::vector<Device*> MeshDevice::get_devices_on_column(int col_idx) const {
return this->primary_view->get_devices_on_column(col_idx);
}

const DeviceIds MeshDevice::get_device_ids() const {
DeviceIds device_ids;
for (auto device : this->get_devices()) {
Expand All @@ -279,24 +318,23 @@ const DeviceIds MeshDevice::get_device_ids() const {
return device_ids;
}

int MeshDevice::num_devices() const { return num_rows() * num_cols(); }
std::size_t MeshDevice::num_devices() const { return this->devices.size(); }

CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); }

CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); }

tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); }

int MeshDevice::num_rows() const { return this->mesh_device_shape.first; }
std::size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; }

int MeshDevice::num_cols() const { return this->mesh_device_shape.second; }
std::size_t MeshDevice::num_cols() const { return this->mesh_device_shape.second; }

MeshShape MeshDevice::shape() const { return this->mesh_device_shape; }

void MeshDevice::close_devices() {
SystemMesh::instance().unmap_mesh_device(shared_from_this());
SystemMesh::instance().unmap_mesh_device(this);
this->devices.clear();
this->physical_id_to_device_index.clear();
this->primary_view.reset();
}

Expand All @@ -308,8 +346,60 @@ std::shared_ptr<const MeshDeviceView> MeshDevice::get_view() const { return this

std::shared_ptr<MeshDeviceView> MeshDevice::get_view() { return this->primary_view; }

std::vector<std::shared_ptr<MeshDeviceView>> MeshDevice::get_submesh_views() {
std::vector<std::shared_ptr<MeshDeviceView>> submesh_views;
if (this->submeshes.empty()) {
submesh_views.push_back(this->get_view());
}
else {
for (auto submesh : this->submeshes) {
submesh_views.push_back(submesh->get_view());
}
}
return submesh_views;
}

MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; }

bool MeshDevice::is_parent_mesh() const { return this->parent_mesh == nullptr; }

std::shared_ptr<MeshDevice> SystemMesh::get_mesh_device(const std::vector<chip_id_t>& physical_device_ids) {
log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids);
std::unordered_set<chip_id_t> input_set(physical_device_ids.begin(), physical_device_ids.end());

for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) {
const auto& assigned_devices = this->assigned_devices.at(mesh_id);
std::unordered_set<chip_id_t> assigned_set(assigned_devices.begin(), assigned_devices.end());
log_trace(LogMetal, "Assigned devices: {}", assigned_devices);

if (input_set == assigned_set) {
return mesh_device;
}
}
TT_THROW("No mesh device found for the provided devices");
}

std::shared_ptr<MeshDevice> MeshDevice::fetch_mesh_device(const std::vector<Device*>& devices) {
TT_FATAL(devices.size() > 0, "No devices provided");
auto& instance = SystemMesh::instance();
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
}
return instance.get_mesh_device(physical_device_ids);
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::get_submeshes() const { return this->submeshes; }

std::shared_ptr<MeshDeviceView> MeshDevice::get_view(const Device* device) {
for (auto submesh_view : this->get_submesh_views()) {
if (submesh_view->contains_device(device->id())) {
return submesh_view;
}
}
TT_THROW("Device {} not found in any submesh view", device->id());
}

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); }

bool validate_worker_modes(const std::vector<Device*>& workers) {
Expand Down
Loading

0 comments on commit 434a7fc

Please sign in to comment.