Skip to content

Commit

Permalink
#9486: Merge CCL line_all_gather to TTNN (#9909)
Browse files Browse the repository at this point in the history
* #9486: Merge line_all_gather to TTNN

* #9486: Move CCL kernel files to TTNN

* #9486: Move pytests to TTNN

* #9486: Move CCL common to TTNN

* #9486: re-enable test cases

* #9486: re-organize namespace

* #9486: Use mesh_mapper for multi-device test

* #9486: Replace pcie_device_mesh with t3k_device_mesh

* #9486: Move kernel files into kernels directory

* #9486: Skip test for GS

* #0: (MINOR) fix namespace

* #9486: Modify namespace and profiler tests

* #0: Rebased
  • Loading branch information
Aswinmcw authored Jul 20, 2024
1 parent d75e0eb commit 5df2e14
Show file tree
Hide file tree
Showing 45 changed files with 828 additions and 628 deletions.
4 changes: 2 additions & 2 deletions tests/scripts/run_profiler_regressions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ run_additional_T3000_test(){
remove_default_log_locations
mkdir -p $PROFILER_ARTIFACTS_DIR

./tt_metal/tools/profiler/profile_this.py -c "pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py::test_all_gather_on_t3000_post_commit[mem_config0-input_dtype0-8-1-input_shape1-0-layout1]" > $PROFILER_ARTIFACTS_DIR/test_out.log
./tt_metal/tools/profiler/profile_this.py -c "'pytest tests/ttnn/unit_tests/operations/test_all_gather.py::test_all_gather_on_t3000_post_commit[mem_config=MemoryConfig\(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt\)-input_dtype=DataType.BFLOAT16-num_devices=8-num_links=1-input_shape=[8,\ 1,\ 33,\ 256]-dim=0-layout=Layout.ROW_MAJOR]'" > $PROFILER_ARTIFACTS_DIR/test_out.log

cat $PROFILER_ARTIFACTS_DIR/test_out.log

Expand All @@ -26,7 +26,7 @@ run_additional_T3000_test(){

run_async_mode_T3000_test(){
#Some tests here do not skip grayskull
if [ "$ARCH_NAME" != "grayskull" ]; then
if [ "$ARCH_NAME" == "wormhole_b0" ]; then
remove_default_log_locations
mkdir -p $PROFILER_ARTIFACTS_DIR

Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ run_frequent_api_pipeline_tests() {
./tests/scripts/run_python_api_unit_tests.sh
else
if [[ $tt_arch == "wormhole_b0" ]]; then
pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k nightly
pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k nightly
else
echo "API tests are not available for fast dispatch because they're already covered in post-commit"
fi
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/t3000/run_t3000_frequent_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ run_t3000_tteager_tests() {

echo "LOG_METAL: Running run_t3000_tteager_tests"

pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py -k post_commit ; fail+=$?
pytest -n auto tests/ttnn/unit_tests/operations/test_all_gather.py -k post_commit ; fail+=$?
pytest -n auto tests/tt_eager/python_api_testing/unit_testing/misc/test_reduce_scatter_post_commit.py ; fail+=$?

# distributed layernorm
Expand Down
22 changes: 11 additions & 11 deletions tests/tt_eager/ops/ccl/test_all_gather_sharded_indexing_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "gtest/gtest.h"
#include "ttnn/experimental/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"


TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTileRow_ClockWise_In3x5_NumShards3) {
Expand All @@ -23,7 +23,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t shard_offset = 0;
uint16_t old_curr_shard = curr_shard;

tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 1);
Expand All @@ -39,7 +39,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t curr_shard = 0;
uint16_t old_curr_shard = curr_shard;

tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 1);
Expand All @@ -57,7 +57,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t old_curr_shard = curr_shard;
ASSERT_EQ(curr_core_index, 0);

tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 2);
Expand All @@ -74,7 +74,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t old_curr_shard = curr_shard;
ASSERT_EQ(curr_core_index, 0);

tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 2);
Expand All @@ -92,7 +92,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t old_curr_shard = curr_shard;


tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 0);
Expand All @@ -108,7 +108,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
uint16_t curr_shard = 0;
uint16_t old_curr_shard = curr_shard;

tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
ASSERT_EQ(curr_shard_tile_x, 0);
ASSERT_EQ(curr_shard_tile_y, 0);
Expand Down Expand Up @@ -146,7 +146,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
for (uint16_t i = 0; i < num_core_iterations; i++) {

for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) {
tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
uint16_t next_tile_row = tile_row + 1;
if (next_tile_row == input_shard_num_tiles_y) {
Expand Down Expand Up @@ -184,7 +184,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceFullTil
for (uint16_t i = 0; i < num_core_iterations; i++) {

for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) {
tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance_full_tile_row(
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
uint16_t next_tile_row = tile_row + 1;
if (next_tile_row == input_shard_num_tiles_y) {
Expand Down Expand Up @@ -224,7 +224,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT

for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) {
for (uint16_t tile_col = curr_shard_tile_x; tile_col < input_shard_num_tiles_x; tile_col++) {
tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance (
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance (
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
uint16_t next_tile_row = tile_row;
uint16_t next_tile_col = tile_col + 1;
Expand Down Expand Up @@ -269,7 +269,7 @@ TEST(AllGatherSharded_WidthShardedIndexing_FullWorkerGridVariant, AdvanceSingleT

for (uint16_t tile_row = curr_shard_tile_y; tile_row < input_shard_num_tiles_y; tile_row++) {
for (uint16_t tile_col = curr_shard_tile_x; tile_col < input_shard_num_tiles_x; tile_col++) {
tt::tt_metal::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance (
ttnn::ccl::all_gather::full_worker_grid_addr_gen_width_sharded_advance (
curr_shard_tile_x, curr_shard_tile_y, curr_tile_index, curr_core_index, total_num_cores, input_shard_num_tiles_x, input_shard_num_tiles_y, num_shards_x, curr_shard, is_clockwise);
uint16_t next_tile_row = tile_row;
uint16_t next_tile_col = tile_col + 1;
Expand Down
Loading

0 comments on commit 5df2e14

Please sign in to comment.