diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py index c28f34a9879..b8a2d143cee 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py @@ -6,6 +6,7 @@ import pytest from loguru import logger import ttnn +import math from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from models.utility_functions import skip_for_grayskull @@ -18,7 +19,7 @@ def is_unsupported_case(input_shape, math_op, mem_config, num_devices, num_links num_l1_banks = 64 if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: return True, "L1 buffer can't support large tensor sizes" - if input_shape[3] == 32 and input_dtype == ttnn.bfloat8_b: + if (input_shape[2] == 32 or input_shape[3] == 32) and input_dtype == ttnn.bfloat8_b: return True, "This combination is not supported for now" return False, "" @@ -107,16 +108,14 @@ def run_all_reduce_test( tt_input_tensors = [] input_tensors = [] - numel = per_chip_output_shape[0] * per_chip_output_shape[1] * per_chip_output_shape[2] * per_chip_output_shape[3] + numel = math.prod(per_chip_output_shape) if debug: input_tensors[-1] = torch.arange(numel).reshape(per_chip_output_shape).bfloat16() for i in range(num_devices): input_tensor = torch.rand(per_chip_output_shape).bfloat16() - tt_input_tensors.append( - ttnn.Tensor(input_tensor, input_dtype) - .to(layout) - .to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config) - ) + t = ttnn.from_torch(input_tensor, input_dtype, layout=layout) + t = t.to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config) + tt_input_tensors.append(t) input_tensor = input_tensor.view(1, -1, input_tensor.shape[2], input_tensor.shape[3]) input_tensors.append(input_tensor) @@ -146,7 +145,7 @@ def run_all_reduce_test( # Compare mismatch = False for i, t in enumerate(tt_out_tensors): - tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + tt_output_tensor = ttnn.to_torch(t) eq, output = comp_pcc(tt_output_tensor, golden_canonical_out_tensor) mismatch = mismatch or not eq diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp index 1a0712a1699..dbaddc6f814 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp @@ -60,6 +60,112 @@ static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_typ namespace operations{ namespace experimental{ namespace ccl{ + +AllReduceStrategy choose_all_reduce_strategy(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links) { + auto shape = input_tensor.get_logical_shape(); + auto rank = shape.rank(); + + uint32_t all_reduce_dim = -1; + bool optimized_version = false; + + for (uint32_t i = 0; i < rank; ++i) { + if (shape[i] % num_devices == 0) { + all_reduce_dim = i; + optimized_version = true; + } + } + + if(optimized_version){ + if(shape[2] == tt::constants::TILE_HEIGHT || shape[3] == tt::constants::TILE_WIDTH){ + optimized_version = false; // Reduce scatter hangs for this shape + } + + if (input_tensor.get_layout() == ttnn::TILE_LAYOUT) { + if ((all_reduce_dim == 2 && shape[all_reduce_dim] % tt::constants::TILE_HEIGHT != 0) || + (all_reduce_dim == 3 && shape[all_reduce_dim] % tt::constants::TILE_WIDTH != 0)) { + optimized_version = false; + } + } + } + + if (optimized_version) { + return AllReduceStrategy::ReduceScatterAllGather; + } else { + return AllReduceStrategy::AllGatherLocalReduce; + } + + return AllReduceStrategy::Invalid; +} + + +Tensor all_gather_local_reduce(const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config, + const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const std::vector& devices, const ttnn::ccl::Topology& topology) { + + auto shape = input_tensor.get_logical_shape(); + auto rank = shape.rank(); + log_warning( + tt::LogOp, + "Falling back to unoptimized version (all_gather + local reduce) as the input tensor shape {} is not handled by optimized version", shape); + + TT_FATAL(rank == 4, "Tensor rank must be 4, but has {} ", rank); + uint32_t merged_dim_size = 1; + for (uint32_t i = 2; i < rank; ++i) { + merged_dim_size *= shape[i - 2]; + } + + std::vector new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]}; + auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape); + + const auto& gathered_tensor = operation::run( + create_all_gather_struct(reshaped_tensor, 0, num_links, output_mem_config, user_defined_num_workers, + user_defined_num_buffers_per_channel, devices, topology), + {reshaped_tensor}); + + auto sum_tensor = ttnn::sum(gathered_tensor.at(0), 0); + return ttnn::reshape(sum_tensor, shape); +} + +Tensor reduce_scatter_all_gather(const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config, + const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const std::vector& devices, const ttnn::ccl::Topology& topology) { + auto shape = input_tensor.get_logical_shape(); + auto rank = shape.rank(); + + uint32_t all_reduce_dim = -1; + for (uint32_t i = 0; i < rank; ++i) { + if (shape[i] % num_devices == 0) { + all_reduce_dim = i; + } + } + + const auto& reduced_tensor = operation::run( + create_reduce_scatter_struct(input_tensor, binary_op_type, all_reduce_dim, num_links, output_mem_config, + user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology), + {input_tensor}); + + const auto& gathered_tensor = operation::run( + create_all_gather_struct(reduced_tensor.at(0), all_reduce_dim, num_links, output_mem_config, + user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology), + {reduced_tensor.at(0)}); + + return gathered_tensor.at(0); +} + +Tensor run_all_reduce(AllReduceStrategy strategy, const Tensor& input_tensor, const ttnn::operations::binary::BinaryOpType binary_op_type, uint32_t num_devices, uint32_t num_links, const MemoryConfig& output_mem_config, + const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const std::vector& devices, const ttnn::ccl::Topology& topology) { + switch (strategy) { + case AllReduceStrategy::AllGatherLocalReduce: + return all_gather_local_reduce(input_tensor, num_devices, num_links, output_mem_config, + user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology); + case AllReduceStrategy::ReduceScatterAllGather: + return reduce_scatter_all_gather(input_tensor, binary_op_type, num_devices, num_links, output_mem_config, + user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology); + case AllReduceStrategy::Invalid: + default: + TT_FATAL(false, "Invalid strategy selected {} for input tensor shape: {}", strategy, input_tensor.get_logical_shape()); + } +} + + Tensor all_reduce( const Tensor& input_tensor, ttnn::operations::reduction::ReduceType math_op, @@ -83,73 +189,16 @@ Tensor all_reduce( bool is_linear = topology == ttnn::ccl::Topology::Linear; const auto& input_tensor = input_tensors.at(0); - - auto shape = input_tensor.get_logical_shape(); - auto rank = shape.rank(); uint32_t num_devices = devices.size(); - uint32_t all_reduce_dim = -1; - bool optimized_version = false; - for (uint32_t i = 0; i < rank; ++i) { - if(shape[i] % num_devices == 0){ - all_reduce_dim = i; - optimized_version = true; - } - } - if(shape[3] == tt::constants::TILE_WIDTH){ - optimized_version = false; // Reduce scatter hangs for this shape - } - if (input_tensor.get_layout() == ttnn::TILE_LAYOUT){ - if ((all_reduce_dim == 2 && shape[all_reduce_dim] % tt::constants::TILE_HEIGHT != 0) || - (all_reduce_dim == 3 && shape[all_reduce_dim] % tt::constants::TILE_WIDTH != 0)) { - optimized_version = false; - } - } - + // Choose the appropriate strategy + AllReduceStrategy strategy = choose_all_reduce_strategy(input_tensor, num_devices, num_links); - if(optimized_version){ - const auto& reduced_tensor = operation::run( - create_reduce_scatter_struct( - input_tensor, - binary_op_type, - all_reduce_dim, - num_links, - output_mem_config, - user_defined_num_workers, - user_defined_num_buffers_per_channel, - devices, - topology), - {input_tensor}); - - const auto& gathered_tensor = operation::run( - create_all_gather_struct(reduced_tensor.at(0), all_reduce_dim, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology), - {reduced_tensor.at(0)}); - - auto final_output = gathered_tensor.at(0); - return {final_output}; - } - else{ - log_warning( - tt::LogOp, - "Falling back to unoptimized version (all_gather + local reduce) as the input tensor shape {} is not handled by optimized version", shape); + // Run the selected all-reduce operation + Tensor result = run_all_reduce(strategy, input_tensor, binary_op_type, num_devices, num_links, output_mem_config, + user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology); - uint32_t merged_dim_size = 1; - for (uint32_t i = 0; i <= rank - 3; ++i) { - merged_dim_size *= shape[i]; - } - - std::vector new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]}; - auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape); - - const auto& gathered_tensor = operation::run( - create_all_gather_struct(reshaped_tensor, 0, num_links, output_mem_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, topology), - {reshaped_tensor}); - - auto sum_tensor = ttnn::sum(gathered_tensor.at(0), 0); - auto final_output = ttnn::reshape(sum_tensor, shape); - return {final_output}; - - } + return {result}; }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp index ed5caade9c9..3829f4d9df9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp @@ -11,6 +11,13 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" namespace ttnn { +enum class AllReduceStrategy { + AllGatherLocalReduce, + ReduceScatterAllGather, + //Fused, + Invalid +}; + struct AllReduce { const ttnn::operations::binary::BinaryOpType binary_op_type; const uint32_t num_links;