diff --git a/tests/scripts/run_profiler_regressions.sh b/tests/scripts/run_profiler_regressions.sh index f8fc6a69759..12a9b970c65 100755 --- a/tests/scripts/run_profiler_regressions.sh +++ b/tests/scripts/run_profiler_regressions.sh @@ -8,7 +8,7 @@ run_additional_T3000_test(){ remove_default_log_locations mkdir -p $PROFILER_ARTIFACTS_DIR - ./tt_metal/tools/profiler/profile_this.py -c "pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py::test_all_gather_on_t3000_post_commit[mem_config0-input_dtype0-8-1-input_shape1-0-layout1]" > $PROFILER_ARTIFACTS_DIR/test_out.log + ./tt_metal/tools/profiler/profile_this.py -c "'pytest tests/ttnn/unit_tests/operations/test_all_gather.py::test_all_gather_on_t3000_post_commit[mem_config=MemoryConfig\(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt\)-input_dtype=DataType.BFLOAT16-num_devices=8-num_links=1-input_shape=[8,\ 1,\ 33,\ 256]-dim=0-layout=Layout.ROW_MAJOR]'" > $PROFILER_ARTIFACTS_DIR/test_out.log cat $PROFILER_ARTIFACTS_DIR/test_out.log @@ -26,7 +26,7 @@ run_additional_T3000_test(){ run_async_mode_T3000_test(){ #Some tests here do not skip grayskull - if [ "$ARCH_NAME" != "grayskull" ]; then + if [ "$ARCH_NAME" == "wormhole_b0" ]; then remove_default_log_locations mkdir -p $PROFILER_ARTIFACTS_DIR diff --git a/tests/scripts/run_tests.sh b/tests/scripts/run_tests.sh index ebd25264b9c..6f83c484e7b 100755 --- a/tests/scripts/run_tests.sh +++ b/tests/scripts/run_tests.sh @@ -81,7 +81,7 @@ run_frequent_api_pipeline_tests() { ./tests/scripts/run_python_api_unit_tests.sh else if [[ $tt_arch == "wormhole_b0" ]]; then - pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k nightly + pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k nightly else echo "API tests are not available for fast dispatch because they're already covered in post-commit" fi diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 28885569990..1c6b59afcdf 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -70,7 +70,7 @@ run_t3000_tteager_tests() { echo "LOG_METAL: Running run_t3000_tteager_tests" - pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k post_commit ; fail+=$? + pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k post_commit ; fail+=$? pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py ; fail+=$? # distributed layernorm diff --git a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp index 68f3cde0696..3fa4d26db90 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "gtest/gtest.h" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTileRow_ClockWise_In3x5_NumShards3) { @@ -23,7 +23,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t shard_offset = 0; uint16_t old_curr_shard = curr_shard; - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 1); @@ -39,7 +39,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t curr_shard = 0; uint16_t old_curr_shard = curr_shard; - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 1); @@ -57,7 +57,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t old_curr_shard = curr_shard; ASSERT_EQ(curr_core_index, 0); - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 2); @@ -74,7 +74,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t old_curr_shard = curr_shard; ASSERT_EQ(curr_core_index, 0); - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 2); @@ -92,7 +92,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t old_curr_shard = curr_shard; - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 0); @@ -108,7 +108,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil uint16_t curr_shard = 0; uint16_t old_curr_shard = curr_shard; - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); ASSERT_EQ(curr_shard_tile_x, 0); ASSERT_EQ(curr_shard_tile_y, 0); @@ -146,7 +146,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil for (uint16_t i = 0; i < num_core_iterations; i++) { for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); uint16_t next_tile_row = tile_row + 1; if (next_tile_row == input_shard_num_tiles_y) { @@ -184,7 +184,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil for (uint16_t i = 0; i < num_core_iterations; i++) { for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); uint16_t next_tile_row = tile_row + 1; if (next_tile_row == input_shard_num_tiles_y) { @@ -224,7 +224,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) { for (uint16_t tile_col = curr_shard_tile_x; tile_col < input_shard_num_tiles_x; tile_col++) { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance ( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance ( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); uint16_t next_tile_row = tile_row; uint16_t next_tile_col = tile_col + 1; @@ -269,7 +269,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) { for (uint16_t tile_col = curr_shard_tile_x; tile_col < input_shard_num_tiles_x; tile_col++) { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance ( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance ( curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise); uint16_t next_tile_row = tile_row; uint16_t next_tile_col = tile_col + 1; diff --git a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp index af05715ad61..35335223344 100644 --- a/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp +++ b/tests/tt_eager/ops/ccl/test_all_gather_utils.cpp @@ -3,22 +3,22 @@ // SPDX-License-Identifier: Apache-2.0 #include "gtest/gtest.h" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" ////////////////////////////////////////////////////// -/// InputTensorShardAddrGenArgGenerator TESTS +/// ttnn::InputTensorShardAddrGenArgGenerator TESTS ////////////////////////////////////////////////////// // Col major orientation not supported yet TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardStartingLocation_RowMajorOrientation) { - // tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + // ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( // num_workers, input_tensor_shard_grid_size, ring_index, serving_worker_index); { uint32_t ring_size = 8; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 1, 1, 0, ring_size, 0); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -26,7 +26,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta { uint32_t ring_size = 8; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 1, 1, 1, ring_size, 0); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 1); @@ -34,7 +34,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta { uint32_t ring_size = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 4, 16, 0, ring_size, 0); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -42,7 +42,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta { uint32_t ring_size = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 4, 16, 0, ring_size, 1); ASSERT_EQ(dest_worker_index, 2); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -50,7 +50,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta { uint32_t ring_size = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 4, 16, 1, ring_size, 0); ASSERT_EQ(dest_worker_index, 8); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -58,7 +58,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta { uint32_t ring_size = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( 4, 16, 1, ring_size, 1); ASSERT_EQ(dest_worker_index, 10); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -67,7 +67,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta uint32_t ring_size = 8; uint32_t num_workers = 1; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( num_workers, 2, 0, ring_size, 0); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -76,7 +76,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta uint32_t ring_size = 8; uint32_t num_workers = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( num_workers, 2, 0, ring_size, 0); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -85,7 +85,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta uint32_t ring_size = 8; uint32_t num_workers = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( num_workers, 2, 0, ring_size, 1); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 1); @@ -94,7 +94,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta uint32_t ring_size = 8; uint32_t num_workers = 2; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( num_workers, 2, 1, ring_size, 1); ASSERT_EQ(dest_worker_index, 0); ASSERT_EQ(offset_chunk_in_worker, 3); @@ -103,7 +103,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta uint32_t ring_size = 8; uint32_t num_workers = 8; auto const [dest_worker_index, offset_chunk_in_worker] = - tt::tt_metal::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( + ttnn::OutputTensorShardAddrGenArgGenerator::get_first_output_shard_starting_location( num_workers, 32, 1, ring_size, 0); ASSERT_EQ(dest_worker_index, 4); ASSERT_EQ(offset_chunk_in_worker, 0); @@ -113,7 +113,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetFirstOutputShardSta TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores_WidthSharding) { bool is_shard_orientation_row_major = true; - ccl::ShardType shard_type = ccl::ShardType::Width; + ttnn::ccl::ShardType shard_type = ttnn::ccl::ShardType::Width; { // shard grid size = 32, ring size = 8, num_workers = 8 std::vector global_shard_dest_cores = @@ -126,7 +126,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores uint32_t num_workers = 8; { std::cout << "HERE" << std::endl; - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 0, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(4,0)); @@ -136,7 +136,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 1, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(4,0)); @@ -146,7 +146,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 2, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(1,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(5,0)); @@ -156,7 +156,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 3, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(1,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(5,0)); @@ -166,7 +166,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 4, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(2,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(6,0)); @@ -176,7 +176,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 5, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(2,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(6,0)); @@ -186,7 +186,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 6, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(3,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(7,0)); @@ -196,7 +196,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_ComputeWorkerDestCores } { - auto const& dest_cores = tt::tt_metal::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( + auto const& dest_cores = ttnn::OutputTensorShardAddrGenArgGenerator::compute_worker_coord_worker_dest_cores ( shard_type, global_shard_dest_cores, global_shard_dest_cores.size(), global_shard_dest_cores.size() * ring_size, num_workers, 7, is_shard_orientation_row_major); ASSERT_EQ(dest_cores.size(), 8); ASSERT_EQ(dest_cores.at(0), CoreCoord(3,0)); ASSERT_EQ(dest_cores.at(1), CoreCoord(7,0)); @@ -218,21 +218,21 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetIntraCoreStrideInSh uint32_t input_shard_grid_size = 2; uint32_t num_workers = 2; uint32_t ring_size = 8; - auto stride = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); + auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); ASSERT_EQ(stride, 2); } { uint32_t input_shard_grid_size = 4; uint32_t num_workers = 2; uint32_t ring_size = 8; - auto stride = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); + auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); ASSERT_EQ(stride, 3); } { uint32_t input_shard_grid_size = 16; uint32_t num_workers = 4; uint32_t ring_size = 8; - auto stride = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); + auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); // Since we should be striding past the end of the core for this case, we don't care // so either of these values would be valid // the first would be the hypothetical stride if ring_size was bigger @@ -243,7 +243,7 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetIntraCoreStrideInSh uint32_t input_shard_grid_size = 56; uint32_t num_workers = 1; uint32_t ring_size = 8; - auto stride = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); + auto stride = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size,num_workers,ring_size); ASSERT_EQ(stride, 1); } @@ -255,42 +255,42 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetContiguousChunkCoun uint32_t input_shard_grid_size = 1; uint32_t num_workers = 1; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 1); } { uint32_t input_shard_grid_size = 2; uint32_t num_workers = 2; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 1); } { uint32_t input_shard_grid_size = 4; uint32_t num_workers = 2; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 2); } { uint32_t input_shard_grid_size = 16; uint32_t num_workers = 4; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 4); } { uint32_t input_shard_grid_size = 56; uint32_t num_workers = 1; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 1); } { uint32_t input_shard_grid_size = 32; uint32_t num_workers = 8; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); TT_ASSERT(num_contiguous_shards, 4); } } @@ -300,16 +300,16 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetContiguousChunksBef uint32_t input_shard_grid_size = 1; uint32_t num_workers = 1; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); - auto intra_core_stride_in_chunks = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto intra_core_stride_in_chunks = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); ASSERT_EQ(num_contiguous_shards, intra_core_stride_in_chunks); } { uint32_t input_shard_grid_size = 2; uint32_t num_workers = 1; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); - auto intra_core_stride_in_chunks = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto intra_core_stride_in_chunks = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); ASSERT_EQ(num_contiguous_shards, 1); ASSERT_EQ(intra_core_stride_in_chunks, 1); } @@ -317,22 +317,22 @@ TEST(AllGatherUtils, OutputTensorShardAddrGenArgGenerator_GetContiguousChunksBef uint32_t input_shard_grid_size = 16; uint32_t num_workers = 4; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); - auto intra_core_stride_in_chunks = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto intra_core_stride_in_chunks = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); ASSERT_TRUE(num_contiguous_shards == 4 && intra_core_stride_in_chunks == 5); } { uint32_t input_shard_grid_size = 32; uint32_t num_workers = 8; uint32_t ring_size = 8; - auto num_contiguous_shards = OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); - auto intra_core_stride_in_chunks = OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); + auto num_contiguous_shards = ttnn::OutputTensorShardAddrGenArgGenerator::get_contiguous_chunks_before_stride(input_shard_grid_size, num_workers, ring_size); + auto intra_core_stride_in_chunks = ttnn::OutputTensorShardAddrGenArgGenerator::get_intra_core_stride_in_shards(input_shard_grid_size, num_workers, ring_size); ASSERT_TRUE(num_contiguous_shards == 4 && intra_core_stride_in_chunks == 5); } } ////////////////////////////////////////////////////// -/// InputTensorShardAddrGenArgGenerator TESTS +/// ttnn::InputTensorShardAddrGenArgGenerator TESTS ////////////////////////////////////////////////////// TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWidthSharding_2Workers_WorkerIdx0_2Cores) @@ -340,7 +340,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0)}, CoreCoord(1,0)}); uint32_t num_workers = 2; uint32_t worker_index = 0; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( CoreRangeSet(all_shard_cores), worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 1); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); @@ -350,7 +350,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(1,0)}}); uint32_t num_workers = 2; uint32_t worker_index = 1; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 1); ASSERT_EQ(dest_cores.at(0), CoreCoord(1,0)); @@ -361,7 +361,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(3,0)}}); uint32_t num_workers = 2; uint32_t worker_index = 0; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 2); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); @@ -372,7 +372,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(3,0)}}); uint32_t num_workers = 2; uint32_t worker_index = 1; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 2); ASSERT_EQ(dest_cores.at(0), CoreCoord(2,0)); @@ -383,7 +383,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,1)}}); uint32_t num_workers = 4; uint32_t worker_index = 0; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); @@ -396,7 +396,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,1)}}); uint32_t num_workers = 4; uint32_t worker_index = 1; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(4,0)); @@ -409,7 +409,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,1)}}); uint32_t num_workers = 4; uint32_t worker_index = 2; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,1)); @@ -422,7 +422,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,1)}}); uint32_t num_workers = 4; uint32_t worker_index = 3; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(4,1)); @@ -435,7 +435,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,3)}}); uint32_t num_workers = 8; uint32_t worker_index = 0; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,0)); @@ -448,7 +448,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,3)}}); uint32_t num_workers = 8; uint32_t worker_index = 1; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(4,0)); @@ -461,7 +461,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi CoreRangeSet const& all_shard_cores = CoreRangeSet({CoreRange{CoreCoord(0,0), CoreCoord(7,3)}}); uint32_t num_workers = 8; uint32_t worker_index = 2; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); ASSERT_EQ(dest_cores.size(), 4); ASSERT_EQ(dest_cores.at(0), CoreCoord(0,1)); @@ -475,7 +475,7 @@ TEST(AllGatherUtils, InputTensorShardAddrGenArgGenerator_CtorGenerateDestCoresWi uint32_t num_workers = 8; uint32_t worker_index = 3; std::cout << "sup" << std::endl; - auto const& dest_cores = InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( + auto const& dest_cores = ttnn::InputTensorShardAddrGenArgGenerator::ctor_generate_dest_cores( all_shard_cores, worker_index, num_workers); std::cout << "hey" << std::endl; ASSERT_EQ(dest_cores.size(), 4); @@ -499,39 +499,39 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 2; // skip 1 const uint16_t contiguous_chunks_before_stride = 1; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 0); @@ -551,39 +551,39 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 2; // skip 1 const uint16_t contiguous_chunks_before_stride = 1; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); @@ -603,39 +603,39 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 2; // skip 1 const uint16_t contiguous_chunks_before_stride = 1; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 1); @@ -654,38 +654,38 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex1_NumWor const uint16_t intra_core_stride_in_shards = 2; // skip 1 const uint16_t contiguous_chunks_before_stride = 1; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 0); @@ -704,40 +704,40 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex1_NumWor const uint16_t intra_core_stride_in_shards = 2; // skip 1 const uint16_t contiguous_chunks_before_stride = 1; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); // ASSERT_EQ(current_core_chunks_visited, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 1); // Should have moved to the next core - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); @@ -756,49 +756,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 2); @@ -817,49 +817,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex1_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 1 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 2); @@ -877,49 +877,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 2); @@ -928,7 +928,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -949,49 +949,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index, contiguous_chunk_count, total_chunks_per_core, num_dest_cores, intra_core_stride_in_shards, contiguous_chunks_before_stride, is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 2); @@ -1000,7 +1000,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -1020,49 +1020,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index, curr_worker_index, contiguous_chunk_count, total_chunks_per_core, num_dest_cores, intra_core_stride_in_shards, contiguous_chunks_before_stride, is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 2); @@ -1071,7 +1071,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -1092,49 +1092,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 2); @@ -1143,7 +1143,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -1164,49 +1164,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex1_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 1); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 2); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 2); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 2); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 2); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 3); @@ -1215,7 +1215,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex1_NumWor // Check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -1235,49 +1235,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 1); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 2); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 3); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 0); ASSERT_EQ(curr_worker_index, 6); @@ -1286,7 +1286,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } @@ -1306,49 +1306,49 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor const uint16_t intra_core_stride_in_shards = 5; // skip 4 const uint16_t contiguous_chunks_before_stride = 4; - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 0); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 1); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 5); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 2); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 6); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 3); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 7); ASSERT_EQ(curr_worker_index, 7); ASSERT_EQ(contiguous_chunk_count, 4); - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); ASSERT_EQ(curr_core_chunk_index, 4); ASSERT_EQ(curr_worker_index, 6); @@ -1357,7 +1357,7 @@ TEST(AllGatherUtilsDevice, AddrGenAdvanceWidthSharded_RingSize8RingIndex0_NumWor // check for wraparound for (uint32_t i = 0; i < num_dest_cores; i++) { for (uint32_t c = 0; c < contiguous_chunks_before_stride; c++) { - ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( curr_core_chunk_index,curr_worker_index,contiguous_chunk_count,total_chunks_per_core,num_dest_cores,intra_core_stride_in_shards,contiguous_chunks_before_stride,is_clockwise); } } diff --git a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp index 31a90fef183..ab7eb532ad1 100644 --- a/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp +++ b/tests/tt_eager/ops/ccl/test_ccl_helpers.cpp @@ -4,15 +4,15 @@ #include "device/tt_xy_pair.h" #include "gtest/gtest.h" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingMode) { std::size_t num_channels = 4; uint32_t page_size = 2048; - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; - ccl::EriscDataMoverTerminationMode termination_mode = ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; + ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + ttnn::ccl::EriscDataMoverTerminationMode termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED; auto edm_builder = create_erisc_datamover_builder(num_channels, page_size, buffer_sharing_mode, termination_mode); std::vector worker_semaphore_addresses = { @@ -22,18 +22,18 @@ TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingM 0x1030, }; std::vector message_counts = {256, 512, 24, 1}; - std::vector> const& worker_coords = { - {ccl::WorkerXY{1, 1}, ccl::WorkerXY{2, 1}}, - {ccl::WorkerXY{3, 1}}, - {ccl::WorkerXY{4, 1}, ccl::WorkerXY{5, 1}, ccl::WorkerXY{6, 1}}, - {ccl::WorkerXY{1, 2}}, + std::vector> const& worker_coords = { + {ttnn::ccl::WorkerXY{1, 1}, ttnn::ccl::WorkerXY{2, 1}}, + {ttnn::ccl::WorkerXY{3, 1}}, + {ttnn::ccl::WorkerXY{4, 1}, ttnn::ccl::WorkerXY{5, 1}, ttnn::ccl::WorkerXY{6, 1}}, + {ttnn::ccl::WorkerXY{1, 2}}, }; std::vector is_sender_channel{true, false, true, false}; - std::vector channel_buffer_interfaces; + std::vector channel_buffer_interfaces; channel_buffer_interfaces.reserve(num_channels); for (std::size_t i = 0; i < num_channels; i++) { - ccl::EriscDatamoverBuilder::ChannelBufferInterface const& channel_buffer_interface = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& channel_buffer_interface = (is_sender_channel[i]) ? edm_builder.add_sender_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i]) : edm_builder.add_receiver_channel(worker_semaphore_addresses[i], message_counts[i], worker_coords[i]); @@ -55,36 +55,36 @@ TEST(CclHelpers, CreateEriscDatamoverBuilder_Chan4_PageSize2048_RRBufferSharingM TEST(CclHelpers, EriscDatamoverConfig_GetEdmHandshakeAddress_GT_0) { for (std::size_t i = 0; i < 8; i++) { - ASSERT_TRUE(ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0); + ASSERT_TRUE(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0); } } TEST(CclHelpers, EriscDatamoverConfig_GetSemaphoresBaseAddress_GT_0) { for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >= - (ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ccl::EriscDatamoverConfig::handshake_location_size + - ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + ttnn::ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >= + (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ttnn::ccl::EriscDatamoverConfig::handshake_location_size + + ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); } } TEST(CclHelpers, EriscDatamoverConfig_GetBuffersBaseAddress_GT_0) { for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= - (ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ccl::EriscDatamoverConfig::handshake_location_size + - ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= + (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ttnn::ccl::EriscDatamoverConfig::handshake_location_size + + ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); } } TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) { for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= - (ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ccl::EriscDatamoverConfig::handshake_location_size + - ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= + (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + + ttnn::ccl::EriscDatamoverConfig::handshake_location_size + + ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); } } @@ -93,32 +93,32 @@ TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) { ///////////////////////////////////////// // x_y x_y x_y TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { - const auto expected = tt::tt_metal::ccl::coord_t(1, 0); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 1); + const auto expected = ttnn::ccl::coord_t(1, 0); + auto const& result = ttnn::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 1); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { - const auto expected = tt::tt_metal::ccl::coord_t(0, 1); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 1); + const auto expected = ttnn::ccl::coord_t(0, 1); + auto const& result = ttnn::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 1); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_1) { - const auto expected = tt::tt_metal::ccl::coord_t(1, 1); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 1); + const auto expected = ttnn::ccl::coord_t(1, 1); + auto const& result = ttnn::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 1); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { - const auto expected = tt::tt_metal::ccl::coord_t(0, 1); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 2); + const auto expected = ttnn::ccl::coord_t(0, 1); + auto const& result = ttnn::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 2); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { - const auto expected = tt::tt_metal::ccl::coord_t(1, 1); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 2); + const auto expected = ttnn::ccl::coord_t(1, 1); + auto const& result = ttnn::ccl::advance_slice_row_major({1, 0}, {1, 1}, {2, 2}, 2); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } @@ -126,102 +126,102 @@ TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_0__InnerShape_1_1__OuterShape // Test cases pulled from LLama 70B prefill configurations // chip 0 worker 0 link 0 reader unidirectional TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_24_1__OuterShape_32_4__NumWorkers_4) { - const auto worker_slice_offset = tt::tt_metal::ccl::coord_t(0, 0); - const auto worker_slice_shape = tt::tt_metal::ccl::coord_t(24, 1); - const auto tensor_slice_shape = tt::tt_metal::ccl::coord_t(32, 4); + const auto worker_slice_offset = ttnn::ccl::coord_t(0, 0); + const auto worker_slice_shape = ttnn::ccl::coord_t(24, 1); + const auto tensor_slice_shape = ttnn::ccl::coord_t(32, 4); const uint32_t num_workers = 4; - const auto expected = tt::tt_metal::ccl::coord_t(0, 2); - auto const& result_offset = tt::tt_metal::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); + const auto expected = ttnn::ccl::coord_t(0, 2); + auto const& result_offset = ttnn::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_EQ(result_offset.x, expected.x); ASSERT_EQ(result_offset.y, expected.y); - auto const& result_offset2 = tt::tt_metal::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); + auto const& result_offset2 = ttnn::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_TRUE(result_offset2.x >= tensor_slice_shape.x || result_offset2.y >= tensor_slice_shape.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_0__InnerShape_24_1__OuterShape_32_4__NumWorkers_4) { - const auto worker_slice_offset = tt::tt_metal::ccl::coord_t(24, 0); - const auto worker_slice_shape = tt::tt_metal::ccl::coord_t(24, 1); - const auto tensor_slice_shape = tt::tt_metal::ccl::coord_t(32, 4); + const auto worker_slice_offset = ttnn::ccl::coord_t(24, 0); + const auto worker_slice_shape = ttnn::ccl::coord_t(24, 1); + const auto tensor_slice_shape = ttnn::ccl::coord_t(32, 4); const uint32_t num_workers = 4; - const auto expected = tt::tt_metal::ccl::coord_t(24, 2); - auto const& result_offset = tt::tt_metal::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); + const auto expected = ttnn::ccl::coord_t(24, 2); + auto const& result_offset = ttnn::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_EQ(result_offset.x, expected.x); ASSERT_EQ(result_offset.y, expected.y); - auto const& result_offset2 = tt::tt_metal::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); + auto const& result_offset2 = ttnn::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_TRUE(result_offset2.x >= tensor_slice_shape.x || result_offset2.y >= tensor_slice_shape.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_24_1__OuterShape_32_4__NumWorkers_4) { - const auto worker_slice_offset = tt::tt_metal::ccl::coord_t(0, 1); - const auto worker_slice_shape = tt::tt_metal::ccl::coord_t(24, 1); - const auto tensor_slice_shape = tt::tt_metal::ccl::coord_t(32, 4); + const auto worker_slice_offset = ttnn::ccl::coord_t(0, 1); + const auto worker_slice_shape = ttnn::ccl::coord_t(24, 1); + const auto tensor_slice_shape = ttnn::ccl::coord_t(32, 4); const uint32_t num_workers = 4; - const auto expected = tt::tt_metal::ccl::coord_t(0, 3); - auto const& result_offset = tt::tt_metal::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); + const auto expected = ttnn::ccl::coord_t(0, 3); + auto const& result_offset = ttnn::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_EQ(result_offset.x, expected.x); ASSERT_EQ(result_offset.y, expected.y); - auto const& result_offset2 = tt::tt_metal::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); + auto const& result_offset2 = ttnn::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_TRUE(result_offset2.x >= tensor_slice_shape.x || result_offset2.y >= tensor_slice_shape.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_1__InnerShape_24_1__OuterShape_32_4__NumWorkers_4) { - const auto worker_slice_offset = tt::tt_metal::ccl::coord_t(24, 1); - const auto worker_slice_shape = tt::tt_metal::ccl::coord_t(24, 1); - const auto tensor_slice_shape = tt::tt_metal::ccl::coord_t(32, 4); + const auto worker_slice_offset = ttnn::ccl::coord_t(24, 1); + const auto worker_slice_shape = ttnn::ccl::coord_t(24, 1); + const auto tensor_slice_shape = ttnn::ccl::coord_t(32, 4); const uint32_t num_workers = 4; - const auto expected = tt::tt_metal::ccl::coord_t(24, 3); - auto const& result_offset = tt::tt_metal::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); + const auto expected = ttnn::ccl::coord_t(24, 3); + auto const& result_offset = ttnn::ccl::advance_slice_row_major(worker_slice_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_EQ(result_offset.x, expected.x); ASSERT_EQ(result_offset.y, expected.y); - auto const& result_offset2 = tt::tt_metal::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); + auto const& result_offset2 = ttnn::ccl::advance_slice_row_major(result_offset, worker_slice_shape, tensor_slice_shape, num_workers); ASSERT_TRUE(result_offset2.x >= tensor_slice_shape.x || result_offset2.y >= tensor_slice_shape.y); } // Test that we successfully go out of bounds on the last iteration TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 2); + auto const& result = ttnn::ccl::advance_slice_row_major({0, 1}, {1, 1}, {2, 2}, 2); ASSERT_TRUE(result.x >= 2 || result.y >= 2); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_2) { - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({1, 1}, {1, 1}, {2, 2}, 2); + auto const& result = ttnn::ccl::advance_slice_row_major({1, 1}, {1, 1}, {2, 2}, 2); ASSERT_TRUE(result.x >= 2 || result.y >= 2); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_0_0__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) { - const auto expected = tt::tt_metal::ccl::coord_t(1, 1); - auto const& result = tt::tt_metal::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 3); + const auto expected = ttnn::ccl::coord_t(1, 1); + auto const& result = ttnn::ccl::advance_slice_row_major({0, 0}, {1, 1}, {2, 2}, 3); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_1_1__InnerShape_1_1__OuterShape_2_2__NumActiveSlices_3) { - const auto expected = tt::tt_metal::ccl::coord_t(1, 1); - const auto outer_shape = tt::tt_metal::ccl::coord_t(2, 2); - const auto inner_offset = tt::tt_metal::ccl::coord_t(1, 1); - const auto inner_shape = tt::tt_metal::ccl::coord_t(1, 1); + const auto expected = ttnn::ccl::coord_t(1, 1); + const auto outer_shape = ttnn::ccl::coord_t(2, 2); + const auto inner_offset = ttnn::ccl::coord_t(1, 1); + const auto inner_shape = ttnn::ccl::coord_t(1, 1); const uint32_t num_parallel_workers = 3; auto const& result = - tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); + ttnn::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); ASSERT_TRUE(result.x >= outer_shape.x || result.y >= outer_shape.y); } TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_0__InnerShape_24_0__OuterShape_32_4__NumActiveSlices_4) { - const auto expected = tt::tt_metal::ccl::coord_t(24, 2); - const auto outer_shape = tt::tt_metal::ccl::coord_t(32, 4); - const auto inner_offset = tt::tt_metal::ccl::coord_t(24, 0); - const auto inner_shape = tt::tt_metal::ccl::coord_t(24, 1); + const auto expected = ttnn::ccl::coord_t(24, 2); + const auto outer_shape = ttnn::ccl::coord_t(32, 4); + const auto inner_offset = ttnn::ccl::coord_t(24, 0); + const auto inner_shape = ttnn::ccl::coord_t(24, 1); const uint32_t num_parallel_workers = 4; auto const& result = - tt::tt_metal::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); + ttnn::ccl::advance_slice_row_major(inner_offset, inner_shape, outer_shape, num_parallel_workers); ASSERT_EQ(result.x, expected.x); ASSERT_EQ(result.y, expected.y); } @@ -232,7 +232,7 @@ TEST(CclHelper_AdvanceSliceRowMajor, InnerOffset_24_0__InnerShape_24_0__OuterSha TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_AllWorkersSameRow) { auto worker_slice_shapes = std::vector(4, {2, 2}); tt_xy_pair tensor_slice_shape = {8, 4}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); @@ -242,7 +242,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_AllWorkersSame TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToNextRowAligned) { auto worker_slice_shapes = std::vector(4, {2, 2}); tt_xy_pair tensor_slice_shape = {6, 4}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); @@ -253,7 +253,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToN { auto worker_slice_shapes = std::vector(4, {2, 2}); tt_xy_pair tensor_slice_shape = {5, 4}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); @@ -264,7 +264,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_1WorkerWrapToN TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowAligned) { auto worker_slice_shapes = std::vector(8, {2, 2}); tt_xy_pair tensor_slice_shape = {10, 4}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); @@ -279,7 +279,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorker TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorkersWrapToNextRowMisaligned) { auto worker_slice_shapes = std::vector(8, {2, 2}); tt_xy_pair tensor_slice_shape = {9, 4}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(2, 0)); @@ -294,7 +294,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_MultipleWorker TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowAligned) { auto worker_slice_shapes = std::vector(3, {4, 4}); tt_xy_pair tensor_slice_shape = {4, 12}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0, 4)); @@ -304,7 +304,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1Workers TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1WorkersWrapToNextRowMisaligned) { auto worker_slice_shapes = std::vector(3, {4, 3}); tt_xy_pair tensor_slice_shape = {3, 12}; - auto const& worker_slice_offsets = ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( + auto const& worker_slice_offsets = ttnn::ccl::RingReduceScatterTensorSlicer::compute_worker_slice_offsets( worker_slice_shapes, tensor_slice_shape); ASSERT_EQ(worker_slice_offsets.at(0), tt_xy_pair(0, 0)); ASSERT_EQ(worker_slice_offsets.at(1), tt_xy_pair(0, 3)); @@ -314,7 +314,7 @@ TEST(Ccl_RingReduceScatterTensorSlicer, ComputeWorkerSliceOffsets_NMinus1Workers TEST( Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_0_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { - auto worker_slice = ccl::InterleavedTensorWorkerSlice( + auto worker_slice = ttnn::ccl::InterleavedTensorWorkerSlice( tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result tt_xy_pair(32, 4), tt_xy_pair(24, 1), @@ -328,7 +328,7 @@ TEST( TEST( Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_24_0__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { - auto worker_slice = ccl::InterleavedTensorWorkerSlice( + auto worker_slice = ttnn::ccl::InterleavedTensorWorkerSlice( tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result tt_xy_pair(32, 4), tt_xy_pair(24, 1), @@ -342,7 +342,7 @@ TEST( TEST( Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_0_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { - auto worker_slice = ccl::InterleavedTensorWorkerSlice( + auto worker_slice = ttnn::ccl::InterleavedTensorWorkerSlice( tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result tt_xy_pair(32, 4), tt_xy_pair(24, 1), @@ -356,7 +356,7 @@ TEST( TEST( Ccl_InterleavedTensorWorkerSlice_ComputeNumWorkerSliceIterations, InnerOffset_24_1__InnerShape_24_1__OuterShape_32_4__NumActiveSlices_4) { - auto worker_slice = ccl::InterleavedTensorWorkerSlice( + auto worker_slice = ttnn::ccl::InterleavedTensorWorkerSlice( tt_xy_pair(99999, 99999), // tensor shape shouldn't affect the result tt_xy_pair(32, 4), tt_xy_pair(24, 1), diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py b/tests/ttnn/unit_tests/operations/test_all_gather.py similarity index 99% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py rename to tests/ttnn/unit_tests/operations/test_all_gather.py index 1d3126a1486..bf6f8da997b 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather.py @@ -410,14 +410,15 @@ def run_line_all_gather( for i, t in enumerate(input_tensors): tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) for i in range(num_iters): - tt_out_tensors = ttl.tensor.line_all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config) + tt_out_tensor = ttnn.line_all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) for d in devices: ttl.device.Synchronize(d) logger.info(f"Done iteration {i}") - for i, t in enumerate(tt_out_tensors): + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() if input_dtype == ttl.tensor.DataType.BFLOAT16: eq, output = comp_equal(tt_output_tensor, input_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py index b84514ff8fa..4d3cb45519f 100644 --- a/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py +++ b/tests/ttnn/unit_tests/operations/test_distributed_layernorm.py @@ -40,12 +40,12 @@ def tt_distributed_layernorm(inp, gamma, beta, epsilon, is_rmsnorm, compute_kern ) ) + tt_stats = ttnn.aggregate_as_tensor(tt_stats) # AllGather stats - tt_stats = ttnn.experimental.tensor.all_gather( - tt_stats, dim=3, num_links=1, output_mem_config=ttnn.DRAM_MEMORY_CONFIG - ) + tt_stats = ttnn.all_gather(tt_stats, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) # Run layernorm part 2 + tt_stats = ttnn.get_device_tensors(tt_stats) tt_out = [] for d in range(n_devices): if is_rmsnorm: diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 9396614b94e..23ba8f58f90 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -30,6 +30,10 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/example_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/single_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/examples/example/device/multi_core_program_factory.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 5aacef48291..cf6d26d6d57 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -7,7 +7,8 @@ #include #include -#include "pybind11/operations/ccl.hpp" +#include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" +#include "ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp" #include "pybind11/operations/conv2d.hpp" #include "pybind11/operations/copy.hpp" #include "pybind11/operations/core.hpp" @@ -62,6 +63,10 @@ void py_module(py::module& module) { auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); unary_backward::py_module(m_unary_backward); + auto m_ccl = module.def_submodule("ccl", "collective communication operations"); + ccl::py_bind_all_gather(m_ccl); + ccl::py_bind_line_all_gather(m_ccl); + auto m_complex_unary = module.def_submodule("complex_unary", "complex_unary operations"); complex_unary::py_module(m_complex_unary); @@ -98,9 +103,6 @@ void py_module(py::module& module) { auto m_reduction = module.def_submodule("reduction", "reduction operations"); reduction::py_module(m_reduction); - auto m_ccl = module.def_submodule("ccl", "collective communication operations"); - ccl::py_module(m_ccl); - auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations"); kv_cache::py_module(m_kv_cache); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt index 68c9eeaec1d..c3e9e6904f5 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt @@ -4,11 +4,8 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/auto_format.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_transfer/data_transfer_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout_conversion/layout_conversion_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/multi_core/all_gather_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/reduce_scatter_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ccl/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/sharded_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded/multi_core/sharded_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sharded_partial/sharded_op_partial.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp index 532a9f7c56e..2088728d2f9 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp @@ -8,9 +8,9 @@ #include "impl/buffers/buffer.hpp" #include "impl/kernels/data_types.hpp" #include "tensor/tensor_impl.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/buffers/circular_buffer_types.hpp" @@ -57,7 +57,7 @@ struct WorkerTransferInfo { }; static std::size_t decide_number_of_edm_channels( - ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { + ttnn::ccl::CCLOpConfig const& ccl_op_config, std::size_t max_num_workers, bool enable_bidirectional) { return ccl_op_config.is_input_sharded() ? std::min( ccl_op_config.get_shard_grid_size(), std::min(max_num_workers, enable_bidirectional ? 8 : 4)) @@ -66,9 +66,9 @@ static std::size_t decide_number_of_edm_channels( struct ReduceScatterWorkerArgBuilder { ReduceScatterWorkerArgBuilder( - ccl::CCLOpConfig const& op_config, - ccl::RingTopology const& topology_config, - ccl::InterleavedTensorWorkerSlice const& worker_input_slice, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingTopology const& topology_config, + ttnn::ccl::InterleavedTensorWorkerSlice const& worker_input_slice, WorkerTransferInfo const& worker_transfer_info, uint32_t worker_idx, uint32_t link, @@ -134,7 +134,7 @@ struct ReduceScatterWorkerArgBuilder { } std::vector generate_receiver_kernel_rt_args( - ccl::WorkerXY edm_core, + ttnn::ccl::WorkerXY edm_core, uint32_t edm_core_semaphore_address, uint32_t edm_core_buffer_address, uint32_t link, @@ -229,7 +229,7 @@ struct ReduceScatterWorkerArgBuilder { } std::vector generate_sender_kernel_rt_args( - ccl::WorkerXY edm_core, + ttnn::ccl::WorkerXY edm_core, uint32_t edm_core_semaphore_address, uint32_t edm_core_buffer_address, uint32_t link, @@ -295,9 +295,9 @@ struct ReduceScatterWorkerArgBuilder { return args; } - ccl::RingTopology const topology_config; - ccl::CCLOpConfig const op_config; - ccl::InterleavedTensorWorkerSlice const worker_input_slice; + ttnn::ccl::RingTopology const topology_config; + ttnn::ccl::CCLOpConfig const op_config; + ttnn::ccl::InterleavedTensorWorkerSlice const worker_input_slice; WorkerTransferInfo const worker_transfer_info; uint32_t cb_num_pages_per_packet; uint32_t worker_sender_semaphore_address; @@ -321,13 +321,13 @@ struct EdmInterfaceAddresses { // For now - the mapping between workers and EDM channels is 1:1 static void add_worker_config_to_edm_builders( Device* device, - RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented - ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented + ttnn::ccl::CCLOpConfig const& op_config, std::vector const& worker_cores, uint32_t num_channels_per_edm, - std::vector& clockwise_edm_builders, - std::vector& counter_clockwise_edm_builders, + std::vector& clockwise_edm_builders, + std::vector& counter_clockwise_edm_builders, uint32_t worker_sender_semaphore_address, uint32_t worker_receiver_semaphore_address, @@ -340,13 +340,13 @@ static void add_worker_config_to_edm_builders( uint32_t global_worker_idx = c + num_channels_per_edm * link; uint32_t num_workers_per_eth_buffer = 1; - std::vector sender_worker_coords; - std::vector receiver_worker_coords; + std::vector sender_worker_coords; + std::vector receiver_worker_coords; for (uint32_t w = c * num_workers_per_eth_buffer; w < (c + 1) * num_workers_per_eth_buffer; ++w) { - sender_worker_coords.push_back(ccl::WorkerXY( + sender_worker_coords.push_back(ttnn::ccl::WorkerXY( device->worker_core_from_logical_core(worker_cores.at(w)).x, device->worker_core_from_logical_core(worker_cores.at(w)).y)); - receiver_worker_coords.push_back(ccl::WorkerXY( + receiver_worker_coords.push_back(ttnn::ccl::WorkerXY( device->worker_core_from_logical_core(worker_cores.at(w)).x, device->worker_core_from_logical_core(worker_cores.at(w)).y)); } @@ -359,7 +359,7 @@ static void add_worker_config_to_edm_builders( auto& sender_edm_builder = is_buffer_in_clockwise_direction_fn(c) ? clockwise_edm_builders.at(link) : counter_clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding sender EDM channel"); - ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = sender_edm_builder.add_sender_channel( worker_sender_semaphore_address, 1, // cw_edm_channel_num_messages_to_send_per_transfer.at(c) * (ring_size - 1), @@ -377,7 +377,7 @@ static void add_worker_config_to_edm_builders( ? counter_clockwise_edm_builders.at(link) : clockwise_edm_builders.at(link); log_trace(tt::LogOp, "Adding receiver EDM channel"); - ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = + ttnn::ccl::EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel( worker_receiver_semaphore_address, // Since we are in worker signal EDM termination mode, we don't need to set the actual number of @@ -396,11 +396,11 @@ static void add_worker_config_to_edm_builders( static std::tuple build_reduce_scatter_worker( tt_metal::Program& program, Device const* device, - ccl::RingTopology const& topology_config, - ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingTopology const& topology_config, + ttnn::ccl::CCLOpConfig const& op_config, ReduceScatterWorkerArgBuilder const& worker_arg_builder, - std::vector& cw_edm_builders, - std::vector& ccw_edm_builders, + std::vector& cw_edm_builders, + std::vector& ccw_edm_builders, EdmInterfaceAddresses const& edm_interface_addresses, CoreCoord const& worker_core, uint32_t num_edm_channels, @@ -427,7 +427,7 @@ static std::tuple build_reduce_scatter_worker( { CoreCoord const& receiver_edm = is_in_clockwise_direction ? topology_config.eth_receiver_cores.at(link) : topology_config.eth_sender_cores.at(link); - ccl::WorkerXY receiver_edm_noc_coord = ccl::WorkerXY( + ttnn::ccl::WorkerXY receiver_edm_noc_coord =ttnn::ccl::WorkerXY( device->ethernet_core_from_logical_core(receiver_edm).x, device->ethernet_core_from_logical_core(receiver_edm).y); const uint32_t edm_core_semaphore_address = @@ -484,7 +484,7 @@ static std::tuple build_reduce_scatter_worker( { CoreCoord sender_edm = is_in_clockwise_direction ? topology_config.eth_sender_cores.at(link) : topology_config.eth_receiver_cores.at(link); - ccl::WorkerXY const sender_edm_noc_coord = ccl::WorkerXY( + ttnn::ccl::WorkerXY const sender_edm_noc_coord =ttnn::ccl::WorkerXY( device->ethernet_core_from_logical_core(sender_edm).x, device->ethernet_core_from_logical_core(sender_edm).y); TT_ASSERT(sender_edm_noc_coord.y == 0 || sender_edm_noc_coord.y == 6); @@ -519,21 +519,21 @@ static std::tuple build_reduce_scatter_worker( } static CoreRangeSet select_worker_cores( - ccl::CCLOpConfig const& op_config, std::size_t num_links, std::size_t num_edm_channels) { + ttnn::ccl::CCLOpConfig const& op_config, std::size_t num_links, std::size_t num_edm_channels) { switch (op_config.get_topology()) { - case tt::tt_metal::ccl::Topology::Linear: + case ttnn::ccl::Topology::Linear: return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); - case tt::tt_metal::ccl::Topology::Ring: + case ttnn::ccl::Topology::Ring: return CoreRangeSet({CoreRange(CoreCoord(0, 0), CoreCoord(num_edm_channels - 1, num_links - 1))}); default: TT_ASSERT(false, "Unsupported topology"); return CoreRangeSet({}); }; } static WorkerTransferInfo compute_num_edm_messages_per_channel( - ccl::CCLOpConfig const& op_config, - RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented - std::vector const& cw_per_link_edm_builders, - std::vector const& ccw_per_link_edm_builders, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::RingReduceScatterTensorSlicer& tensor_slicer, // TODO: Update to Generic ReduceScatterSlicer when it is implemented + std::vector const& cw_per_link_edm_builders, + std::vector const& ccw_per_link_edm_builders, std::size_t const num_edm_channels, std::size_t const num_links, std::size_t const ring_size) { @@ -601,7 +601,7 @@ static uint32_t compute_maximum_worker_slice_in_bytes( } static bool is_cb_buffering_sufficient_to_avoid_deadlock( - ccl::InterleavedTensorWorkerSlice const& worker_slice, + ttnn::ccl::InterleavedTensorWorkerSlice const& worker_slice, uint32_t cb_src0_size_pages, uint32_t cb_dst0_size_pages, uint32_t cb_short_circuit_size_pages, @@ -627,7 +627,7 @@ static bool is_cb_buffering_sufficient_to_avoid_deadlock( static std::tuple create_worker_circular_buffers( Tensor const& input_tensor, - ccl::CCLOpConfig const& op_config, + ttnn::ccl::CCLOpConfig const& op_config, CoreRangeSet const& worker_core_range, uint32_t worker_pages_per_transfer, tt_metal::Program& program) { @@ -679,7 +679,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, - ccl::Topology topology) { + ttnn::ccl::Topology topology) { log_trace(tt::LogOp, "reduce_scatter_with_workers entry"); TT_ASSERT( input_tensors.at(0).get_legacy_shape()[scatter_split_dim] == @@ -691,12 +691,12 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( /////////////// Constants/Configuration /// Constants/Configuration - ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode = ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; - auto const& op_config = ccl::CCLOpConfig(input_tensors, output_tensors, topology); - std::unique_ptr input_tensor_config = - CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); - std::unique_ptr output_tensor_config = - CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); + ttnn::ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode =ttnn::ccl::EriscDataMoverBufferSharingMode::ROUND_ROBIN; + auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + std::unique_ptr input_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensors.at(0)); + std::unique_ptr output_tensor_config = + ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensors.at(0)); uint32_t per_step_dim_size = input_tensors.at(0).get_legacy_shape()[scatter_split_dim] / ring_size; uint32_t input_tensor_num_units_per_scatter_dim = per_step_dim_size / constants::TILE_WIDTH; // TODO: find the divisibility based on layout @@ -705,7 +705,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( bool enable_bidirectional = false; auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional); log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels); - auto edm_termination_mode = ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; + auto edm_termination_mode =ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; auto const& edm_builder = create_erisc_datamover_builder( num_edm_channels, op_config.get_page_size(), buffer_sharing_mode, edm_termination_mode); TT_ASSERT(num_edm_channels > 0); @@ -716,8 +716,8 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( std::map worker_defines; std::vector worker_receiver_kernels; std::vector worker_sender_kernels; - std::vector cw_per_link_edm_builders(num_links, edm_builder); - std::vector ccw_per_link_edm_builders(num_links, edm_builder); + std::vector cw_per_link_edm_builders(num_links, edm_builder); + std::vector ccw_per_link_edm_builders(num_links, edm_builder); bool rm = local_chip_tensor.get_layout() == Layout::ROW_MAJOR; if (rm) { @@ -731,7 +731,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const auto& device = local_chip_tensor.device(); auto const& topology_config = - ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); + ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); auto dim_slice_factors = Shape(std::vector(local_chip_tensor.get_legacy_shape().rank(), 1)); dim_slice_factors[-1] = ring_size; @@ -756,7 +756,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( cb_num_pages, cw_per_link_edm_builders.at(0).get_eth_buffer_size_bytes(), op_config.get_page_size()); - auto tensor_slicer = ccl::RingReduceScatterTensorSlicer( + auto tensor_slicer =ttnn::ccl::RingReduceScatterTensorSlicer( local_chip_tensor, local_chip_output_tensor, scatter_split_dim, @@ -853,7 +853,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( } // Generate the EDM kernels - ccl::generate_edm_kernels_for_ring_or_linear_topology( + ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( program, device, topology_config, diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp index bf86efb8cc3..850bbfadca4 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_reader.cpp @@ -8,12 +8,12 @@ #include "dataflow_api.h" #include "debug/assert.h" #include "tensix_types.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -using tt::tt_metal::ccl::coord_t; -using tt::tt_metal::ccl::WorkerXY; +using ttnn::ccl::coord_t; +using ttnn::ccl::WorkerXY; struct reduce_scatter_reader_common_args_t { reduce_scatter_reader_common_args_t(uint32_t& arg_idx) : @@ -34,10 +34,10 @@ struct reduce_scatter_reader_common_args_t { edm_core_buffer_address(get_arg_val(arg_idx++)), num_concurrent_workers(get_arg_val(arg_idx++)), - input_tensor_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), - tensor_slice_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), - worker_slice_shape(tt::tt_metal::ccl::coord_from_args(arg_idx)), - worker_slice_offset(tt::tt_metal::ccl::coord_from_args(arg_idx)), + input_tensor_shape(ttnn::ccl::coord_from_args(arg_idx)), + tensor_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), + worker_slice_shape(ttnn::ccl::coord_from_args(arg_idx)), + worker_slice_offset(ttnn::ccl::coord_from_args(arg_idx)), total_eltwise_kernel_num_pages(get_arg_val(arg_idx++)) { ASSERT(full_chunk_num_pages > 0); @@ -322,7 +322,7 @@ void kernel_main() { if (!last_worker_message_to_edm) { noc_semaphore_inc( eth_receiver_l1_semaphore_noc_addr, - tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); } if (n_pages < args.half_cb_n_pages) { uint32_t num_filler_pages = args.half_cb_n_pages - n_pages; @@ -351,6 +351,6 @@ void kernel_main() { noc_semaphore_inc( eth_receiver_l1_semaphore_noc_addr, - tt::tt_metal::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); DEBUG_STATUS("DONE"); } diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp index 4e93fea738e..ac8647cb584 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/kernels/worker_interleaved_ring_reduce_scatter_sender.cpp @@ -5,10 +5,10 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" -using tt::tt_metal::ccl::coord_t; +using ttnn::ccl::coord_t; void kernel_main() { constexpr bool is_sharded = get_compile_time_arg_val(0) == 1; @@ -27,9 +27,9 @@ void kernel_main() { uint32_t const half_cb_n_pages = get_arg_val(arg_idx++); uint32_t const num_concurrent_workers = get_arg_val(arg_idx++); - coord_t const& output_tensor_shape = tt::tt_metal::ccl::coord_from_args(arg_idx); - coord_t const& worker_slice_shape = tt::tt_metal::ccl::coord_from_args(arg_idx); - coord_t worker_slice_base_offset = tt::tt_metal::ccl::coord_from_args(arg_idx); + coord_t const& output_tensor_shape = ttnn::ccl::coord_from_args(arg_idx); + coord_t const& worker_slice_shape = ttnn::ccl::coord_from_args(arg_idx); + coord_t worker_slice_base_offset = ttnn::ccl::coord_from_args(arg_idx); uint32_t total_eltwise_kernel_num_pages = get_arg_val(arg_idx++); @@ -80,7 +80,7 @@ void kernel_main() { send_chunk(cb_in, n_pages, page_size, eth_l1_sender_base_noc_addr); noc_semaphore_inc( eth_l1_sender_semaphore_addr, - tt::tt_metal::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); + ttnn::ccl::EriscDataMoverWorkerSignal::NEXT_MESSAGE_AVAILABLE); if (i != 0) { total_lifetime_cb_pages_popped_from_math += n_pages; } @@ -144,5 +144,5 @@ void kernel_main() { noc_semaphore_wait(writer_send_semaphore_addr_ptr, 1); noc_semaphore_set(writer_send_semaphore_addr_ptr, 0); noc_semaphore_inc( - eth_l1_sender_semaphore_addr, tt::tt_metal::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); + eth_l1_sender_semaphore_addr, ttnn::ccl::EriscDataMoverWorkerSignal::TERMINATE_IMMEDIATELY); } diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp index 4111720ee48..8af69597b2b 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.cpp @@ -5,7 +5,7 @@ #include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "tt_metal/host_api.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" @@ -65,14 +65,14 @@ std::vector reduce_scatter_impl( const uint32_t scatter_dim, const uint32_t num_links, const MemoryConfig& output_mem_config, - const ccl::Topology topology) { + const ttnn::ccl::Topology topology) { TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); std::vector output_tensors; output_tensors.reserve(input_tensors.size()); std::vector ops; ops.reserve(input_tensors.size()); - bool is_ring = topology == ccl::Topology::Ring; + bool is_ring = topology ==ttnn::ccl::Topology::Ring; for (uint32_t i = 0; i < input_tensors.size(); ++i) { bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (input_tensors.size() - 1); bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; @@ -116,7 +116,7 @@ std::vector reduce_scatter( const MemoryConfig& output_mem_config) { ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); return reduce_scatter_impl( - input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config, ccl::Topology::Ring); + input_tensors, binary_op_type, scatter_dim, num_links, output_mem_config,ttnn::ccl::Topology::Ring); } }; // namespace tt_metal diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp index 7bfc7c19001..357902e319b 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp @@ -5,8 +5,8 @@ #pragma once #include "ttnn/experimental/tt_dnn/op_library/run_operation.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" @@ -23,7 +23,7 @@ struct ReduceScatter { const std::optional receiver_device_id; const std::optional sender_device_id; const MemoryConfig output_mem_config; - const ccl::Topology topology; + const ttnn::ccl::Topology topology; void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; @@ -51,7 +51,7 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( const uint32_t ring_index, const std::optional receiver_device_id, const std::optional sender_device_id, - ccl::Topology topology); + ttnn::ccl::Topology topology); } }; // namespace ccl diff --git a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index 0e66f17e49c..ce168738c82 100644 --- a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -20,7 +20,6 @@ #include "ttnn/experimental/tt_dnn/op_library/non_zero_indices/non_zero_indices_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/sharded/sharded_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/sharded_partial/sharded_op_partial.hpp" -#include "ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/ccl/reduce_scatter/reduce_scatter_op.hpp" @@ -462,15 +461,6 @@ namespace tt::tt_metal::detail{ ); // ---------- Multi-Device ops ---------- - // All Gather - m_tensor.def("all_gather", &all_gather, - py::arg("input_tensors"), py::arg("dim"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - R"doc(Performs all gather on a list of tensors that form one tensor that is distributed across devices. The output is a list of a tensor which has been duplciated across the input devices.)doc" - ); - m_tensor.def("line_all_gather", &line_all_gather, - py::arg("input_tensors"), py::arg("dim"), py::arg("num_links") = 1, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - R"doc(Performs all gather on a list of tensors that form one tensor that is distributed across devices. The output is a list of a tensor which has been duplciated across the input devices.)doc" - ); // Reduce Scatter m_tensor.def("reduce_scatter", &reduce_scatter, diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/README.md b/ttnn/cpp/ttnn/operations/ccl/README.md similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/README.md rename to ttnn/cpp/ttnn/operations/ccl/README.md diff --git a/ttnn/cpp/ttnn/operations/ccl.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp similarity index 79% rename from ttnn/cpp/ttnn/operations/ccl.hpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp index cf3152e9563..94c84f41a38 100644 --- a/ttnn/cpp/ttnn/operations/ccl.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_op.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "ttnn/cpp/ttnn/multi_device.hpp" namespace ttnn { @@ -18,7 +18,7 @@ struct ExecuteAllGather { const uint32_t dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt) { - return tt::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); + return ttnn::operations::ccl::all_gather(input_tensor, dim, num_links, memory_config); } }; diff --git a/ttnn/cpp/pybind11/operations/ccl.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp similarity index 89% rename from ttnn/cpp/pybind11/operations/ccl.hpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp index cb4c680d2cc..c3aa4015c49 100644 --- a/ttnn/cpp/pybind11/operations/ccl.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.hpp @@ -8,7 +8,7 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/ccl.hpp" +#include "ttnn/operations/ccl/all_gather/all_gather_op.hpp" #include "ttnn/types.hpp" namespace py = pybind11; @@ -20,7 +20,7 @@ namespace ccl { namespace detail { template -void bind_ccl_operation(py::module& module, const ccl_operation_t& operation, const char* doc) { +void bind_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { bind_registered_operation( module, operation, @@ -43,8 +43,8 @@ void bind_ccl_operation(py::module& module, const ccl_operation_t& operation, co } // namespace detail -void py_module(py::module& module) { - detail::bind_ccl_operation( +void py_bind_all_gather(py::module& module) { + detail::bind_all_gather( module, ttnn::all_gather, R"doc(all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp similarity index 65% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index f6b986840d7..725a79c26b7 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/math.hpp" #include "tt_metal/host_api.hpp" @@ -11,9 +11,7 @@ #include "eth_l1_address_map.h" -namespace tt { - -namespace tt_metal { +namespace ttnn { AllGatherMode choose_all_gather_mode(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim) { bool is_sharded = input_tensor.is_sharded(); @@ -63,10 +61,10 @@ void AllGather::validate(const std::vector &input_tensors) const { } } -std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { +std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { auto shape = input_tensors[0].get_legacy_shape(); shape[this->dim] *= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector AllGather::create_output_tensors(const std::vector &input_tensors) const { @@ -85,7 +83,7 @@ std::vector AllGather::create_output_tensors(const std::vector & } operation::ProgramWithCallbacks AllGather::create_program(const std::vector & input_tensors, std::vector &output_tensors) const { - AllGatherMode all_gather_mode = tt::tt_metal::choose_all_gather_mode(input_tensors.at(0), output_tensors.at(0), dim); + AllGatherMode all_gather_mode = choose_all_gather_mode(input_tensors.at(0), output_tensors.at(0), dim); switch (all_gather_mode) { case AllGatherMode::RING_INTERLEAVED: case AllGatherMode::SINGLE_TILE_HIGH_WIDTH_SHARDED: @@ -99,49 +97,6 @@ operation::ProgramWithCallbacks AllGather::create_program(const std::vector all_gather_impl(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config, const all_gather_op::Topology topology) { - - TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); - - std::vector output_tensors = std::vector(input_tensors.size()); - - bool is_ring = topology == all_gather_op::Topology::Ring; - uint32_t num_inputs = static_cast(input_tensors.size()); - for (uint32_t i = 0; i < input_tensors.size(); ++i) { - output_tensors[i] = Tensor(operation::get_workers_for_op_output({input_tensors[i]})); - // Extract these tensors in the main thread, since they're used to get the sender and receiver device ids - // Dont get the device in the main thread, since it can cause stalls in async mode. - const Tensor& tensor_on_receiver = input_tensors[(i + 1) % num_inputs]; - const Tensor& tensor_on_sender = input_tensors[i == 0 ? num_inputs - 1 : i - 1]; - // Package output in vector, to populate it with launch_op - std::vector output_for_curr_device = {output_tensors[i]}; - operation::launch_op( - [is_ring, dim, num_links, i, num_inputs, output_mem_config, topology] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - bool is_last_chip_in_clockwise_direction = is_ring ? false : i == (num_inputs - 1); - bool is_last_chip_in_counter_clockwise_direction = is_ring ? false : i == 0; - - std::optional receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(input_tensors.at(1).device()->id()); - std::optional sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(input_tensors.at(2).device()->id()); - return operation::run(AllGather{dim, num_links, num_inputs, i, receiver_device_id, sender_device_id, output_mem_config,topology}, {input_tensors.at(0)}); - }, - {input_tensors[i], tensor_on_receiver, tensor_on_sender}, output_for_curr_device); - } - return output_tensors; -} - -std::vector all_gather(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config) { - return all_gather_impl(input_tensors, dim, num_links, output_mem_config, all_gather_op::Topology::Ring); -} -std::vector line_all_gather(const std::vector& input_tensors, const uint32_t dim, const uint32_t num_links, const MemoryConfig& output_mem_config) { - return all_gather_impl(input_tensors, dim, num_links, output_mem_config, all_gather_op::Topology::Linear); -} - -} // namespace tt_metal - namespace operations { namespace ccl { @@ -175,7 +130,7 @@ Tensor all_gather( } return operation::run( - AllGather{ + ttnn::AllGather{ dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config())}, {input_tensor}); }, @@ -184,7 +139,8 @@ Tensor all_gather( return output_tensors.at(0); } + } // namespace ccl } // namespace operations -} // namespace tt +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp similarity index 93% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index a665ab4c30a..c3469068f73 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -8,11 +8,11 @@ #include "common/core_coord.h" #include "impl/buffers/buffer.hpp" #include "tensor/tensor.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include "ttnn/experimental/tt_dnn/op_library/run_operation.hpp" @@ -20,9 +20,7 @@ #include #include -namespace tt { - -namespace tt_metal { +namespace ttnn { enum AllGatherMode { RING_INTERLEAVED, @@ -45,7 +43,7 @@ class AllGatherConfig { semaphore_size(32), ring_size(ring_size), - erisc_handshake_address(round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)), + erisc_handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)), topology(topology), enable_bidirectional(topology == all_gather_op::Topology::Ring), @@ -92,7 +90,7 @@ class AllGatherConfig { this->eth_buffers_l1_base_byte_address = this->eth_sems_l1_base_byte_address + this->semaphore_offset; uint32_t const page_size = input_tensor.buffer()->page_size(); - this->eth_buffer_size = round_down((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions), page_size); + this->eth_buffer_size = tt::round_down((total_l1_buffer_space - this->semaphore_offset) / (this->num_eth_buffers * num_duplicate_directions), page_size); TT_FATAL(eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS); TT_FATAL(this->eth_buffer_size * (this->num_eth_buffers * num_duplicate_directions) + this->semaphore_offset <= total_l1_buffer_space); @@ -184,24 +182,24 @@ class AllGatherConfig { struct RingInterleavedAllGatherVariantConfig : public AllGatherConfig { - std::string const& send_reader_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp"; - std::string const& sender_writer_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp"; - std::string const& receiver_reader_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp"; - std::string const& receiver_writer_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp"; + std::string const& send_reader_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp"; + std::string const& sender_writer_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp"; + std::string const& receiver_reader_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp"; + std::string const& receiver_writer_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp"; }; struct SingleTileHighWidthShardedAllGatherVariantConfig : public AllGatherConfig { - std::string const& send_reader_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp"; - std::string const& sender_writer_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp"; - std::string const& receiver_reader_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp"; - std::string const& receiver_writer_kernel_path = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp"; + std::string const& send_reader_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp"; + std::string const& sender_writer_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp"; + std::string const& receiver_reader_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp"; + std::string const& receiver_writer_kernel_path = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp"; }; struct FullWorkerGridShardedAllGatherVariantConfig : public AllGatherConfig { - std::string const& reader_kernel = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_all_shard_workers_ring_gather_reader.cpp"; - std::string const& writer_kernel = "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_all_shard_workers_ring_gather_writer.cpp"; + std::string const& reader_kernel = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_all_shard_workers_ring_gather_reader.cpp"; + std::string const& writer_kernel = "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_all_shard_workers_ring_gather_writer.cpp"; }; struct AllGather { @@ -215,7 +213,7 @@ struct AllGather { const all_gather_op::Topology topology; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; @@ -242,24 +240,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers( const std::optional sender_device_id, all_gather_op::Topology topology); - -std::vector all_gather_impl( - const std::vector& input_tensors, - const uint32_t dim, - const uint32_t num_links, - const MemoryConfig& output_mem_config, - const all_gather_op::Topology topology); -std::vector all_gather( - const std::vector &input_tensors, - const uint32_t dim, - const uint32_t num_links = 1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -std::vector line_all_gather( - const std::vector &input_tensors, - const uint32_t dim, - const uint32_t num_links = 1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - struct ShardedAllGatherConfig { ShardedAllGatherConfig(Tensor const& input_tensor, Tensor const& output_tensor, uint32_t dim) { @@ -290,7 +270,7 @@ struct ShardedAllGatherConfig { break; }; - Shape const& output_shape = output_tensor.get_legacy_shape(); + tt::tt_metal::Shape const& output_shape = output_tensor.get_legacy_shape(); bool multiple_dims_are_multi_tile = std::count_if(output_shape.begin(), output_shape.end(), [](uint32_t s) { return s > 1; }) > 1; this->requires_post_all_gather_reshard = !multiple_dims_are_multi_tile; } @@ -754,8 +734,6 @@ struct FullWorkerGridShardAddrGenArgGenerator { bool initialized; }; -} // namespace tt_metal - namespace operations { namespace ccl { @@ -768,4 +746,4 @@ Tensor all_gather( } // namespace ccl } // namespace operations -} // namespace tt +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp similarity index 96% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp index bf7dcf378c1..a3eb5f11cd4 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp @@ -4,7 +4,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { constexpr uint32_t num_transfers = get_compile_time_arg_val(0); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp similarity index 98% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp index c148feb6a39..754c939bc8a 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp @@ -4,7 +4,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { const uint32_t dst_addr = get_arg_val(0); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp similarity index 98% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp index bc95272bce9..6d386287b66 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp @@ -4,7 +4,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { const uint32_t src_addr = get_arg_val(0); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp similarity index 97% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp index 6af895ebbb4..682f9b684f1 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp @@ -4,7 +4,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp similarity index 94% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index da2b0b7d2fb..043853e7876 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -4,13 +4,13 @@ #include "dataflow_api.h" #include "debug/assert.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -using tt::tt_metal::ccl::ShardType; -using tt::tt_metal::ccl::UNINITIALIZED_VALUE_U16; -using tt::tt_metal::ccl::UNINITIALIZED_VALUE_U32; -using tt::tt_metal::ccl::WorkerXY; +using ttnn::ccl::ShardType; +using ttnn::ccl::UNINITIALIZED_VALUE_U16; +using ttnn::ccl::UNINITIALIZED_VALUE_U32; +using ttnn::ccl::WorkerXY; // Only workers on local worker core, hence no uint64_t noc addresses template @@ -18,7 +18,7 @@ struct FullWorkerGridShardAddrGen { FullWorkerGridShardAddrGen() = default; FORCE_INLINE static void build_with_placement_new( FullWorkerGridShardAddrGen* placement_new_address, const uint32_t arg_index) { - tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs input_args; + ttnn::ccl::FullWorkerGridShardAddrGenArgs input_args; uint32_t curr_arg_index = arg_index; input_args.tile_size_in_bytes = get_arg_val(curr_arg_index++); @@ -54,7 +54,7 @@ struct FullWorkerGridShardAddrGen { } FullWorkerGridShardAddrGen( - uint8_t num_args_consumed, tt::tt_metal::ccl::FullWorkerGridShardAddrGenArgs const& input_args) : + uint8_t num_args_consumed, ttnn::ccl::FullWorkerGridShardAddrGenArgs const& input_args) : dest_cores(input_args.dest_cores), tile_size_in_bytes(input_args.tile_size_in_bytes), shards_start_address(input_args.shards_start_address), @@ -101,7 +101,7 @@ struct FullWorkerGridShardAddrGen { } FORCE_INLINE void advance() { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance( this->curr_shard_tile_x, this->curr_shard_tile_y, this->curr_tile_index, @@ -115,7 +115,7 @@ struct FullWorkerGridShardAddrGen { } FORCE_INLINE void advance_to_next_tile_row() { - tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( + ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row( this->curr_shard_tile_x, this->curr_shard_tile_y, this->curr_tile_index, @@ -164,7 +164,7 @@ struct ShardAddrGen final { ShardAddrGen() = default; FORCE_INLINE static void build_with_placement_new(ShardAddrGen* placement_new_address, const uint32_t arg_index) { - tt::tt_metal::ccl::ShardAddrGenArgs input_args; + ttnn::ccl::ShardAddrGenArgs input_args; uint32_t curr_arg_index = arg_index; input_args.is_clockwise = bool(get_arg_val(curr_arg_index++) == 1); @@ -196,7 +196,7 @@ struct ShardAddrGen final { // This addr gen will dump all tiles from an input shard contiguously, and dump the // next input shard contiguously after it. This approach depends on a follow up // - ShardAddrGen(uint8_t num_args_consumed, tt::tt_metal::ccl::ShardAddrGenArgs const& input_args) : + ShardAddrGen(uint8_t num_args_consumed, ttnn::ccl::ShardAddrGenArgs const& input_args) : dest_cores(input_args.dest_cores), shards_start_address(input_args.shards_start_address), shard_size_in_bytes(input_args.shard_size_in_bytes), @@ -224,7 +224,7 @@ struct ShardAddrGen final { // correc order per worker FORCE_INLINE void advance() { if constexpr (TYPE == ShardType::Width or TYPE == ShardType::Height) { - tt::tt_metal::ccl::all_gather::addr_gen_advance_width_sharded( + ttnn::ccl::all_gather::addr_gen_advance_width_sharded( this->curr_core_chunk_index, this->curr_worker_index, this->contiguous_chunk_count, @@ -532,11 +532,11 @@ FORCE_INLINE void read_chunk_from_output_tensor( template FORCE_INLINE void read_chunk_from_output_tensor_v2( uint32_t& curr_page_idx, - tt::tt_metal::ccl::coord_t& offset_into_worker_slice, - const tt::tt_metal::ccl::coord_t& worker_slice_shape, + ttnn::ccl::coord_t& offset_into_worker_slice, + const ttnn::ccl::coord_t& worker_slice_shape, // In tiles for tile layout - const tt::tt_metal::ccl::coord_t& tensor_shape, + const ttnn::ccl::coord_t& tensor_shape, const uint32_t cb_id, const AddrGen& s, const uint32_t num_pages, @@ -581,11 +581,11 @@ FORCE_INLINE void read_chunk_from_output_tensor_v2( template FORCE_INLINE void write_chunk_v2( uint32_t& curr_page_idx, - tt::tt_metal::ccl::coord_t& offset_into_worker_slice, - const tt::tt_metal::ccl::coord_t& worker_slice_shape, + ttnn::ccl::coord_t& offset_into_worker_slice, + const ttnn::ccl::coord_t& worker_slice_shape, // In tiles for tile layout - const tt::tt_metal::ccl::coord_t& tensor_shape, + const ttnn::ccl::coord_t& tensor_shape, uint32_t cb_id, const AddrGen& d, const uint32_t num_pages, diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp similarity index 96% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp index 12d03dce709..340c1fe56ba 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp @@ -6,7 +6,7 @@ #include "dataflow_api.h" #include "debug/assert.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { // TODO: Update the interleaver receive reader kernel invocation to just be able to use this diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp similarity index 95% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp index 3e1cedd0310..61a2b29495a 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp @@ -5,7 +5,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { constexpr ShardType shard_type = static_cast(get_compile_time_arg_val(0)); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp similarity index 95% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp index 2d1ed297ad1..e6b8cde34ab 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp @@ -5,7 +5,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { constexpr ShardType shard_type = static_cast(get_compile_time_arg_val(0)); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp similarity index 96% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp index 414c39be1b3..5de984ea7f5 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp @@ -5,7 +5,7 @@ #include #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" void kernel_main() { constexpr ShardType shard_type = static_cast(get_compile_time_arg_val(0)); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp similarity index 94% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp rename to ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index caafa0f28dc..2280d7b352e 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -8,10 +8,10 @@ #include "eth_l1_address_map.h" #include "impl/buffers/buffer.hpp" #include "tensor/tensor_impl.hpp" -#include "ttnn/experimental/tt_dnn/op_library/all_gather/all_gather_op.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" #include "ttnn/experimental/tt_dnn/op_library/math.hpp" #include "ttnn/experimental/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" @@ -22,9 +22,7 @@ using namespace tt::constants; -namespace tt { - -namespace tt_metal { +namespace ttnn { using namespace ccl; @@ -66,13 +64,13 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_FATAL(!(receiver_device_id == std::nullopt && sender_device_id == std::nullopt), "At least one of receiver_device_id or sender_device_id must be specified"); bool is_linear = topology == all_gather_op::Topology::Linear; - std::unique_ptr input_tensor_config = CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); - std::unique_ptr output_tensor_config = CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); + std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); + std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); - tt_metal::Program program{}; + tt::tt_metal::Program program{}; const auto& device = input_tensor.device(); auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology); - auto const& topology_config = ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); + auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index); auto const& sharding_info = ShardedAllGatherConfig(input_tensor, output_tensor, dim); bool enable_print = false; // ring_index == 0 @@ -96,8 +94,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "input_buffer->shard_spec().tensor2d_shape[1]: {}", input_buffer->shard_spec().tensor2d_shape[1]); } const uint32_t max_buffer_per_chunk = is_sharded ? - round_down(all_gather_config.get_eth_buffer_size(), shard_size_in_bytes): - round_down(all_gather_config.get_eth_buffer_size(), input_tensor_config->get_page_size()); + tt::round_down(all_gather_config.get_eth_buffer_size(), shard_size_in_bytes): + tt::round_down(all_gather_config.get_eth_buffer_size(), input_tensor_config->get_page_size()); const uint32_t max_pages_per_chunk = is_sharded ? max_buffer_per_chunk / shard_size_in_bytes : max_buffer_per_chunk / input_page_size; @@ -107,7 +105,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "max_pages_per_chunk: {}", max_pages_per_chunk); bool rm = input_tensor.get_layout() == Layout::ROW_MAJOR; bool width = input_tensor.get_legacy_shape().rank() - 1 == dim; - DataFormat df = tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::DataFormat df = datatype_to_dataformat_converter(input_tensor.get_dtype()); uint32_t global_num_workers = all_gather_config.get_num_eth_buffers_per_edm() * num_links; @@ -170,9 +168,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); counter_clockwise_edm_builders.emplace_back( - all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); + all_gather_config.get_eth_buffer_size(), all_gather_config.get_erisc_handshake_address(), edm_sem_addrs_per_link.at(link), edm_buffer_addrs_per_link.at(link), ttnn::ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED); } for (uint32_t direction = 0; direction < num_full_send_directions; direction++) { @@ -233,7 +231,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& pages_per_link.at(i)++; } - auto tensor_slicer = ccl::InterleavedRingAllGatherTensorSlicer ( + auto tensor_slicer = ttnn::ccl::InterleavedRingAllGatherTensorSlicer ( input_tensor, output_tensor, dim, @@ -259,20 +257,20 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "input_page_size: {}", input_page_size); uint32_t cb_num_pages = 2 * max_pages_per_chunk; log_trace(tt::LogOp, "cb_num_pages: {}", cb_num_pages); - uint32_t src0_cb_index = CB::c_in0; - tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(cb_num_pages * cb_page_size, {{src0_cb_index, df}}) + uint32_t src0_cb_index = tt::CB::c_in0; + CircularBufferConfig cb_src0_config = CircularBufferConfig(cb_num_pages * cb_page_size, {{src0_cb_index, df}}) .set_page_size(src0_cb_index, cb_page_size); CBHandle cb_src0_sender_workers = CreateCircularBuffer(program, sender_workers, cb_src0_config); CBHandle cb_src0_receiver_workers = CreateCircularBuffer(program, receiver_workers, cb_src0_config); // This semaphore is used by the receiver core to tell workers that data is available to read - auto receiver_worker_semaphore_addr = tt_metal::CreateSemaphore(program, receiver_workers, 0); + auto receiver_worker_semaphore_addr = CreateSemaphore(program, receiver_workers, 0); // This semaphore is used by the receiver core to tell the worker sender writer that sender buffer is available to write to - auto sender_worker_writer_semaphore_addr = tt_metal::CreateSemaphore(program, sender_workers, 0); + auto sender_worker_writer_semaphore_addr = CreateSemaphore(program, sender_workers, 0); // This semaphore is used by the worker receiver writer to tell the worker sender reader that data has been committed to memory // This is currently a running counter of how many chunks were committed since the sender worker never decrements this buffer // Potentially avoid overflow by having it actually decrement (using noc atomic inc with value of -1) - auto sender_worker_reader_semaphore_addr = tt_metal::CreateSemaphore(program, sender_workers, 0); + auto sender_worker_reader_semaphore_addr = CreateSemaphore(program, sender_workers, 0); // Rename this the _channel std::vector pages_per_buffer; @@ -344,7 +342,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -398,7 +396,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& receiver_semaphores_base_address.push_back(all_gather_config.get_eth_sems_l1_base_byte_address() + b * all_gather_config.get_semaphore_size()); link_buffer_receiver_addresses.push_back(all_gather_config.get_eth_buffers_l1_base_byte_address() + b * all_gather_config.get_eth_buffer_size()); } - std::vector sender_eth_sem_addrs; sender_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); std::vector sender_eth_buffer_addrs; sender_eth_buffer_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); std::vector receiver_eth_sem_addrs; receiver_eth_sem_addrs.reserve(all_gather_config.get_num_eth_buffers_per_edm()); @@ -410,11 +407,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector receiver_worker_coords; for (uint32_t w = b * num_workers_per_eth_buffer; w < (b + 1) * num_workers_per_eth_buffer; ++w) { sender_worker_coords.push_back( - ccl::WorkerXY( + ttnn::ccl::WorkerXY( device->worker_core_from_logical_core(sender_worker_cores.at(w)).x, device->worker_core_from_logical_core(sender_worker_cores.at(w)).y)); receiver_worker_coords.push_back( - ccl::WorkerXY( + ttnn::ccl::WorkerXY( device->worker_core_from_logical_core(receiver_worker_cores.at(w)).x, device->worker_core_from_logical_core(receiver_worker_cores.at(w)).y)); } @@ -556,7 +553,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -579,8 +576,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -619,17 +616,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector const& worker_send_reader_rt_args = build_worker_send_reader_rt_args(); std::string const& send_reader_kernel_path = is_sharded ? - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp" : - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp"; - KernelHandle worker_reader_sender_kernel_id = tt_metal::CreateKernel( + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_reader.cpp" : + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp"; + KernelHandle worker_reader_sender_kernel_id = tt::tt_metal::CreateKernel( program, send_reader_kernel_path, sender_worker_cores.at(b), - tt_metal::ReaderDataMovementConfig(worker_send_reader_ct_args, worker_defines)); + tt::tt_metal::ReaderDataMovementConfig(worker_send_reader_ct_args, worker_defines)); worker_reader_sender_kernels.push_back(worker_reader_sender_kernel_id); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_reader_sender_kernel_id, sender_worker_cores.at(b), @@ -723,7 +720,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& global_worker_index); auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -738,8 +735,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -794,17 +791,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector const& worker_sender_writer_rt_args = build_worker_sender_writer_rt_args(); std::string const& sender_writer_kernel_path = is_sharded ? - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp" : - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp"; - KernelHandle worker_sender_writer_kernel_id = tt_metal::CreateKernel( + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_send_writer.cpp" : + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp"; + KernelHandle worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( program, sender_writer_kernel_path, sender_worker_cores.at(b), - tt_metal::WriterDataMovementConfig(worker_sender_writer_ct_args, worker_defines)); + tt::tt_metal::WriterDataMovementConfig(worker_sender_writer_ct_args, worker_defines)); worker_writer_sender_kernels.push_back(worker_sender_writer_kernel_id); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_sender_writer_kernel_id, sender_worker_cores.at(b), @@ -922,7 +919,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& CoreCoord const& worker_eth_receiver_core = is_clockwise_direction ? eth_receiver_cores.at(i) : eth_sender_cores.at(i); auto input_tensor_shard_arg_generator = InputTensorShardAddrGenArgGenerator( device, - dynamic_cast(input_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -971,17 +968,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector worker_receiver_reader_rt_args = build_worker_receiver_reader_rt_args(); std::string const& receiver_reader_kernel_path = is_sharded ? - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp" : - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp"; - KernelHandle worker_receiver_reader_kernel_id = tt_metal::CreateKernel( + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_reader.cpp" : + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_reader.cpp"; + KernelHandle worker_receiver_reader_kernel_id = tt::tt_metal::CreateKernel( program, receiver_reader_kernel_path, receiver_worker_cores.at(b), - tt_metal::ReaderDataMovementConfig(worker_receiver_reader_ct_args, worker_defines)); + tt::tt_metal::ReaderDataMovementConfig(worker_receiver_reader_ct_args, worker_defines)); worker_reader_receiver_kernels.push_back(worker_receiver_reader_kernel_id); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_receiver_reader_kernel_id, receiver_worker_cores.at(b), @@ -1077,8 +1074,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& OutputTensorShardAddrGenArgGenerator output_tensor_shard_arg_generator( all_gather_config, device, - dynamic_cast(input_tensor_config.get()), - dynamic_cast(output_tensor_config.get()), + dynamic_cast(input_tensor_config.get()), + dynamic_cast(output_tensor_config.get()), ring_index, ring_size, global_num_workers, @@ -1123,17 +1120,17 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& std::vector worker_receive_writer_rt_args = build_worker_receive_writer_rt_args(); std::string const& receiver_writer_kernel_path = is_sharded ? - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp" : - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/all_gather/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp"; - KernelHandle worker_receive_writer_kernel_id = tt_metal::CreateKernel( + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_sharded_ring_gather_receive_writer.cpp" : + "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp"; + KernelHandle worker_receive_writer_kernel_id = tt::tt_metal::CreateKernel( program, receiver_writer_kernel_path, receiver_worker_cores.at(b), - tt_metal::WriterDataMovementConfig(worker_receive_writer_ct_args, worker_defines)); + tt::tt_metal::WriterDataMovementConfig(worker_receive_writer_ct_args, worker_defines)); worker_writer_receiver_kernels.push_back(worker_receive_writer_kernel_id); - tt_metal::SetRuntimeArgs( + tt::tt_metal::SetRuntimeArgs( program, worker_receive_writer_kernel_id, receiver_worker_cores.at(b), @@ -1159,7 +1156,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& } } // num_full_send_directions - ccl::generate_edm_kernels_for_ring_or_linear_topology( + ttnn::ccl::generate_edm_kernels_for_ring_or_linear_topology( program, device, topology_config, @@ -1218,6 +1215,4 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; } -} // namespace tt_metal - -} // namespace tt +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp similarity index 97% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.cpp rename to ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 4a7563a72ab..16ea367a731 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -8,8 +8,7 @@ #include "ccl_host_datastructures.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn { namespace ccl { std::unique_ptr CclOpTensorConfig::build_all_gather_tensor_config(Tensor const& tensor) { @@ -21,7 +20,7 @@ std::unique_ptr CclOpTensorConfig::build_all_gather_tensor_co } void generate_edm_kernels_for_ring_or_linear_topology( - tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, RingTopology const& topology_config, std::vector const& clockwise_edm_builders, @@ -74,7 +73,7 @@ void generate_edm_kernels_for_ring_or_linear_topology( } KernelHandle generate_edm_kernel( - tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, @@ -90,13 +89,13 @@ KernelHandle generate_edm_kernel( log_trace(tt::LogOp, "\t{}", s); } - auto eth_sender_kernel = tt_metal::CreateKernel( + auto eth_sender_kernel =tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp", + "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp", eth_core, - tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); + tt::tt_metal::EthernetConfig{.noc = noc_id, .compile_args = eth_sender_ct_args}); - tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_clockwise_kernel_rt_args); + tt::tt_metal::SetRuntimeArgs(program, eth_sender_kernel, eth_core, edm_clockwise_kernel_rt_args); std::stringstream ss; ss << "EDM ARGS:\n"; @@ -215,10 +214,10 @@ RingReduceScatterTensorSlicer::RingReduceScatterTensorSlicer( input_tensor.get_legacy_shape()[2]}; } else { this->flattened_tensor_shape = tt_xy_pair{ - input_tensor.get_legacy_shape()[3] / constants::TILE_WIDTH, + input_tensor.get_legacy_shape()[3] /tt::constants::TILE_WIDTH, (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * input_tensor.get_legacy_shape()[2]) / - constants::TILE_HEIGHT}; + tt::constants::TILE_HEIGHT}; } this->worker_slice_offsets = compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); TT_ASSERT(this->worker_slice_offsets.size() == this->worker_slice_shapes.size()); @@ -302,7 +301,7 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape } std::vector RingReduceScatterTensorSlicer::create_worker_slice_shapes_for_tile_layout( - Shape const& tensor_shape, + tt::tt_metal::Shape const& tensor_shape, tt_xy_pair const& tensor_slice_shape_in_tiles, uint32_t num_workers, uint32_t max_slice_size_in_pages, @@ -450,7 +449,7 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape largest_worker_slice_shape.y); log_trace(tt::LogOp, "max_slice_size_in_tiles={}", max_slice_size_in_tiles); auto get_padded_worker_slice_size_in_tiles = [](tt_xy_pair const& worker_slice_shape, uint32_t half_cb_n_pages) { - return round_up(worker_slice_shape.x * worker_slice_shape.y, half_cb_n_pages); + return tt::round_up(worker_slice_shape.x * worker_slice_shape.y, half_cb_n_pages); }; while (get_padded_worker_slice_size_in_tiles(largest_worker_slice_shape, half_cb_n_pages) > max_slice_size_in_tiles) { @@ -514,5 +513,4 @@ std::vector RingReduceScatterTensorSlicer::create_worker_slice_shape } // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp similarity index 93% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp rename to ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 8b6d249857b..9a71b4b3034 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -8,13 +8,13 @@ #include #include "common/constants.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/program/program.hpp" +#include "ttnn/experimental/tensor/types.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn { namespace ccl { // Eventual home: ccl_topology_descriptors @@ -89,7 +89,7 @@ class CclOpTensorConfig { CclOpTensorConfig(Tensor const& tensor) : buffer_start_address(tensor.buffer()->address()), - df(tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) {} + df(tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype())) {} virtual uint32_t get_page_size() const = 0; virtual uint32_t get_unit_size() const = 0; @@ -100,14 +100,14 @@ class CclOpTensorConfig { protected: uint32_t buffer_start_address; - DataFormat df; + tt::DataFormat df; }; class CclOpInterleavedTensorConfig final : public virtual CclOpTensorConfig { public: CclOpInterleavedTensorConfig(Tensor const& input_tensor) : CclOpTensorConfig(input_tensor) { if (input_tensor.get_layout() == Layout::TILE) { - this->page_size = tt_metal::detail::TileSize(this->df); + this->page_size =tt::tt_metal::detail::TileSize(this->df); } else { this->page_size = input_tensor.buffer()->page_size(); } @@ -124,13 +124,13 @@ class CclOpShardedTensorConfig final : public virtual CclOpTensorConfig { CclOpShardedTensorConfig(Tensor const& tensor) : CclOpTensorConfig(tensor), shard_spec(tensor.shard_spec().value()) { if (tensor.get_layout() == Layout::TILE) { - this->page_size = tt_metal::detail::TileSize(this->df); + this->page_size =tt::tt_metal::detail::TileSize(this->df); TT_ASSERT( this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) % - (constants::TILE_HEIGHT * constants::TILE_WIDTH) == + (tt::constants::TILE_HEIGHT *tt::constants::TILE_WIDTH) == 0); this->unit_size = (this->shard_spec.shape.at(0) * this->shard_spec.shape.at(1) / - (constants::TILE_HEIGHT * constants::TILE_WIDTH)) * + (tt::constants::TILE_HEIGHT *tt::constants::TILE_WIDTH)) * this->page_size; } else { this->page_size = tensor.get_legacy_shape()[-1] * tensor.element_size(); @@ -154,9 +154,9 @@ class CclOpShardedTensorConfig final : public virtual CclOpTensorConfig { struct CclTensorSlicer { CclTensorSlicer( - Shape tensor_shape, - Shape dim_slice_factors, - // Shape page_shape, + tt::tt_metal::Shape tensor_shape, + tt::tt_metal::Shape dim_slice_factors, + // tt::tt_metal::Shape page_shape, std::size_t num_pages, std::size_t elem_size, std::size_t page_size_in_bytes) : @@ -183,12 +183,12 @@ struct CclTensorSlicer { return n; } - Shape const tensor_shape; - Shape const dim_slice_factors_per_rank; - // Shape const page_shape; + tt::tt_metal::Shape const tensor_shape; + tt::tt_metal::Shape const dim_slice_factors_per_rank; + // tt::tt_metal::Shape const page_shape; std::size_t const num_pages; - // Shape rank_slice_shape; + // tt::tt_metal::Shape rank_slice_shape; std::size_t const page_size_in_bytes; std::size_t const elem_size; @@ -289,7 +289,7 @@ struct InterleavedTensorWorkerSlice { uint32_t num_iterations = 0; while (slice_offset.y < tensor_slice_shape.y && slice_offset.x < tensor_slice_shape.x) { slice_offset = - tt::tt_metal::ccl::advance_slice_row_major(slice_offset, slice_shape, outer_slice_shape, num_workers); + ccl::advance_slice_row_major(slice_offset, slice_shape, outer_slice_shape, num_workers); num_iterations++; } @@ -342,7 +342,7 @@ class RingReduceScatterTensorSlicer : public LegacyCclTensorSlicer { tt_xy_pair const& tensor_slice_shape_in_elems, uint32_t num_workers, uint32_t max_slice_size_in_elements); std::vector create_worker_slice_shapes_for_tile_layout( - Shape const& tensor_shape, + tt::tt_metal::Shape const& tensor_shape, tt_xy_pair const& tensor_slice_shape_in_tiles, uint32_t num_workers, uint32_t max_slice_size_in_pages, @@ -450,14 +450,14 @@ class InterleavedRingAllGatherTensorSlicer : public LegacyCclTensorSlicer { }; KernelHandle generate_edm_kernel( - tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, ccl::EriscDatamoverBuilder const& edm_builder, CoreCoord const& eth_core, NOC noc_id); void generate_edm_kernels_for_ring_or_linear_topology( - tt_metal::Program& program, + tt::tt_metal::Program& program, Device const* device, RingTopology const& topology_config, std::vector const& clockwise_edm_builders, @@ -472,5 +472,4 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( EriscDataMoverTerminationMode termination_mode); } // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp similarity index 98% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp rename to ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index b8b5cdecbdc..55066f63eea 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -6,11 +6,10 @@ #include "eth_l1_address_map.h" #include "tensor/tensor_impl.hpp" -#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include -namespace tt { -namespace tt_metal { +namespace ttnn { namespace ccl { enum Topology { Ring = 0, Linear = 1, Meash = 2 }; @@ -35,7 +34,7 @@ struct EriscDatamoverConfig { return usable_l1_base_address + (handshake_location_size * 3) + edm_receiver_first_level_ack_source_word_size; } static uint32_t get_buffers_base_address(std::size_t num_edm_channels) { - uint32_t base_address = round_up( + uint32_t base_address =tt::round_up( get_semaphores_base_address(num_edm_channels) + num_edm_channels * semaphore_size, eth_word_size_bytes); TT_ASSERT(base_address % eth_word_size_bytes == 0); return base_address; @@ -43,7 +42,7 @@ struct EriscDatamoverConfig { static uint32_t compute_buffer_size(std::size_t num_edm_channels, uint32_t page_size = eth_word_size_bytes) { page_size = std::max(page_size, eth_word_size_bytes); TT_ASSERT(num_edm_channels > 0); - uint32_t buffer_size = round_down( + uint32_t buffer_size =tt::round_down( (total_l1_buffer_space - get_buffers_base_address(num_edm_channels)) / (num_edm_channels), page_size); log_trace(tt::LogOp, "total_l1_buffer_space: {}", total_l1_buffer_space); log_trace( @@ -333,5 +332,4 @@ class EriscDatamoverBuilder { }; }; // namespace ccl -}; // namespace tt_metal -}; // namespace tt +}; // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp similarity index 89% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp rename to ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index 49391feb915..d1037849dd8 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -7,14 +7,13 @@ #include "dataflow_api.h" #include "debug/assert.h" #include "debug/dprint.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -using tt::tt_metal::ccl::ShardType; -using tt::tt_metal::ccl::WorkerXY; -// using tt::tt_metal::ccl::coord_t; +using ttnn::ccl::ShardType; +using ttnn::ccl::WorkerXY; +// using ttnn::ccl::coord_t; -namespace tt { -namespace tt_metal { +namespace ttnn { namespace ccl { static FORCE_INLINE coord_t coord_from_args(uint32_t& arg_idx) { uint32_t x = get_arg_val(arg_idx++); @@ -23,8 +22,7 @@ static FORCE_INLINE coord_t coord_from_args(uint32_t& arg_idx) { } } // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ttnn FORCE_INLINE void push_filler_pages_to_cb(const uint32_t& cb_id, uint32_t num_pages) { ASSERT(num_pages < cb_interface[cb_id].fifo_num_pages); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/README.md b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/README.md similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/README.md rename to ttnn/cpp/ttnn/operations/ccl/kernels/edm/README.md diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp similarity index 98% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp rename to ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp index 843cd445653..8248934c780 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp @@ -10,12 +10,12 @@ #include "debug/assert.h" #include "eth_l1_address_map.h" #include "ethernet/dataflow_api.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "tt_metal/hw/inc/wormhole/noc/noc.h" -using tt::tt_metal::ccl::EriscDataMoverBufferSharingMode; -using tt::tt_metal::ccl::EriscDataMoverTerminationMode; -using tt::tt_metal::ccl::EriscDataMoverWorkerSignal; +using ttnn::ccl::EriscDataMoverBufferSharingMode; +using ttnn::ccl::EriscDataMoverTerminationMode; +using ttnn::ccl::EriscDataMoverWorkerSignal; namespace erisc { namespace datamover { @@ -34,7 +34,7 @@ struct edm_worker_index { uint16_t worker_index = 0; }; -using tt::tt_metal::ccl::WorkerXY; +using ttnn::ccl::WorkerXY; /* * The `ChannelBuffer` is a building block of the Erisc Data Mover (EDM). For every concurrent transaction @@ -115,13 +115,13 @@ class ChannelBuffer final { is_sender_side(is_sender_side) { clear_local_semaphore(); - if (TERMINATION_MODE != tt::tt_metal::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { + if (TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED || total_num_messages_to_move != 0) { if (is_sender_side) { // Tell the sender side workers that we're ready to accept data on this channel increment_worker_semaphores(); } } else { - ASSERT(TERMINATION_MODE != tt::tt_metal::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); + ASSERT(TERMINATION_MODE != ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED); goto_state(STATE::DONE); } }; diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp similarity index 96% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp rename to ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp index e7b7faaadc4..23d8c41e252 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_datamover.cpp @@ -8,8 +8,9 @@ #include "dataflow_api.h" #include "debug/dprint.h" #include "eth_l1_address_map.h" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/edm/erisc_async_datamover.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernels/edm/erisc_async_datamover.hpp" // Args Schema: // 1) handshake addr @@ -44,7 +45,7 @@ FORCE_INLINE void eth_setup_handshake2(std::uint32_t handshake_register_address, } } -using tt::tt_metal::ccl::WorkerXY; +using ttnn::ccl::WorkerXY; template struct sender_receiver_index_t { @@ -117,11 +118,11 @@ void kernel_main() { constexpr uint32_t num_senders = get_compile_time_arg_val(2); constexpr uint32_t num_receivers = get_compile_time_arg_val(3); - constexpr tt::tt_metal::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = - static_cast(get_compile_time_arg_val(4)); + constexpr ttnn::ccl::EriscDataMoverBufferSharingMode edm_buffer_sharing_mode = + static_cast(get_compile_time_arg_val(4)); - constexpr tt::tt_metal::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = - static_cast(get_compile_time_arg_val(5)); + constexpr ttnn::ccl::EriscDataMoverTerminationMode terminate_on_worker_signal = + static_cast(get_compile_time_arg_val(5)); constexpr auto EDM_CONFIG = erisc::datamover::EriscDatamoverConfig(); using EDM_CONFIG_T = decltype(EDM_CONFIG); diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp new file mode 100644 index 00000000000..4c32817bdfb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +#include "tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" + +#include "tensor/tensor_utils.hpp" + +#include "eth_l1_address_map.h" + +namespace ttnn { + +void LineAllGather::validate(const std::vector &input_tensors) const { + TT_FATAL(input_tensors.size() == 1); + const auto& input_tensor = input_tensors[0]; + const auto& layout = input_tensors[0].get_layout(); + const auto& dtype = input_tensors[0].get_dtype(); + const auto& page_size = input_tensors[0].buffer()->page_size(); + TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages"); + + // TODO: This can be removed by passing two page sizes, actual and aligned to be used for address offsets + // Buffer sizes also need to take this aligned page size into consideration + // TODO: Validate ring + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr , "Operands to all_gather need to be allocated in buffers on device!"); + TT_FATAL(this->num_links > 0); + TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows"); + TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value()); + if (this->receiver_device_id == this->sender_device_id) { + TT_FATAL(input_tensor.device()->get_ethernet_sockets(this->receiver_device_id.value()).size() >= 2 * this->num_links, "2 Device all gather requires at least 2 eth connections per link"); + } else { + TT_FATAL(this->topology == all_gather_op::Topology::Linear || (this->receiver_device_id.has_value() && input_tensor.device()->get_ethernet_sockets(this->receiver_device_id.value()).size() >= this->num_links), "All gather requires at least 1 eth connection per link between sender device {} and receiver device {}", this->sender_device_id, this->receiver_device_id); + TT_FATAL(this->topology == all_gather_op::Topology::Linear || (this->sender_device_id.has_value() &&input_tensor.device()->get_ethernet_sockets(this->sender_device_id.value()).size() >= this->num_links), "All gather requires at least 1 eth connection per link between sender device {} and receiver device {}", this->sender_device_id, this->receiver_device_id); + } + + TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + + // Sharding Config checks + bool input_sharded = input_tensor.is_sharded(); + if (input_sharded) { + // TODO(snijjar) + } +} + +std::vector LineAllGather::compute_output_shapes(const std::vector &input_tensors) const { + auto shape = input_tensors[0].get_legacy_shape(); + shape[this->dim] *= this->ring_size; + return std::vector(input_tensors.size(), shape); +} + +std::vector LineAllGather::create_output_tensors(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors[0]; + if(this->output_mem_config.is_sharded()) { + return {create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + input_tensor.get_dtype(), + input_tensor.get_layout(), + input_tensor.device(), + this->output_mem_config + )}; + } else { + return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config); + } +} + +operation::ProgramWithCallbacks LineAllGather::create_program(const std::vector & input_tensors, std::vector &output_tensors) const { + AllGatherMode line_all_gather_mode = choose_all_gather_mode(input_tensors.at(0), output_tensors.at(0), dim); + switch (line_all_gather_mode) { + case AllGatherMode::RING_INTERLEAVED: + case AllGatherMode::SINGLE_TILE_HIGH_WIDTH_SHARDED: + return all_gather_multi_core_with_workers(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology); + break; + case AllGatherMode::FULL_WORKER_GRID_SHARDED: + TT_THROW("Unsupported AllGatherMode"); + break; + default: + TT_THROW("Unsupported AllGatherMode"); + }; +} + +namespace operations { +namespace ccl { + +Tensor line_all_gather( + const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional& memory_config) { + + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + + auto devices = input_tensor.get_workers(); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [dim, num_links, memory_config, devices]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + + const auto& input_tensor = input_tensors.at(0); + uint32_t num_devices = devices.size(); + + uint32_t device_index = 0; // Initialize device index + uint32_t receiver_device_id = 0; // Initialize receiver device ID + uint32_t sender_device_id = 0; // Initialize sender device ID + + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices[i] == input_tensor.device()) { + device_index = i; + receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring + sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring + break; + } + } + + return operation::run( + ttnn::LineAllGather{ + dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), ttnn::all_gather_op::Topology::Linear}, + {input_tensor}); + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace ccl +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp new file mode 100644 index 00000000000..c6171bc57f0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "common/core_coord.h" +#include "impl/buffers/buffer.hpp" +#include "tensor/tensor.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" + +#include "tt_dnn/op_library/run_operation.hpp" + +#include +#include +#include + +namespace ttnn { + +namespace all_gather_op { +using ccl::Topology; +}; // namespace all_gather_op + +using ccl::EriscDatamoverBuilder; + + +struct LineAllGather { + const uint32_t dim; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const std::optional receiver_device_id; + const std::optional sender_device_id; + const MemoryConfig output_mem_config; + const all_gather_op::Topology topology; + + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; +}; + + +namespace operations { +namespace ccl { + +Tensor line_all_gather( + const Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt); + +} // namespace ccl +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_op.hpp new file mode 100644 index 00000000000..102d6a6b6c7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_op.hpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" +#include "ttnn/cpp/ttnn/multi_device.hpp" + +namespace ttnn { +namespace operations { +namespace ccl { + +struct ExecuteLineAllGather { + + static ttnn::Tensor execute_on_main_thread( + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links = 1, + const std::optional& memory_config = std::nullopt) { + return ttnn::operations::ccl::line_all_gather(input_tensor, dim, num_links, memory_config); + } +}; + +} // namespace ccl +} // namespace operations + +constexpr auto line_all_gather = ttnn::register_operation("ttnn::line_all_gather"); + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp new file mode 100644 index 00000000000..9f0e0f954e8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/line_all_gather/line_all_gather_pybind.hpp @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/ccl/line_all_gather/line_all_gather_op.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace ccl { + +namespace detail { + +template +void bind_line_all_gather(py::module& module, const ccl_operation_t& operation, const char* doc) { + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const uint32_t dim, + const uint32_t num_links, + const std::optional& memory_config) -> ttnn::Tensor { + return self(input_tensor, dim, num_links, memory_config); + }, + py::arg("input_tensor"), + py::arg("dim"), + py::kw_only(), + py::arg("num_links") = 1, + py::arg("memory_config") = std::nullopt}); +} + +} // namespace detail + + +void py_bind_line_all_gather(py::module& module) { + + detail::bind_line_all_gather( + module, + ttnn::line_all_gather, + R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices. + + Args: + * :attr:`input_tensor` (ttnn.Tensor): multi-device tensor + * :attr:`dim` (int) + + Keyword Args: + * :attr:`num_links` (int): Number of links to use for the all-gather operation. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.line_all_gather(tensor, dim=0) + + )doc"); +} + +} // namespace ccl +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp similarity index 99% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp rename to ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp index 87799e8f4a9..c6a37929fac 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp @@ -9,8 +9,7 @@ #include #include -namespace tt { -namespace tt_metal { +namespace ttnn { namespace ccl { enum EriscDataMoverBufferSharingMode : uint32_t { @@ -304,5 +303,4 @@ inline void full_worker_grid_addr_gen_width_sharded_advance( }; // namespace all_gather } // namespace ccl -} // namespace tt_metal -} // namespace tt +} // namespace ttnn diff --git a/ttnn/ttnn/operations/ccl.py b/ttnn/ttnn/operations/ccl.py deleted file mode 100644 index 6fffcb71ce1..00000000000 --- a/ttnn/ttnn/operations/ccl.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import sys - -import ttnn - -__all__ = [] - -ttnn.register_python_operation(name="ttnn.all_gather")(ttnn._ttnn.operations.ccl.all_gather)