Skip to content

Commit

Permalink
#5560: Add enum for strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 25, 2024
1 parent b7d0d22 commit 81e2007
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from loguru import logger
import ttnn
import math
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc
from models.utility_functions import skip_for_grayskull

Expand All @@ -18,7 +19,7 @@ def is_unsupported_case(input_shape, math_op, mem_config, num_devices, num_links
num_l1_banks = 64
if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024:
return True, "L1 buffer can't support large tensor sizes"
if input_shape[3] == 32 and input_dtype == ttnn.bfloat8_b:
if (input_shape[2] == 32 or input_shape[3] == 32) and input_dtype == ttnn.bfloat8_b:
return True, "This combination is not supported for now"

return False, ""
Expand Down Expand Up @@ -107,16 +108,14 @@ def run_all_reduce_test(
tt_input_tensors = []
input_tensors = []

numel = per_chip_output_shape[0] * per_chip_output_shape[1] * per_chip_output_shape[2] * per_chip_output_shape[3]
numel = math.prod(per_chip_output_shape)
if debug:
input_tensors[-1] = torch.arange(numel).reshape(per_chip_output_shape).bfloat16()
for i in range(num_devices):
input_tensor = torch.rand(per_chip_output_shape).bfloat16()
tt_input_tensors.append(
ttnn.Tensor(input_tensor, input_dtype)
.to(layout)
.to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config)
)
t = ttnn.from_torch(input_tensor, input_dtype, layout=layout)
t = t.to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config)
tt_input_tensors.append(t)
input_tensor = input_tensor.view(1, -1, input_tensor.shape[2], input_tensor.shape[3])
input_tensors.append(input_tensor)

Expand Down Expand Up @@ -146,7 +145,7 @@ def run_all_reduce_test(
# Compare
mismatch = False
for i, t in enumerate(tt_out_tensors):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_output_tensor = ttnn.to_torch(t)

eq, output = comp_pcc(tt_output_tensor, golden_canonical_out_tensor)
mismatch = mismatch or not eq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,112 @@ static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_typ
namespace operations{
namespace experimental{
namespace ccl{

AllReduceStrategy choose_all_reduce_strategy(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links) {
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();

uint32_t all_reduce_dim = -1;
bool optimized_version = false;

for (uint32_t i = 0; i < rank; ++i) {
if (shape[i] % num_devices == 0) {
all_reduce_dim = i;
optimized_version = true;
}
}

if(optimized_version){
if(shape[2] == tt::constants::TILE_HEIGHT || shape[3] == tt::constants::TILE_WIDTH){
optimized_version = false; // Reduce scatter hangs for this shape
}

if (input_tensor.get_layout() == ttnn::TILE_LAYOUT) {
if ((all_reduce_dim == 2 && shape[all_reduce_dim] % tt::constants::TILE_HEIGHT != 0) ||
(all_reduce_dim == 3 && shape[all_reduce_dim] % tt::constants::TILE_WIDTH != 0)) {
optimized_version = false;
}
}
}

if (optimized_version) {
return AllReduceStrategy::ReduceScatterAllGather;
} else {
return AllReduceStrategy::AllGatherLocalReduce;
}

return AllReduceStrategy::Invalid;
}


Tensor all_gather_local_reduce(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const std::vector<Device*>& devices, const ttnn::ccl::Topology& topology) {

auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();
log_warning(
tt::LogOp,
"Falling back to unoptimized version (all_gather + local reduce) as the input tensor shape {} is not handled by optimized version", shape);

TT_FATAL(rank == 4, "Tensor rank must be 4, but has {} ", rank);
uint32_t merged_dim_size = 1;
for (uint32_t i = 2; i < rank; ++i) {
merged_dim_size *= shape[i - 2];
}

std::vector<int32_t> new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]};
auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape);

const auto& gathered_tensor = operation::run(
create_all_gather_struct(reshaped_tensor, 0, num_links, output_mem_config, user_defined_num_workers,
user_defined_num_buffers_per_channel, devices, topology),
{reshaped_tensor});

auto sum_tensor = ttnn::sum(gathered_tensor.at(0), 0);
return ttnn::reshape(sum_tensor, shape);
}

Tensor reduce_scatter_all_gather(const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const std::vector<Device*>& devices, const ttnn::ccl::Topology& topology) {
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();

uint32_t all_reduce_dim = -1;
for (uint32_t i = 0; i < rank; ++i) {
if (shape[i] % num_devices == 0) {
all_reduce_dim = i;
}
}

const auto& reduced_tensor = operation::run(
create_reduce_scatter_struct(input_tensor, binary_op_type, all_reduce_dim, num_links, output_mem_config,
user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology),
{input_tensor});

const auto& gathered_tensor = operation::run(
create_all_gather_struct(reduced_tensor.at(0), all_reduce_dim, num_links, output_mem_config,
user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology),
{reduced_tensor.at(0)});

return gathered_tensor.at(0);
}

Tensor run_all_reduce(AllReduceStrategy strategy, const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const std::vector<Device*>& devices, const ttnn::ccl::Topology& topology) {
switch (strategy) {
case AllReduceStrategy::AllGatherLocalReduce:
return all_gather_local_reduce(input_tensor, num_devices, num_links, output_mem_config,
user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology);
case AllReduceStrategy::ReduceScatterAllGather:
return reduce_scatter_all_gather(input_tensor, binary_op_type, num_devices, num_links, output_mem_config,
user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology);
case AllReduceStrategy::Invalid:
default:
TT_FATAL(false, "Invalid strategy selected {} for input tensor shape: {}", strategy, input_tensor.get_logical_shape());
}
}


Tensor all_reduce(
const Tensor& input_tensor,
ttnn::operations::reduction::ReduceType math_op,
Expand All @@ -83,73 +189,16 @@ Tensor all_reduce(
bool is_linear = topology == ttnn::ccl::Topology::Linear;

const auto& input_tensor = input_tensors.at(0);

auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();
uint32_t num_devices = devices.size();

uint32_t all_reduce_dim = -1;
bool optimized_version = false;
for (uint32_t i = 0; i < rank; ++i) {
if(shape[i] % num_devices == 0){
all_reduce_dim = i;
optimized_version = true;
}
}
if(shape[3] == tt::constants::TILE_WIDTH){
optimized_version = false; // Reduce scatter hangs for this shape
}
if (input_tensor.get_layout() == ttnn::TILE_LAYOUT){
if ((all_reduce_dim == 2 && shape[all_reduce_dim] % tt::constants::TILE_HEIGHT != 0) ||
(all_reduce_dim == 3 && shape[all_reduce_dim] % tt::constants::TILE_WIDTH != 0)) {
optimized_version = false;
}
}

// Choose the appropriate strategy
AllReduceStrategy strategy = choose_all_reduce_strategy(input_tensor, num_devices, num_links);

if(optimized_version){
const auto& reduced_tensor = operation::run(
create_reduce_scatter_struct(
input_tensor,
binary_op_type,
all_reduce_dim,
num_links,
output_mem_config,
user_defined_num_workers,
user_defined_num_buffers_per_channel,
devices,
topology),
{input_tensor});

const auto& gathered_tensor = operation::run(
create_all_gather_struct(reduced_tensor.at(0), all_reduce_dim, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology),
{reduced_tensor.at(0)});

auto final_output = gathered_tensor.at(0);
return {final_output};
}
else{
log_warning(
tt::LogOp,
"Falling back to unoptimized version (all_gather + local reduce) as the input tensor shape {} is not handled by optimized version", shape);
// Run the selected all-reduce operation
Tensor result = run_all_reduce(strategy, input_tensor, binary_op_type, num_devices, num_links, output_mem_config,
user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology);

uint32_t merged_dim_size = 1;
for (uint32_t i = 0; i <= rank - 3; ++i) {
merged_dim_size *= shape[i];
}

std::vector<int32_t> new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]};
auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape);

const auto& gathered_tensor = operation::run(
create_all_gather_struct(reshaped_tensor, 0, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology),
{reshaped_tensor});

auto sum_tensor = ttnn::sum(gathered_tensor.at(0), 0);
auto final_output = ttnn::reshape(sum_tensor, shape);
return {final_output};

}
return {result};

},
{input_tensor},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
#include "ttnn/operations/eltwise/binary/binary.hpp"
namespace ttnn {

enum class AllReduceStrategy {
AllGatherLocalReduce,
ReduceScatterAllGather,
//Fused,
Invalid
};

struct AllReduce {
const ttnn::operations::binary::BinaryOpType binary_op_type;
const uint32_t num_links;
Expand Down

0 comments on commit 81e2007

Please sign in to comment.