Skip to content

Commit

Permalink
#13136: Modify create_all_gather_struct to work with linear
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 1, 2024
1 parent 7013d95 commit c1f11a4
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,27 @@ AllGather create_all_gather_struct(
uint32_t num_devices = devices.size();

uint32_t device_index = 0; // Initialize device index
uint32_t receiver_device_id = 0; // Initialize receiver device ID
uint32_t sender_device_id = 0; // Initialize sender device ID
std::optional<uint32_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<uint32_t> sender_device_id = std::nullopt; // Initialize sender device ID

for (uint32_t i = 0; i < num_devices; ++i) {
if (devices[i] == input_tensor.device()) {
device_index = i;
receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring
sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring
if (topology == ttnn::ccl::Topology::Ring) {
// Ring topology
receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring
sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring
} else if (topology == ttnn::ccl::Topology::Linear) {
// Linear topology
bool is_last_chip_in_clockwise_direction = i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = i == 0;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i+1)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i-1)->id());
}
break;
}
}
Expand Down Expand Up @@ -220,7 +233,7 @@ Tensor all_gather(
const std::optional<size_t> user_defined_num_buffers_per_channel,
const ttnn::ccl::Topology topology) {

TT_FATAL(topology != ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology");
TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology");
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

Expand Down

0 comments on commit c1f11a4

Please sign in to comment.