Skip to content

Commit

Permalink
#5560: Initial commit to get reduce_scatter as common
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 22, 2024
1 parent 204e3aa commit 385dce2
Showing 1 changed file with 61 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,55 @@

namespace ttnn {

ReduceScatter create_reduce_scatter_struct (
const Tensor& input_tensor,
const ttnn::operations::binary::BinaryOpType binary_op_type,
const uint32_t scatter_dim,
const uint32_t num_links,
const MemoryConfig output_mem_config,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel,
const std::vector<Device*>& devices,
const ttnn::ccl::Topology topology
){
uint32_t num_devices = devices.size();

bool is_linear = topology == ttnn::ccl::Topology::Linear;

uint32_t device_index = 0; // Initialize device index
std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {

bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
device_index = i;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
break;
}
}
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return ttnn::ReduceScatter{
binary_op_type,
scatter_dim,
num_links,
num_devices,
device_index,
receiver_device_id,
sender_device_id,
output_mem_config,
topology,
user_defined_num_workers,
user_defined_num_buffers_per_channel};
}

void ReduceScatter::validate(const std::vector<Tensor>& input_tensors) const {
for (auto const& t : input_tensors) {
TT_FATAL(
Expand Down Expand Up @@ -77,54 +126,32 @@ Tensor reduce_scatter(
ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op);
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch");

ttnn::ccl::Topology ccl_topology = topology;
auto devices = input_tensor.get_workers();
uint32_t num_devices = devices.size();
if (num_devices == 2){
ccl_topology = ttnn::ccl::Topology::Linear;
}

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[binary_op_type, scatter_dim, num_links, output_mem_config, topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
[binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

uint32_t num_devices = devices.size();
if (num_devices == 2){
topology = ttnn::ccl::Topology::Linear;
}
bool is_linear = topology == ttnn::ccl::Topology::Linear;

const auto& input_tensor = input_tensors.at(0);
uint32_t device_index = 0; // Initialize device index
std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {

bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
device_index = i;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
break;
}
}
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return operation::run(
ttnn::ReduceScatter{
create_reduce_scatter_struct(
input_tensor,
binary_op_type,
scatter_dim,
num_links,
num_devices,
device_index,
receiver_device_id,
sender_device_id,
output_mem_config,
topology,
user_defined_num_workers,
user_defined_num_buffers_per_channel},
user_defined_num_buffers_per_channel,
devices,
ccl_topology),
{input_tensor});
},
{input_tensor},
Expand Down

0 comments on commit 385dce2

Please sign in to comment.