Skip to content

Commit

Permalink
#14267: Move sender receiver computation to ccl_common.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 25, 2024
1 parent e966f77 commit 31ee90b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 52 deletions.
35 changes: 2 additions & 33 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,8 @@ AllGather create_all_gather_struct(
const ttnn::ccl::Topology topology
) {
uint32_t num_devices = devices.size();

uint32_t device_index = 0; // Initialize device index
std::optional<uint32_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<uint32_t> sender_device_id = std::nullopt; // Initialize sender device ID

for (uint32_t i = 0; i < num_devices; ++i) {
if (devices[i] == input_tensor.device()) {
device_index = i;
switch(topology){
case ttnn::ccl::Topology::Ring:{
// Ring topology
receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring
sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring
break;
}
case ttnn::ccl::Topology::Linear:{
// Linear topology
bool is_last_chip_in_clockwise_direction = i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = i == 0;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i+1)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i-1)->id());
break;
}
default:
TT_FATAL(false, "Invalid Topology {}, Accepted topologies are Ring and Linear currently", topology);
}
break;
}
}
auto [device_index, sender_device_id, receiver_device_id] =
getDeviceIndexAndSenderReceiverIDs(input_tensor, devices, topology);

return ttnn::AllGather{
dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), topology};
Expand Down
29 changes: 29 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,35 @@
namespace ttnn {
namespace ccl {

std::tuple<uint32_t, std::optional<chip_id_t>, std::optional<chip_id_t>> getDeviceIndexAndSenderReceiverIDs(
const Tensor& input_tensor,
const std::vector<Device*>& devices,
const ttnn::ccl::Topology& topology) {

uint32_t num_devices = devices.size();
bool is_linear = topology == ttnn::ccl::Topology::Linear;
uint32_t device_index = 0; // Initialize device index
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {
device_index = i;
bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;

std::optional<chip_id_t> receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());

std::optional<chip_id_t> sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());

return {device_index, sender_device_id, receiver_device_id};
}
}

return {device_index, std::nullopt, std::nullopt}; // Return null if the device is not found
}

RingTopology::RingTopology(
Device const* device,
Topology topology,
Expand Down
5 changes: 5 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
namespace ttnn {
namespace ccl {

std::tuple<uint32_t, std::optional<chip_id_t>, std::optional<chip_id_t>> getDeviceIndexAndSenderReceiverIDs(
const Tensor& input_tensor,
const std::vector<Device*>& devices,
const ttnn::ccl::Topology& topology);

// Eventual home: ccl_topology_descriptors
struct RingTopology {
RingTopology(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,11 @@ Tensor reduce_scatter(
if (num_devices == 2){
topology = ttnn::ccl::Topology::Linear;
}
bool is_linear = topology == ttnn::ccl::Topology::Linear;

const auto& input_tensor = input_tensors.at(0);
uint32_t device_index = 0; // Initialize device index
std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {

bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
device_index = i;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
break;
}
}
auto [device_index, sender_device_id, receiver_device_id] =
getDeviceIndexAndSenderReceiverIDs(input_tensor, devices, topology);

TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return operation::run(
Expand Down

0 comments on commit 31ee90b

Please sign in to comment.