From d6c5a99687ecdb7d5d321653b71196ef624eb072 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Thu, 5 Dec 2024 20:21:50 +0530 Subject: [PATCH] Conv op non tile multiple shard width (#15742) --- tests/scripts/run_tt_eager.py | 1 + tests/tt_eager/CMakeLists.txt | 1 + tests/tt_eager/ops/test_tensor_utils.cpp | 483 ++++++++++++++++++ .../unit_tests/operations/test_new_conv2d.py | 86 ++++ .../ttnn/operations/conv/conv2d/conv2d.cpp | 21 +- .../operations/conv/conv2d/conv2d_pybind.cpp | 12 +- .../operations/conv/conv2d/conv2d_utils.cpp | 139 +++-- .../operations/conv/conv2d/conv2d_utils.hpp | 6 +- .../conv/conv2d/device/conv2d_op.cpp | 47 +- .../conv/conv2d/device/conv2d_op.hpp | 2 - .../conv2d_op_sharded_program_factory.cpp | 139 +++-- ...onv2d_op_width_sharded_program_factory.cpp | 12 +- .../conv_bmm_tilize_col_major_out_blocks.cpp | 8 +- ..._mcast_padded_with_halo_3x3_weights_v2.cpp | 8 +- ...er_conv_weights_tiled_col_to_rm_blocks.cpp | 46 +- ...er_conv_weights_tiled_col_to_rm_blocks.cpp | 44 -- ...er_conv_weights_tiled_col_to_rm_blocks.cpp | 30 ++ ...er_conv_weights_tiled_col_to_rm_blocks.cpp | 70 ++- .../conv/conv2d/prepare_conv2d_weights.cpp | 45 +- .../conv/conv2d/prepare_conv2d_weights.hpp | 3 +- .../device/kernels/dataflow/halo_gather.cpp | 49 +- .../untilize_with_halo_v2_program_factory.cpp | 7 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 277 ++++++++-- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 10 + 24 files changed, 1219 insertions(+), 327 deletions(-) create mode 100644 tests/tt_eager/ops/test_tensor_utils.cpp diff --git a/tests/scripts/run_tt_eager.py b/tests/scripts/run_tt_eager.py index e4811b6a2ee..12369a2984c 100644 --- a/tests/scripts/run_tt_eager.py +++ b/tests/scripts/run_tt_eager.py @@ -36,6 +36,7 @@ TestEntry("tt_eager/tests/ops/test_bcast_op", "ops/test_bcast_op"), TestEntry("tt_eager/tests/ops/test_transpose_op", "ops/test_transpose_op"), TestEntry("tt_eager/tests/ops/test_sliding_window_ops", "ops/test_sliding_window_ops"), + TestEntry("tt_eager/tests/ops/test_tensor_utils", "ops/test_tensor_utils"), TestEntry("tt_eager/tests/ops/test_bmm_op", "ops/test_bmm_op"), void_for_bh(void_for_whb0(TestEntry("tt_eager/tests/ops/test_eltwise_unary_op", "ops/test_eltwise_unary_op"))), void_for_whb0( diff --git a/tests/tt_eager/CMakeLists.txt b/tests/tt_eager/CMakeLists.txt index 823fda96cfc..018c871b98f 100644 --- a/tests/tt_eager/CMakeLists.txt +++ b/tests/tt_eager/CMakeLists.txt @@ -24,6 +24,7 @@ set(TT_EAGER_TESTS_OPS ops/test_sfpu.cpp ops/test_sliding_window_ops.cpp ops/test_fold_op.cpp + ops/test_tensor_utils.cpp ) set(TT_EAGER_TESTS_TENSORS diff --git a/tests/tt_eager/ops/test_tensor_utils.cpp b/tests/tt_eager/ops/test_tensor_utils.cpp new file mode 100644 index 00000000000..1121b455b2a --- /dev/null +++ b/tests/tt_eager/ops/test_tensor_utils.cpp @@ -0,0 +1,483 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "common/assert.hpp" +#include "common/bfloat16.hpp" +#include "ttnn/cpp/ttnn/tensor/host_buffer/functions.hpp" +#include "ttnn/cpp/ttnn/tensor/tensor_utils.hpp" +#include "ttnn/cpp/ttnn/tensor/types.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" +#include "ttnn/tensor/host_buffer/types.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/numpy/functions.hpp" +#include "ttnn/tensor/types.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +using std::vector; +using tt::tt_metal::Tensor; +using namespace tt::tt_metal; + +static vector> ref_weight_in = { + { + 16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720, + 16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165, + 15525, 16060, 16213, 16245, 16199, 15887, 16222, 16222, 16250, 16114, 16204, 16205, 16108, 16133, 16199, 16173, + 15858, 16184, 16163, 16148, 15890, 16137, 16241, 16194, 16133, 15832, 16084, 16114, 16007, 15934, 16198, 16188, + 16105, 15965, 16145, 15882, 15513, 16037, 16158, 15897, 16156, 15971, 16157, 16069, 16241, 16231, 16174, 16102, + 16056, 16156, 16095, 16231, 16178, 15819, 15734, 16248, 16170, 16167, 16171, 15919, 15959, 16055, 15876, 16192, + 16033, 16155, 16058, 16038, 16145, 15645, 16096, 16162, 16253, 16245, 15824, 16167, 15957, 16162, 15909, 16254, + 16167, 16148, 16001, 16084, 16110, 16115, 15994, 16159, 15906, 16045, 15842, 16172, 16168, 16034, 15885, 16199, + 15945, 16243, 16060, 16169, 16210, 15454, 15814, 16159, 16214, 16172, 15812, 16248, 16249, 16224, 16111, 16130, + 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 16250, 15862, 16056, 16023, 16118, 15859, 16176, + 16034, 16225, 16084, 16235, 15747, 15966, 16177, 16144, 16145, 16221, 16007, 16130, 16133, 16234, 15808, 16235, + 16147, 15786, 16237, 16014, 16035, 15385, 16170, 16215, 15878, 16165, 16183, 16215, 16020, 16007, 15931, 16075, + 16150, 16141, 15524, 15912, 16212, 16061, 15257, 15893, 16173, 16145, 16010, 16180, 16188, 16019, 16246, 16093, + 15998, 16193, 16147, 16074, 16151, 16229, 16146, 16163, 15972, 16228, 16243, 16174, 16100, 16101, 16216, 16250, + 16179, 15853, 16024, 16196, 16208, 16082, 16075, 16172, 16225, 15999, 16148, 16032, 16225, 16247, 16177, 16150, + 16185, 16168, 16128, 16136, 16244, 15980, 16164, 16074, 16089, 16158, 16155, 16115, 15517, 16112, 16026, 16183, + 16169, 16019, 16020, 16068, 16158, 16191, 16091, 16224, 15882, 15826, 16024, 15805, 16145, 16053, 16151, 16141, + 16147, 15625, 16167, 16248, 16166, 16036, 16092, 15970, 16229, 15888, 16060, 15815, 16095, 16251, 16228, 16005, + 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 16112, 16255, 16215, 15897, 16231, 16222, 15641, + 15910, 16130, 16157, 15914, 15869, 16199, 16217, 16221, 16206, 16082, 16145, 15887, 16080, 15624, 15757, 16251, + 16178, 16063, 16104, 16087, 16184, 15695, 16221, 16059, 16249, 15496, 16219, 15980, 15423, 16195, 16056, 16241, + 16186, 16191, 15919, 16045, 16133, 16122, 15710, 16045, 15948, 15927, 15511, 15919, 16203, 16109, 15973, 16223, + 16048, 16241, 16237, 16155, 16180, 16152, 15618, 16200, 15912, 16128, 16159, 15694, 16147, 16178, 15987, 16254, + 16239, 16008, 16157, 16173, 16137, 16221, 16151, 16192, 16186, 16246, 16031, 16141, 16075, 15961, 15958, 15971, + 15934, 15967, 16241, 16145, 16189, 16103, 16123, 16248, 15976, 16174, 16002, 15790, 15725, 15719, 16094, 16121, + 16031, 16225, 16178, 16249, 16065, 16158, 15927, 16138, 15562, 16218, 15753, 16190, 16173, 16117, 16104, 16173, + 16137, 16155, 16229, 16182, 16253, 16112, 15966, 16105, 16169, 16232, 16006, 15884, 15529, 15978, 16194, 16225, + 16035, 16231, 16068, 16165, 16150, 16038, 16212, 16133, 16161, 14440, 16223, 16031, 16012, 16089, 16204, 16226, + 15934, 16174, 16243, 16105, 16175, 16119, 15964, 16201, 16242, 15978, 16187, 16225, 16002, 16032, 15962, 16245, + 16132, 16113, 15570, 16182, 15956, 15901, 16089, 16186, 16063, 16165, 16109, 15964, 16014, 15934, 16150, 16206, + 16221, 16191, 15856, 16172, 16132, 16013, 15879, 15923, 16183, 16180, 16074, 16109, 16144, 16215, 15931, 15953, + 15892, 15912, 16121, 15871, 16054, 16184, 16240, 15609, 16195, 16191, 16191, 15805, 16231, 15966, 15786, 16191, + 16141, 16187, 16149, 15674, 16246, 15958, 16021, 16018, 15990, 16173, 15821, 15745, 15494, 16142, 16237, 15383, + 16171, 16213, 16200, 16251, 16016, 16180, 16150, 15929, 15746, 16131, 16120, 16148, 16250, 16201, 16224, 16155, + 16045, 15967, 16246, 16105, 15981, 16224, 16243, 16124, 16240, 16183, 16204, 16120, 16161, 16181, 16223, 16127, + 16022, 16216, 16217, 15943, 16158, 16197, 15448, 16249, 16049, 16220, 15895, 16199, 16251, 16252, 16116, 16192, + }, + + { + 16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720, + 16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165, + 15525, 16060, 16213, 16245, 16199, 15887, 16222, 16222, 16250, 16114, 16204, 16205, 16108, 16133, 16199, 16173, + 15858, 16184, 16163, 16148, 15890, 16137, 16241, 16194, 16133, 15832, 16084, 16114, 16007, 15934, 16198, 16188, + 16105, 15965, 16145, 15882, 15513, 16037, 16158, 15897, 16156, 15971, 16157, 16069, 16241, 16231, 16174, 16102, + 16056, 16156, 16095, 16231, 16178, 15819, 15734, 16248, 16170, 16167, 16171, 15919, 15959, 16055, 15876, 16192, + 16033, 16155, 16058, 16038, 16145, 15645, 16096, 16162, 16253, 16245, 15824, 16167, 15957, 16162, 15909, 16254, + 16167, 16148, 16001, 16084, 16110, 16115, 15994, 16159, 15906, 16045, 15842, 16172, 16168, 16034, 15885, 16199, + 15945, 16243, 16060, 16169, 16210, 15454, 15814, 16159, 16214, 16172, 15812, 16248, 16249, 16224, 16111, 16130, + 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 16250, 15862, 16056, 16023, 16118, 15859, 16176, + 16034, 16225, 16084, 16235, 15747, 15966, 16177, 16144, 16145, 16221, 16007, 16130, 16133, 16234, 15808, 16235, + 16147, 15786, 16237, 16014, 16035, 15385, 16170, 16215, 15878, 16165, 16183, 16215, 16020, 16007, 15931, 16075, + 16150, 16141, 15524, 15912, 16212, 16061, 15257, 15893, 16173, 16145, 16010, 16180, 16188, 16019, 16246, 16093, + 15998, 16193, 16147, 16074, 16151, 16229, 16146, 16163, 15972, 16228, 16243, 16174, 16100, 16101, 16216, 16250, + 16179, 15853, 16024, 16196, 16208, 16082, 16075, 16172, 16225, 15999, 16148, 16032, 16225, 16247, 16177, 16150, + 16185, 16168, 16128, 16136, 16244, 15980, 16164, 16074, 16089, 16158, 16155, 16115, 15517, 16112, 16026, 16183, + 16169, 16019, 16020, 16068, 16158, 16191, 16091, 16224, 15882, 15826, 16024, 15805, 16145, 16053, 16151, 16141, + 16147, 15625, 16167, 16248, 16166, 16036, 16092, 15970, 16229, 15888, 16060, 15815, 16095, 16251, 16228, 16005, + 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 16112, 16255, 16215, 15897, 16231, 16222, 15641, + 15910, 16130, 16157, 15914, 15869, 16199, 16217, 16221, 16206, 16082, 16145, 15887, 16080, 15624, 15757, 16251, + 16178, 16063, 16104, 16087, 16184, 15695, 16221, 16059, 16249, 15496, 16219, 15980, 15423, 16195, 16056, 16241, + 16186, 16191, 15919, 16045, 16133, 16122, 15710, 16045, 15948, 15927, 15511, 15919, 16203, 16109, 15973, 16223, + 16048, 16241, 16237, 16155, 16180, 16152, 15618, 16200, 15912, 16128, 16159, 15694, 16147, 16178, 15987, 16254, + 16239, 16008, 16157, 16173, 16137, 16221, 16151, 16192, 16186, 16246, 16031, 16141, 16075, 15961, 15958, 15971, + 15934, 15967, 16241, 16145, 16189, 16103, 16123, 16248, 15976, 16174, 16002, 15790, 15725, 15719, 16094, 16121, + 16031, 16225, 16178, 16249, 16065, 16158, 15927, 16138, 15562, 16218, 15753, 16190, 16173, 16117, 16104, 16173, + 16137, 16155, 16229, 16182, 16253, 16112, 15966, 16105, 16169, 16232, 16006, 15884, 15529, 15978, 16194, 16225, + 16035, 16231, 16068, 16165, 16150, 16038, 16212, 16133, 16161, 14440, 16223, 16031, 16012, 16089, 16204, 16226, + 15934, 16174, 16243, 16105, 16175, 16119, 15964, 16201, 16242, 15978, 16187, 16225, 16002, 16032, 15962, 16245, + 16132, 16113, 15570, 16182, 15956, 15901, 16089, 16186, 16063, 16165, 16109, 15964, 16014, 15934, 16150, 16206, + 16221, 16191, 15856, 16172, 16132, 16013, 15879, 15923, 16183, 16180, 16074, 16109, 16144, 16215, 15931, 15953, + 15892, 15912, 16121, 15871, 16054, 16184, 16240, 15609, 16195, 16191, 16191, 15805, 16231, 15966, 15786, 16191, + 16141, 16187, 16149, 15674, 16246, 15958, 16021, 16018, 15990, 16173, 15821, 15745, 15494, 16142, 16237, 15383, + 16171, 16213, 16200, 16251, 16016, 16180, 16150, 15929, 15746, 16131, 16120, 16148, 16250, 16201, 16224, 16155, + 16045, 15967, 16246, 16105, 15981, 16224, 16243, 16124, 16240, 16183, 16204, 16120, 16161, 16181, 16223, 16127, + 16022, 16216, 16217, 15943, 16158, 16197, 15448, 16249, 16049, 16220, 15895, 16199, 16251, 16252, 16116, 16192, + 16126, 15236, 16163, 16009, 16060, 16082, 15884, 16091, 16210, 16024, 15938, 16077, 16130, 15863, 15973, 16251, + 15816, 16079, 16220, 16145, 16249, 16047, 16245, 16201, 16232, 16082, 16198, 16055, 16042, 16076, 15782, 16026, + 16080, 16198, 15981, 16237, 15879, 16038, 15706, 16243, 16185, 15460, 15419, 16136, 16197, 16027, 15894, 16226, + 15778, 16000, 15799, 16173, 16172, 16207, 15995, 16093, 16087, 16192, 16142, 16212, 16220, 16066, 16186, 15813, + 16010, 16003, 15878, 16151, 15714, 16115, 16026, 16121, 16006, 16106, 16105, 16134, 16174, 16098, 16178, 16218, + 16017, 16093, 16066, 16211, 15929, 16130, 16201, 15792, 15720, 16168, 16178, 15955, 16199, 16216, 16199, 16174, + 16004, 15926, 16063, 15759, 16150, 15390, 16011, 16228, 16061, 15880, 15945, 16199, 16107, 16236, 15670, 16183, + 16204, 16123, 15773, 16112, 16132, 16225, 16029, 16122, 16147, 16084, 16245, 15922, 16165, 16115, 15632, 16200, + 16092, 16142, 16130, 15907, 16137, 15891, 16174, 16166, 16014, 16138, 15875, 16038, 16073, 15894, 16244, 15907, + 15935, 15876, 16231, 16148, 16139, 15804, 16105, 16233, 16225, 15785, 16106, 16204, 16185, 16224, 16076, 15807, + 16231, 16090, 16176, 16114, 16179, 16148, 16039, 16183, 16193, 15581, 16162, 16187, 15989, 16196, 15908, 15392, + 16203, 16029, 16245, 15982, 16106, 16128, 16151, 16244, 16219, 16142, 16106, 15815, 16243, 16159, 16147, 16220, + 16210, 15905, 16232, 16254, 16208, 15790, 15907, 15809, 16160, 16162, 16075, 16243, 15744, 16239, 16089, 16101, + 16004, 16186, 16217, 16190, 15624, 16029, 16245, 15861, 16053, 16099, 16054, 16072, 15493, 16136, 15933, 16216, + 16077, 16137, 16237, 16174, 15820, 16155, 16241, 15817, 16222, 15804, 16104, 15717, 16039, 15793, 15982, 15986, + 16157, 16214, 15623, 16133, 15487, 16131, 16091, 16166, 15755, 16139, 16000, 15620, 15970, 16148, 16001, 16197, + 15878, 16064, 15429, 16123, 15852, 16251, 16158, 15994, 16249, 16063, 16253, 15675, 16081, 16030, 15910, 16212, + 16163, 16206, 16123, 16163, 16253, 16060, 15749, 16032, 16200, 16205, 16019, 15760, 15991, 16174, 16169, 16066, + 15995, 16162, 16170, 16237, 16132, 16218, 16089, 16126, 16142, 16091, 16018, 16210, 16180, 16188, 16084, 16100, + 16056, 16248, 16212, 16057, 16236, 16075, 15676, 16189, 15982, 16101, 16050, 16239, 16208, 16003, 16252, 16067, + 16248, 16178, 16231, 16229, + }, + + { + 16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720, + 16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165, + 15525, 16060, 16213, 16245, 16199, 15887, 16222, 16222, 16250, 16114, 16204, 16205, 16108, 16133, 16199, 16173, + 15858, 16184, 16163, 16148, 15890, 16137, 16241, 16194, 16133, 15832, 16084, 16114, 16007, 15934, 16198, 16188, + 16105, 15965, 16145, 15882, 15513, 16037, 16158, 15897, 16156, 15971, 16157, 16069, 16241, 16231, 16174, 16102, + 16056, 16156, 16095, 16231, 16178, 15819, 15734, 16248, 16170, 16167, 16171, 15919, 15959, 16055, 15876, 16192, + 16033, 16155, 16058, 16038, 16145, 15645, 16096, 16162, 16253, 16245, 15824, 16167, 15957, 16162, 15909, 16254, + 16167, 16148, 16001, 16084, 16110, 16115, 15994, 16159, 15906, 16045, 15842, 16172, 16168, 16034, 15885, 16199, + 15945, 16243, 16060, 16169, 16210, 15454, 15814, 16159, 16214, 16172, 15812, 16248, 16249, 16224, 16111, 16130, + 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 16250, 15862, 16056, 16023, 16118, 15859, 16176, + 16034, 16225, 16084, 16235, 15747, 15966, 16177, 16144, 16145, 16221, 16007, 16130, 16133, 16234, 15808, 16235, + 16147, 15786, 16237, 16014, 16035, 15385, 16170, 16215, 15878, 16165, 16183, 16215, 16020, 16007, 15931, 16075, + 16150, 16141, 15524, 15912, 16212, 16061, 15257, 15893, 16173, 16145, 16010, 16180, 16188, 16019, 16246, 16093, + 15998, 16193, 16147, 16074, 16151, 16229, 16146, 16163, 15972, 16228, 16243, 16174, 16100, 16101, 16216, 16250, + 16179, 15853, 16024, 16196, 16208, 16082, 16075, 16172, 16225, 15999, 16148, 16032, 16225, 16247, 16177, 16150, + 16185, 16168, 16128, 16136, 16244, 15980, 16164, 16074, 16089, 16158, 16155, 16115, 15517, 16112, 16026, 16183, + 16169, 16019, 16020, 16068, 16158, 16191, 16091, 16224, 15882, 15826, 16024, 15805, 16145, 16053, 16151, 16141, + 16147, 15625, 16167, 16248, 16166, 16036, 16092, 15970, 16229, 15888, 16060, 15815, 16095, 16251, 16228, 16005, + 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 16112, 16255, 16215, 15897, 16231, 16222, 15641, + 15910, 16130, 16157, 15914, 15869, 16199, 16217, 16221, 16206, 16082, 16145, 15887, 16080, 15624, 15757, 16251, + 16178, 16063, 16104, 16087, 16184, 15695, 16221, 16059, 16249, 15496, 16219, 15980, 15423, 16195, 16056, 16241, + 16186, 16191, 15919, 16045, 16133, 16122, 15710, 16045, 15948, 15927, 15511, 15919, 16203, 16109, 15973, 16223, + 16048, 16241, 16237, 16155, 16180, 16152, 15618, 16200, 15912, 16128, 16159, 15694, 16147, 16178, 15987, 16254, + 16239, 16008, 16157, 16173, 16137, 16221, 16151, 16192, 16186, 16246, 16031, 16141, 16075, 15961, 15958, 15971, + 15934, 15967, 16241, 16145, 16189, 16103, 16123, 16248, 15976, 16174, 16002, 15790, 15725, 15719, 16094, 16121, + 16031, 16225, 16178, 16249, 16065, 16158, 15927, 16138, 15562, 16218, 15753, 16190, 16173, 16117, 16104, 16173, + 16137, 16155, 16229, 16182, 16253, 16112, 15966, 16105, 16169, 16232, 16006, 15884, 15529, 15978, 16194, 16225, + 16035, 16231, 16068, 16165, 16150, 16038, 16212, 16133, 16161, 14440, 16223, 16031, 16012, 16089, 16204, 16226, + 15934, 16174, 16243, 16105, 16175, 16119, 15964, 16201, 16242, 15978, 16187, 16225, 16002, 16032, 15962, 16245, + 16132, 16113, 15570, 16182, 15956, 15901, 16089, 16186, 16063, 16165, 16109, 15964, 16014, 15934, 16150, 16206, + 16221, 16191, 15856, 16172, 16132, 16013, 15879, 15923, 16183, 16180, 16074, 16109, 16144, 16215, 15931, 15953, + 15892, 15912, 16121, 15871, 16054, 16184, 16240, 15609, 16195, 16191, 16191, 15805, 16231, 15966, 15786, 16191, + 16141, 16187, 16149, 15674, 16246, 15958, 16021, 16018, 15990, 16173, 15821, 15745, 15494, 16142, 16237, 15383, + 16171, 16213, 16200, 16251, 16016, 16180, 16150, 15929, 15746, 16131, 16120, 16148, 16250, 16201, 16224, 16155, + 16045, 15967, 16246, 16105, 15981, 16224, 16243, 16124, 16240, 16183, 16204, 16120, 16161, 16181, 16223, 16127, + 16022, 16216, 16217, 15943, 16158, 16197, 15448, 16249, 16049, 16220, 15895, 16199, 16251, 16252, 16116, 16192, + 16126, 15236, 16163, 16009, 16060, 16082, 15884, 16091, 16210, 16024, 15938, 16077, 16130, 15863, 15973, 16251, + 15816, 16079, 16220, 16145, 16249, 16047, 16245, 16201, 16232, 16082, 16198, 16055, 16042, 16076, 15782, 16026, + 16080, 16198, 15981, 16237, 15879, 16038, 15706, 16243, 16185, 15460, 15419, 16136, 16197, 16027, 15894, 16226, + 15778, 16000, 15799, 16173, 16172, 16207, 15995, 16093, 16087, 16192, 16142, 16212, 16220, 16066, 16186, 15813, + 16010, 16003, 15878, 16151, 15714, 16115, 16026, 16121, 16006, 16106, 16105, 16134, 16174, 16098, 16178, 16218, + 16017, 16093, 16066, 16211, 15929, 16130, 16201, 15792, 15720, 16168, 16178, 15955, 16199, 16216, 16199, 16174, + 16004, 15926, 16063, 15759, 16150, 15390, 16011, 16228, 16061, 15880, 15945, 16199, 16107, 16236, 15670, 16183, + 16204, 16123, 15773, 16112, 16132, 16225, 16029, 16122, 16147, 16084, 16245, 15922, 16165, 16115, 15632, 16200, + 16092, 16142, 16130, 15907, 16137, 15891, 16174, 16166, 16014, 16138, 15875, 16038, 16073, 15894, 16244, 15907, + 15935, 15876, 16231, 16148, 16139, 15804, 16105, 16233, 16225, 15785, 16106, 16204, 16185, 16224, 16076, 15807, + 16231, 16090, 16176, 16114, 16179, 16148, 16039, 16183, 16193, 15581, 16162, 16187, 15989, 16196, 15908, 15392, + 16203, 16029, 16245, 15982, 16106, 16128, 16151, 16244, 16219, 16142, 16106, 15815, 16243, 16159, 16147, 16220, + 16210, 15905, 16232, 16254, 16208, 15790, 15907, 15809, 16160, 16162, 16075, 16243, 15744, 16239, 16089, 16101, + 16004, 16186, 16217, 16190, 15624, 16029, 16245, 15861, 16053, 16099, 16054, 16072, 15493, 16136, 15933, 16216, + 16077, 16137, 16237, 16174, 15820, 16155, 16241, 15817, 16222, 15804, 16104, 15717, 16039, 15793, 15982, 15986, + 16157, 16214, 15623, 16133, 15487, 16131, 16091, 16166, 15755, 16139, 16000, 15620, 15970, 16148, 16001, 16197, + 15878, 16064, 15429, 16123, 15852, 16251, 16158, 15994, 16249, 16063, 16253, 15675, 16081, 16030, 15910, 16212, + 16163, 16206, 16123, 16163, 16253, 16060, 15749, 16032, 16200, 16205, 16019, 15760, 15991, 16174, 16169, 16066, + }, + { + 16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720, + 16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165, + 15525, 16060, 16213, 16245, 16199, 15887, 16222, 16222, 16250, 16114, 16204, 16205, 16108, 16133, 16199, 16173, + 15858, 16184, 16163, 16148, 15890, 16137, 16241, 16194, 16133, 15832, 16084, 16114, 16007, 15934, 16198, 16188, + 16105, 15965, 16145, 15882, 15513, 16037, 16158, 15897, 16156, 15971, 16157, 16069, 16241, 16231, 16174, 16102, + 16056, 16156, 16095, 16231, 16178, 15819, 15734, 16248, 16170, 16167, 16171, 15919, 15959, 16055, 15876, 16192, + 16033, 16155, 16058, 16038, 16145, 15645, 16096, 16162, 16253, 16245, 15824, 16167, 15957, 16162, 15909, 16254, + 16167, 16148, 16001, 16084, 16110, 16115, 15994, 16159, 15906, 16045, 15842, 16172, 16168, 16034, 15885, 16199, + 15945, 16243, 16060, 16169, 16210, 15454, 15814, 16159, 16214, 16172, 15812, 16248, 16249, 16224, 16111, 16130, + 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 16250, 15862, 16056, 16023, 16118, 15859, 16176, + 16034, 16225, 16084, 16235, 15747, 15966, 16177, 16144, 16145, 16221, 16007, 16130, 16133, 16234, 15808, 16235, + 16147, 15786, 16237, 16014, 16035, 15385, 16170, 16215, 15878, 16165, 16183, 16215, 16020, 16007, 15931, 16075, + 16150, 16141, 15524, 15912, 16212, 16061, 15257, 15893, 16173, 16145, 16010, 16180, 16188, 16019, 16246, 16093, + 15998, 16193, 16147, 16074, 16151, 16229, 16146, 16163, 15972, 16228, 16243, 16174, 16100, 16101, 16216, 16250, + 16179, 15853, 16024, 16196, 16208, 16082, 16075, 16172, 16225, 15999, 16148, 16032, 16225, 16247, 16177, 16150, + 16185, 16168, 16128, 16136, 16244, 15980, 16164, 16074, 16089, 16158, 16155, 16115, 15517, 16112, 16026, 16183, + 16169, 16019, 16020, 16068, 16158, 16191, 16091, 16224, 15882, 15826, 16024, 15805, 16145, 16053, 16151, 16141, + 16147, 15625, 16167, 16248, 16166, 16036, 16092, 15970, 16229, 15888, 16060, 15815, 16095, 16251, 16228, 16005, + 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 16112, 16255, 16215, 15897, 16231, 16222, 15641, + 15910, 16130, 16157, 15914, 15869, 16199, 16217, 16221, 16206, 16082, 16145, 15887, 16080, 15624, 15757, 16251, + 16178, 16063, 16104, 16087, 16184, 15695, 16221, 16059, 16249, 15496, 16219, 15980, 15423, 16195, 16056, 16241, + 16186, 16191, 15919, 16045, 16133, 16122, 15710, 16045, 15948, 15927, 15511, 15919, 16203, 16109, 15973, 16223, + 16048, 16241, 16237, 16155, 16180, 16152, 15618, 16200, 15912, 16128, 16159, 15694, 16147, 16178, 15987, 16254, + 16239, 16008, 16157, 16173, 16137, 16221, 16151, 16192, 16186, 16246, 16031, 16141, 16075, 15961, 15958, 15971, + 15934, 15967, 16241, 16145, 16189, 16103, 16123, 16248, 15976, 16174, 16002, 15790, 15725, 15719, 16094, 16121, + 16031, 16225, 16178, 16249, 16065, 16158, 15927, 16138, 15562, 16218, 15753, 16190, 16173, 16117, 16104, 16173, + 16137, 16155, 16229, 16182, 16253, 16112, 15966, 16105, 16169, 16232, 16006, 15884, 15529, 15978, 16194, 16225, + 16035, 16231, 16068, 16165, 16150, 16038, 16212, 16133, 16161, 14440, 16223, 16031, 16012, 16089, 16204, 16226, + 15934, 16174, 16243, 16105, 16175, 16119, 15964, 16201, 16242, 15978, 16187, 16225, 16002, 16032, 15962, 16245, + 16132, 16113, 15570, 16182, 15956, 15901, 16089, 16186, 16063, 16165, 16109, 15964, 16014, 15934, 16150, 16206, + 16221, 16191, 15856, 16172, 16132, 16013, 15879, 15923, 16183, 16180, 16074, 16109, 16144, 16215, 15931, 15953, + 15892, 15912, 16121, 15871, 16054, 16184, 16240, 15609, 16195, 16191, 16191, 15805, 16231, 15966, 15786, 16191, + 16141, 16187, 16149, 15674, 16246, 15958, 16021, 16018, 15990, 16173, 15821, 15745, 15494, 16142, 16237, 15383, + 16171, 16213, 16200, 16251, 16016, 16180, 16150, 15929, 15746, 16131, 16120, 16148, 16250, 16201, 16224, 16155, + 16045, 15967, 16246, 16105, 15981, 16224, 16243, 16124, 16240, 16183, 16204, 16120, 16161, 16181, 16223, 16127, + 16022, 16216, 16217, 15943, 16158, 16197, 15448, 16249, 16049, 16220, 15895, 16199, 16251, 16252, 16116, 16192, + 16126, 15236, 16163, 16009, 16060, 16082, 15884, 16091, 16210, 16024, 15938, 16077, 16130, 15863, 15973, 16251, + 15816, 16079, 16220, 16145, 16249, 16047, 16245, 16201, 16232, 16082, 16198, 16055, 16042, 16076, 15782, 16026, + 16080, 16198, 15981, 16237, 15879, 16038, 15706, 16243, 16185, 15460, 15419, 16136, 16197, 16027, 15894, 16226, + 15778, 16000, 15799, 16173, 16172, 16207, 15995, 16093, 16087, 16192, 16142, 16212, 16220, 16066, 16186, 15813, + 16010, 16003, 15878, 16151, 15714, 16115, 16026, 16121, 16006, 16106, 16105, 16134, 16174, 16098, 16178, 16218, + 16017, 16093, 16066, 16211, 15929, 16130, 16201, 15792, 15720, 16168, 16178, 15955, 16199, 16216, 16199, 16174, + 16004, 15926, 16063, 15759, 16150, 15390, 16011, 16228, 16061, 15880, 15945, 16199, 16107, 16236, 15670, 16183, + 16204, 16123, 15773, 16112, 16132, 16225, 16029, 16122, 16147, 16084, 16245, 15922, 16165, 16115, 15632, 16200, + 16092, 16142, 16130, 15907, 16137, 15891, 16174, 16166, 16014, 16138, 15875, 16038, 16073, 15894, 16244, 15907, + 15935, 15876, 16231, 16148, 16139, 15804, 16105, 16233, 16225, 15785, 16106, 16204, 16185, 16224, 16076, 15807, + 16231, 16090, 16176, 16114, 16179, 16148, 16039, 16183, 16193, 15581, 16162, 16187, 15989, 16196, 15908, 15392, + 16203, 16029, 16245, 15982, 16106, 16128, 16151, 16244, 16219, 16142, 16106, 15815, 16243, 16159, 16147, 16220, + 16210, 15905, 16232, 16254, 16208, 15790, 15907, 15809, 16160, 16162, 16075, 16243, 15744, 16239, 16089, 16101, + 16004, 16186, 16217, 16190, 15624, 16029, 16245, 15861, 16053, 16099, 16054, 16072, 15493, 16136, 15933, 16216, + 16077, 16137, 16237, 16174, 15820, 16155, 16241, 15817, 16222, 15804, 16104, 15717, 16039, 15793, 15982, 15986, + 16157, 16214, 15623, 16133, 15487, 16131, 16091, 16166, 15755, 16139, 16000, 15620, 15970, 16148, 16001, 16197, + 15878, 16064, 15429, 16123, 15852, 16251, 16158, 15994, 16249, 16063, 16253, 15675, 16081, 16030, 15910, 16212, + 16163, 16206, 16123, 16163, 16253, 16060, 15749, 16032, 16200, 16205, 16019, 15760, 15991, 16174, 16169, 16066, + 15995, 16162, 16170, 16237, 16132, 16218, 16089, 16126, 16142, 16091, 16018, 16210, 16180, 16188, 16084, 16100, + 16056, 16248, 16212, 16057, 16236, 16075, 15676, 16189, 15982, 16101, 16050, 16239, 16208, 16003, 16252, 16067, + 16248, 16178, 16231, 16229, 16023, 15863, 16253, 15991, 15999, 15977, 15832, 16122, 16243, 16228, 15983, 16055, + 16176, 16069, 15727, 16234, 16187, 15849, 16225, 16161, 16011, 15880, 16066, 16063, 16063, 16038, 16191, 16174, + 15987, 16203, 15919, 16129, 16102, 16023, 16027, 16226, 16214, 16052, 15987, 16189, 16128, 16142, 16241, 15950, + 16162, 16140, 16222, 16133, 16240, 16050, 16192, 15561, 16179, 15896, 16247, 15879, 16254, 16181, 16103, 16181, + 15761, 16156, 16021, 16172, 15900, 16101, 16085, 16178, 15878, 16065, 16154, 15820, 16067, 16245, 16229, 15764, + 16247, 15518, 16140, 16250, 16012, 15896, 16151, 16004, 16229, 15964, 16080, 16148, 16141, 16249, 16011, 16011, + 16105, 16248, 16077, 15568, 15998, 16227, 16129, 16181, 16030, 16014, 16062, 16229, 16134, 15577, 16192, 16160, + 16042, 16040, 16236, 16247, 16220, 15916, 15687, 16230, 16001, 16040, 16100, 16227, 15830, 16131, 16050, 16130, + 16189, 16070, 16174, 16135, 16159, 16241, 16181, 16228, 15953, 16173, 16046, 16163, 16173, 16140, 16225, 16011, + 16139, 15895, 16016, 16219, 15607, 16162, 16181, 16025, 15361, 16107, 16062, 15560, 16135, 16142, 16236, 16056, + 15799, 16128, 16079, 15901, 15559, 16089, 16047, 16231, 16159, 15371, 16014, 16248, 15958, 16176, 15852, 15819, + 16147, 16020, 16177, 16138, 16172, 16185, 16242, 16071, + } + +}; +static vector> ref_weight_out = { + {16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16156, 15971, 16157, 16069, 16241, 16231, 16174, + 16102, 16056, 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 15972, 16228, 16243, 16174, 16100, + 16101, 16216, 16250, 16179, 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 15912, 16128, 16159, + 15694, 16147, 16178, 15987, 16254, 16239, 16035, 16231, 16068, 16165, 16150, 16038, 16212, 16133, 16161, 16195, + 16191, 16191, 15805, 16231, 15966, 15786, 16191, 16141, 16159, 16165, 16068, 16096, 16024, 16228, 15720, 16246, + 16011, 16156, 16095, 16231, 16178, 15819, 15734, 16248, 16170, 16167, 16250, 15862, 16056, 16023, 16118, 15859, + 16176, 16034, 16225, 15853, 16024, 16196, 16208, 16082, 16075, 16172, 16225, 15999, 16112, 16255, 16215, 15897, + 16231, 16222, 15641, 15910, 16130, 16008, 16157, 16173, 16137, 16221, 16151, 16192, 16186, 16246, 14440, 16223, + 16031, 16012, 16089, 16204, 16226, 15934, 16174, 16187, 16149, 15674, 16246, 15958, 16021, 16018, 15990, 16173, + 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16171, 15919, 15959, 16055, 15876, 16192, 16033, + 16155, 16058, 16084, 16235, 15747, 15966, 16177, 16144, 16145, 16221, 16007, 16148, 16032, 16225, 16247, 16177, + 16150, 16185, 16168, 16128, 16157, 15914, 15869, 16199, 16217, 16221, 16206, 16082, 16145, 16031, 16141, 16075, + 15961, 15958, 15971, 15934, 15967, 16241, 16243, 16105, 16175, 16119, 15964, 16201, 16242, 15978, 16187, 15821, + 15745, 15494, 16142, 16237, 15383, 16171, 16213, 16200, 16214, 15761, 16044, 15794, 16165, 15525, 16060, 16213, + 16245, 16038, 16145, 15645, 16096, 16162, 16253, 16245, 15824, 16167, 16130, 16133, 16234, 15808, 16235, 16147, + 15786, 16237, 16014, 16136, 16244, 15980, 16164, 16074, 16089, 16158, 16155, 16115, 15887, 16080, 15624, 15757, + 16251, 16178, 16063, 16104, 16087, 16145, 16189, 16103, 16123, 16248, 15976, 16174, 16002, 15790, 16225, 16002, + 16032, 15962, 16245, 16132, 16113, 15570, 16182, 16251, 16016, 16180, 16150, 15929, 15746, 16131, 16120, 16148, + 16199, 15887, 16222, 16222, 16250, 16114, 16204, 16205, 16108, 15957, 16162, 15909, 16254, 16167, 16148, 16001, + 16084, 16110, 16035, 15385, 16170, 16215, 15878, 16165, 16183, 16215, 16020, 15517, 16112, 16026, 16183, 16169, + 16019, 16020, 16068, 16158, 16184, 15695, 16221, 16059, 16249, 15496, 16219, 15980, 15423, 15725, 15719, 16094, + 16121, 16031, 16225, 16178, 16249, 16065, 15956, 15901, 16089, 16186, 16063, 16165, 16109, 15964, 16014, 16250, + 16201, 16224, 16155, 16045, 15967, 16246, 16105, 15981, 16133, 16199, 16173, 15858, 16184, 16163, 16148, 15890, + 16137, 16115, 15994, 16159, 15906, 16045, 15842, 16172, 16168, 16034, 16007, 15931, 16075, 16150, 16141, 15524, + 15912, 16212, 16061, 16191, 16091, 16224, 15882, 15826, 16024, 15805, 16145, 16053, 16195, 16056, 16241, 16186, + 16191, 15919, 16045, 16133, 16122, 16158, 15927, 16138, 15562, 16218, 15753, 16190, 16173, 16117, 15934, 16150, + 16206, 16221, 16191, 15856, 16172, 16132, 16013, 16224, 16243, 16124, 16240, 16183, 16204, 16120, 16161, 16181, + 16241, 16194, 16133, 15832, 16084, 16114, 16007, 15934, 16198, 15885, 16199, 15945, 16243, 16060, 16169, 16210, + 15454, 15814, 15257, 15893, 16173, 16145, 16010, 16180, 16188, 16019, 16246, 16151, 16141, 16147, 15625, 16167, + 16248, 16166, 16036, 16092, 15710, 16045, 15948, 15927, 15511, 15919, 16203, 16109, 15973, 16104, 16173, 16137, + 16155, 16229, 16182, 16253, 16112, 15966, 15879, 15923, 16183, 16180, 16074, 16109, 16144, 16215, 15931, 16223, + 16127, 16022, 16216, 16217, 15943, 16158, 16197, 15448, 16188, 16105, 15965, 16145, 15882, 15513, 16037, 16158, + 15897, 16159, 16214, 16172, 15812, 16248, 16249, 16224, 16111, 16130, 16093, 15998, 16193, 16147, 16074, 16151, + 16229, 16146, 16163, 15970, 16229, 15888, 16060, 15815, 16095, 16251, 16228, 16005, 16223, 16048, 16241, 16237, + 16155, 16180, 16152, 15618, 16200, 16105, 16169, 16232, 16006, 15884, 15529, 15978, 16194, 16225, 15953, 15892, + 15912, 16121, 15871, 16054, 16184, 16240, 15609, 16249, 16049, 16220, 15895, 16199, 16251, 16252, 16116, 16192}, + { + 16140, 16171, 16035, 16159, 16038, 16007, 16068, 15957, 15257, 16151, 15919, 15385, 16165, 16145, 15931, 16116, + 16162, 15893, 16183, 15959, 16170, 16068, 15645, 16075, 16202, 15909, 16173, 16216, 16055, 16215, 16096, 16096, + 16150, 16207, 16254, 16145, 16154, 15876, 15878, 16024, 16162, 16141, 16135, 16167, 16010, 16219, 16192, 16165, + 16228, 16253, 15524, 16117, 16148, 16180, 16139, 16033, 16183, 15720, 16245, 15912, 16145, 16001, 16188, 16216, + 16155, 16215, 16246, 15824, 16212, 16073, 16084, 16019, 16088, 16058, 16020, 16011, 16167, 16061, 16236, 16110, + 16246, 16151, 15912, 16243, 15970, 16008, 16225, 16206, 16031, 15956, 16141, 16128, 16105, 16229, 16157, 16002, + 16137, 16141, 15901, 16147, 16159, 16175, 15888, 16173, 16032, 16180, 16075, 16089, 15625, 15694, 16119, 16060, + 16137, 15962, 16101, 15961, 16186, 16167, 16147, 15964, 15815, 16221, 16245, 15821, 15958, 16063, 16248, 16178, + 16201, 16095, 16151, 16132, 15819, 15971, 16165, 16166, 15987, 16242, 16251, 16192, 16113, 16235, 15934, 16109, + 16036, 16254, 15978, 16228, 16186, 15570, 16052, 15967, 15964, 16092, 16239, 16187, 16005, 16246, 16182, 16182, + 16241, 16014, 16250, 15995, 15935, 16224, 15813, 15785, 16223, 16006, 16176, 16201, 16093, 15876, 16243, 16010, + 16106, 16127, 16106, 16114, 16224, 16087, 16231, 16124, 16003, 16204, 16022, 16105, 16179, 16155, 16192, 16148, + 16240, 15878, 16185, 16216, 16134, 16148, 16045, 16142, 16139, 16183, 16151, 16224, 16217, 16174, 16039, 15967, + 16212, 15804, 16204, 15714, 16076, 15943, 16098, 16183, 16246, 16220, 16105, 16120, 16115, 15807, 16158, 16178, + 16193, 16105, 16066, 16233, 16161, 16026, 16231, 16197, 16218, 15581, 15981, 16186, 16225, 16181, 16121, 16090, + 15448, 16017, 16162, 16214, 16115, 16093, 16199, 15885, 15972, 16133, 16159, 15853, 15761, 15994, 15998, 15887, + 16199, 16228, 16199, 16214, 16024, 16044, 16159, 16193, 16222, 15945, 16243, 16173, 16172, 16196, 15794, 15906, + 16147, 16222, 16243, 16174, 15858, 15812, 16208, 16165, 16045, 16074, 16250, 16060, 16100, 16184, 16248, 16082, + 15525, 15842, 16151, 16114, 16169, 16101, 16163, 16249, 16075, 16060, 16172, 16229, 16204, 16210, 16216, 16148, + 16224, 16172, 16213, 16168, 16146, 16205, 15454, 16250, 15890, 16111, 16225, 16245, 16034, 16163, 16108, 15814, + 16179, 16137, 16130, 15999, 16112, 16145, 15934, 16157, 15725, 15879, 15887, 16158, 15953, 16255, 16189, 16150, + 15914, 15719, 15923, 16080, 15927, 15892, 16215, 16103, 16206, 15869, 16094, 16183, 15624, 16138, 15912, 15897, + 16123, 16221, 16199, 16121, 16180, 15757, 15562, 16121, 16231, 16248, 16191, 16217, 16031, 16074, 16251, 16218, + 15871, 16222, 15976, 15856, 16221, 16225, 16109, 16178, 15753, 16054, 15641, 16174, 16172, 16206, 16178, 16144, + 16063, 16190, 16184, 15910, 16002, 16132, 16082, 16249, 16215, 16104, 16173, 16240, 16130, 15790, 16013, 16145, + 16065, 15931, 16087, 16117, 15609, 16249, 16093, 16187, 16126, 16178, 16106, 16024, 15759, 16159, 16049, 16066, + 15989, 15236, 15955, 16128, 15938, 16150, 16147, 16220, 16211, 16196, 16163, 16199, 16151, 16077, 15390, 16220, + 15895, 15929, 15908, 16009, 16216, 16244, 16130, 16011, 16210, 16199, 16130, 15392, 16060, 16199, 16219, 15863, + 16228, 15905, 16251, 16201, 16203, 16082, 16174, 16142, 15973, 16061, 16232, 16252, 15792, 16029, 15884, 16004, + 16106, 16251, 15880, 16254, 16116, 15720, 16245, 16091, 15926, 15815, 15816, 15945, 16208, 16192, 16168, 15982, + 16210, 16063, 16243, 16079, 16199, 15790, 16241, 16250, 16148, 16188, 16250, 16136, 16156, 16084, 15517, 16194, + 15716, 16032, 16105, 15862, 16244, 15971, 16235, 16112, 16133, 16154, 16225, 15965, 16056, 15980, 16157, 15747, + 16026, 15832, 16102, 16247, 16145, 16023, 16164, 16069, 15966, 16183, 16084, 16189, 16177, 15882, 16118, 16074, + 16241, 16177, 16169, 16114, 15523, 16150, 15513, 15859, 16089, 16231, 16144, 16019, 16007, 15648, 16185, 16037, + 16176, 16158, 16174, 16145, 16020, 15934, 16098, 16168, 16158, 16034, 16155, 16102, 16221, 16068, 16198, 16016, + 16128, 15897, 16225, 16115, 16056, 16007, 16158, 16184, 16104, 16195, 16195, 16105, 16187, 15710, 16035, 15821, + 15695, 16173, 16191, 16056, 16169, 16149, 16045, 16231, 15745, 16221, 16137, 16191, 16241, 16232, 15674, 15948, + 16068, 15494, 16059, 16155, 15805, 16186, 16006, 16246, 15927, 16165, 16142, 16249, 16229, 16231, 16191, 15884, + 15958, 15511, 16150, 16237, 15496, 16182, 15966, 15919, 15529, 16021, 15919, 16038, 15383, 16219, 16253, 15786, + 16045, 15978, 16018, 16203, 16212, 16171, 15980, 16112, 16191, 16133, 16194, 15990, 16109, 16133, 16213, 15423, + 15966, 16141, 16122, 16225, 16173, 15973, 16161, 16200, 16220, 16107, 15907, 16055, 16225, 16101, 15879, 15632, + 16053, 16145, 16236, 15809, 16042, 16029, 16004, 16038, 16200, 16099, 16249, 15670, 16160, 16076, 16122, 16186, + 15706, 16092, 16054, 16047, 16183, 16162, 15782, 16147, 16217, 16243, 16142, 16072, 16245, 16204, 16075, 16026, + 16084, 16190, 16185, 16130, 15493, 16201, 16123, 16243, 16080, 16245, 15624, 15460, 15907, 16136, 16232, 15773, + 15744, 16198, 15922, 16029, 15419, 16137, 15933, 16082, 16112, 16239, 15981, 16165, 16245, 16136, 15891, 16216, + 16198, 16132, 16089, 16237, 16115, 15861, 16197, 16174, 16077, + }, + { + 16140, 16156, 16151, 15971, 16183, 16157, 16216, 16069, 16154, 16241, 16219, 16231, 16139, 16174, 16216, 16102, + 16088, 16056, 16250, 15972, 15716, 16228, 16154, 16243, 16102, 16174, 16189, 16100, 15523, 16101, 15648, 16216, + 16098, 16250, 16016, 16179, 16206, 15912, 16137, 16128, 16180, 16159, 16101, 15694, 15821, 16147, 15819, 16178, + 16235, 15987, 16052, 16254, 16182, 16239, 16035, 16195, 16231, 16191, 16068, 16191, 16165, 15805, 16150, 16231, + 16038, 15966, 16212, 15786, 16133, 16191, 16161, 16141, 16126, 16006, 15236, 16106, 16163, 16105, 16009, 16134, + 16060, 16174, 16082, 16098, 15884, 16178, 16091, 16218, 16210, 16017, 16159, 16156, 16165, 16095, 16068, 16231, + 16096, 16178, 16024, 15819, 16228, 15734, 15720, 16248, 16246, 16170, 16011, 16167, 16250, 15853, 15862, 16024, + 16056, 16196, 16023, 16208, 16118, 16082, 15859, 16075, 16176, 16172, 16034, 16225, 16225, 15999, 16112, 16008, + 16255, 16157, 16215, 16173, 15897, 16137, 16231, 16221, 16222, 16151, 15641, 16192, 15910, 16186, 16130, 16246, + 14440, 16187, 16223, 16149, 16031, 15674, 16012, 16246, 16089, 15958, 16204, 16021, 16226, 16018, 15934, 15990, + 16174, 16173, 16024, 16093, 15938, 16066, 16077, 16211, 16130, 15929, 15863, 16130, 15973, 16201, 16251, 15792, + 15816, 15720, 16079, 16168, 16068, 16171, 16116, 15919, 16202, 15959, 16207, 16055, 16135, 15876, 16117, 16192, + 16145, 16033, 16073, 16155, 16236, 16058, 16084, 16148, 16235, 16032, 15747, 16225, 15966, 16247, 16177, 16177, + 16144, 16150, 16145, 16185, 16221, 16168, 16007, 16128, 16157, 16031, 15914, 16141, 15869, 16075, 16199, 15961, + 16217, 15958, 16221, 15971, 16206, 15934, 16082, 15967, 16145, 16241, 16243, 15821, 16105, 15745, 16175, 15494, + 16119, 16142, 15964, 16237, 16201, 15383, 16242, 16171, 15978, 16213, 16187, 16200, 16220, 16178, 16145, 15955, + 16249, 16199, 16047, 16216, 16245, 16199, 16201, 16174, 16232, 16004, 16082, 15926, 16198, 16063, 16214, 16038, + 15761, 16145, 16044, 15645, 15794, 16096, 16165, 16162, 15525, 16253, 16060, 16245, 16213, 15824, 16245, 16167, + 16130, 16136, 16133, 16244, 16234, 15980, 15808, 16164, 16235, 16074, 16147, 16089, 15786, 16158, 16237, 16155, + 16014, 16115, 15887, 16145, 16080, 16189, 15624, 16103, 15757, 16123, 16251, 16248, 16178, 15976, 16063, 16174, + 16104, 16002, 16087, 15790, 16225, 16251, 16002, 16016, 16032, 16180, 15962, 16150, 16245, 15929, 16132, 15746, + 16113, 16131, 15570, 16120, 16182, 16148, 16055, 15759, 16042, 16150, 16076, 15390, 15782, 16011, 16026, 16228, + 16080, 16061, 16198, 15880, 15981, 15945, 16237, 16199, 16199, 15957, 15887, 16162, 16222, 15909, 16222, 16254, + 16250, 16167, 16114, 16148, 16204, 16001, 16205, 16084, 16108, 16110, 16035, 15517, 15385, 16112, 16170, 16026, + 16215, 16183, 15878, 16169, 16165, 16019, 16183, 16020, 16215, 16068, 16020, 16158, 16184, 15725, 15695, 15719, + 16221, 16094, 16059, 16121, 16249, 16031, 15496, 16225, 16219, 16178, 15980, 16249, 15423, 16065, 15956, 16250, + 15901, 16201, 16089, 16224, 16186, 16155, 16063, 16045, 16165, 15967, 16109, 16246, 15964, 16105, 16014, 15981, + 15879, 16107, 16038, 16236, 15706, 15670, 16243, 16183, 16185, 16204, 15460, 16123, 15419, 15773, 16136, 16112, + 16197, 16132, + }, + { + 16140, 16159, 16159, 16250, 16068, 16250, 16151, 16214, 16165, 15716, 16116, 15862, 16183, 16172, 16068, 16154, + 16202, 16056, 16216, 15812, 16096, 16102, 16207, 16023, 16154, 16248, 16024, 16189, 16135, 16118, 16219, 16249, + 16228, 15523, 16117, 15859, 16139, 16224, 15720, 15648, 16145, 16176, 16216, 16111, 16246, 16098, 16073, 16034, + 16088, 16130, 16011, 16016, 16236, 16225, 16151, 16158, 15970, 16104, 16206, 16105, 16141, 15927, 16229, 16173, + 16137, 16169, 16147, 16138, 15888, 16137, 16180, 16232, 15625, 15562, 16060, 16155, 16101, 16006, 16167, 16218, + 15815, 16229, 15821, 15884, 16248, 15753, 16095, 16182, 15819, 15529, 16166, 16190, 16251, 16253, 16235, 15978, + 16036, 16173, 16228, 16112, 16052, 16194, 16092, 16117, 16005, 15966, 16182, 16225, 16250, 15759, 16224, 16107, + 16223, 16225, 16201, 16150, 16243, 16236, 16127, 16029, 16224, 15390, 16124, 15670, 16022, 16122, 16155, 16011, + 16240, 16183, 16216, 16147, 16045, 16228, 16183, 16204, 16217, 16084, 15967, 16061, 16204, 16123, 15943, 16245, + 16246, 15880, 16120, 15773, 16158, 15922, 16105, 15945, 16161, 16112, 16197, 16165, 15981, 16199, 16181, 16132, + 15448, 16115, 16104, 16140, 16133, 16247, 15970, 16172, 15717, 16222, 15487, 15879, 16148, 15900, 16039, 16133, + 16131, 16254, 16001, 16101, 15793, 16240, 16091, 16181, 16197, 16085, 15982, 16050, 16166, 16103, 15878, 16178, + 15986, 16192, 15755, 16181, 16064, 15878, 16157, 15561, 16139, 15761, 15429, 16065, 16214, 16179, 16000, 16156, + 16123, 16154, 15623, 15896, 15620, 16021, 15852, 15820, 16214, 16084, 16199, 16130, 16133, 16035, 15761, 16235, + 15887, 16133, 16199, 15385, 16044, 15747, 16222, 16234, 16173, 16170, 15794, 15966, 16222, 15808, 15858, 16215, + 16165, 16177, 16250, 16235, 16184, 15878, 15525, 16144, 16114, 16147, 16163, 16165, 16060, 16145, 16204, 15786, + 16148, 16183, 16213, 16221, 16205, 16237, 15890, 16215, 16245, 16007, 16108, 16014, 16137, 16020, 16112, 16035, + 16157, 14440, 15887, 16243, 16255, 16231, 15914, 16223, 16080, 16105, 16215, 16068, 15869, 16031, 15624, 16175, + 15897, 16165, 16199, 16012, 15757, 16119, 16231, 16150, 16217, 16089, 16251, 15964, 16222, 16038, 16221, 16204, + 16178, 16201, 15641, 16212, 16206, 16226, 16063, 16242, 15910, 16133, 16082, 15934, 16104, 15978, 16130, 16161, + 16145, 16174, 16087, 16187, 16249, 15632, 16126, 16166, 16024, 15935, 16049, 16200, 15236, 16014, 15938, 15876, + 16220, 16092, 16163, 16138, 16077, 16231, 15895, 16142, 16009, 15875, 16130, 16148, 16199, 16130, 16060, 16038, + 15863, 16139, 16251, 15907, 16082, 16073, 15973, 15804, 16252, 16137, 15884, 15894, 16251, 16105, 16116, 15891, + 16091, 16244, 15816, 16233, 16192, 16174, 16210, 15907, 16079, 16225, 16251, 16067, 15910, 15896, 16032, 16011, + 16158, 16245, 16212, 16151, 16200, 16011, 15994, 16229, 16163, 16004, 16205, 16105, 16249, 15764, 16206, 16229, + 16019, 16248, 16063, 16247, 16123, 15964, 15760, 16077, 16253, 15518, 16163, 16080, 15991, 15568, 15675, 16140, + 16253, 16148, 16174, 15998, 16081, 16250, 16060, 16141, 16169, 16227, 16030, 16012, 15749, 16249, 16066, 16129, + 16241, 16007, 16188, 15257, 16156, 16093, 16194, 15931, 16105, 15893, 15971, 15998, 16133, 16075, 15965, 16173, + 16157, 16193, 15832, 16150, 16145, 16145, 16069, 16147, 16084, 16141, 15882, 16010, 16241, 16074, 16114, 15524, + 15513, 16180, 16231, 16151, 16007, 15912, 16037, 16188, 16174, 16229, 15934, 16212, 16158, 16019, 16102, 16146, + 16198, 16061, 15897, 16246, 16056, 16163, 16184, 16225, 16195, 15956, 15710, 15934, 15695, 16002, 16056, 15901, + 16045, 16150, 16221, 16032, 16241, 16089, 15948, 16206, 16059, 15962, 16186, 16186, 15927, 16221, 16249, 16245, + 16191, 16063, 15511, 16191, 15496, 16132, 15919, 16165, 15919, 15856, 16219, 16113, 16045, 16109, 16203, 16172, + 15980, 15570, 16133, 15964, 16109, 16132, 15423, 16182, 16122, 16014, 15973, 16013, 16220, 15785, 16055, 16176, + 15879, 16187, 16145, 16106, 16042, 16114, 16038, 15989, 16249, 16204, 16076, 16179, 15706, 16196, 16047, 16185, + 15782, 16148, 16243, 15908, 16245, 16224, 16026, 16039, 16185, 15392, 16201, 16076, 16080, 16183, 15460, 16203, + 16232, 15807, 16198, 16193, 15419, 16029, 16082, 16231, 15981, 15581, 16136, 16245, 16198, 16090, 16237, 16162, + 16197, 15982, 15995, 16181, 16091, 16042, 16212, 16040, 16162, 16030, 16018, 16040, 16057, 16100, 16170, 16014, + 16210, 16236, 16236, 16227, 16237, 16062, 16180, 16247, 16075, 15830, 16132, 16229, 16188, 16220, 15676, 16131, + 16218, 16134, 16084, 15916, 16189, 16050, 16089, 15577, 16100, 15687, 15982, 16130, 16126, 16192, 16056, 16230, + 16101, 16189, 16142, 16160, 16248, 16001, 16050, 16070, 16156, 15972, 16171, 15853, 16038, 16148, 16095, 16228, + 15919, 16024, 16145, 16032, 16231, 16243, 15959, 16196, 15645, 16225, 16178, 16174, 16055, 16208, 16096, 16247, + 15819, 16100, 15876, 16082, 16162, 16177, 15734, 16101, 16192, 16075, 16253, 16150, 16248, 16216, 16033, 16172, + 16245, 16185, 16170, 16250, 16155, 16225, 15824, 16168, 16167, 16179, 16058, 15999, 16167, 16128, 16223, 15879, + 15912, 15953, 16008, 16195, 16048, 15923, 16128, 15892, 16157, 16191, 16241, 16183, 16159, 15912, 16173, 16191, + 16237, 16180, 15694, 16121, 16137, 15805, 16155, 16074, 16147, 15871, 16221, 16231, 16180, 16109, 16178, 16054, + 16151, 15966, 16152, 16144, 15987, 16184, 16192, 15786, 15618, 16215, 16254, 16240, 16186, 16191, 16200, 15931, + 16239, 15609, 16246, 16141, 16027, 16106, 15995, 16159, 15813, 15907, 15894, 16128, 16093, 16147, 16010, 15809, + 16226, 16151, 16087, 16220, 16003, 16160, 15778, 16244, 16192, 16210, 15878, 16162, 16000, 16219, 16142, 15905, + 16151, 16075, 15799, 16142, 16212, 16232, 15714, 16243, 16173, 16106, 16220, 16254, 16115, 15744, 16172, 15815, + 16066, 16208, 16026, 16239, 16207, 16243, 16186, 15790, 16121, 16089, 16239, 16174, 16023, 16163, 16228, 15607, + 16208, 16135, 15863, 16173, 15983, 16162, 16003, 16159, 16253, 16140, 16055, 16181, 16252, 16241, 15991, 16225, + 16176, 16025, 16067, 16181, 15999, 16011, 16069, 15361, 16248, 16228, 15977, 16139, 15727, 16107, 16178, 15953, + 15832, 15895, 16234, 16062, 16231, 16173, 16122, 16016, 16187, 15560, 16229, 16046, 16243, 16219, 15849, 16135, + }}; + +static vector weight_tensor_shape = {{8, 8, 3, 3}, {10, 10, 3, 3}, {12, 8, 3, 3}, {8, 15, 3, 3}}; +static vector bias_tensor_shape = {{1, 1, 1, 32}, {1, 1, 1, 60}, {12, 1, 1, 320}, {8, 1, 1, 48}}; +static vector shards = {8, 3, 5, 4}; + +template +static uint32_t compare_out_with_ref(const owned_buffer::Buffer& out_buf, T& ref) { + uint32_t diff = 0, j = 0; + for (uint32_t i = 0; i < out_buf.size(); i++) { + if (out_buf[i] == 0) { + continue; + } + if (out_buf[i] != ref[j]) { + log_info( + tt::LogTest, + "Error at i = {}, Golden = {}, Calculated = {}", + i, + out_buf[i].to_float(), + ref[j].to_float()); + diff++; + } + j++; + } + return diff; +} + +static void test_convert_conv_weight_tensor_to_tiled_layout_block_sharded() { + tt::log_info(tt::LogTest, "Running {}", __func__); + for (auto i = 0; i < weight_tensor_shape.size(); i++) { + auto input_tensor = ttnn::numpy::zeros(weight_tensor_shape[i]); + auto input_buffer = owned_buffer::get_as(input_tensor); + for (auto j = 0; j < input_buffer.size(); j++) { + input_buffer[j] = ref_weight_in[i][j]; + } + auto output_tensor = + convert_conv_weight_tensor_to_tiled_layout_block_sharded(input_tensor, shards[i], DataType::BFLOAT16); + auto out_buffer = owned_buffer::get_as(output_tensor); + + TT_FATAL(compare_out_with_ref(out_buffer, ref_weight_out[i]) == 0, "Error"); + } +} + +static void test_convert_conv_bias_tensor_to_tiled_layout_block_sharded() { + tt::log_info(tt::LogTest, "Running {}", __func__); + for (auto i = 0; i < bias_tensor_shape.size(); i++) { + auto input_tensor = + ttnn::numpy::random::random(bias_tensor_shape[i], DataType::BFLOAT16).to(Layout::ROW_MAJOR).cpu(); + auto input_buffer = owned_buffer::get_as(input_tensor); + auto output_tensor = + convert_conv_bias_tensor_to_tiled_layout_block_sharded(input_tensor, shards[i], DataType::BFLOAT16); + auto out_buffer = owned_buffer::get_as(output_tensor); + /* Expected output should be same as input buffer except some padding*/ + TT_FATAL(compare_out_with_ref(out_buffer, input_buffer) == 0, "Error"); + } +} + +int main() { + tt::log_info(tt::LogTest, "Tests for Tensor utils starts"); + test_convert_conv_weight_tensor_to_tiled_layout_block_sharded(); + test_convert_conv_bias_tensor_to_tiled_layout_block_sharded(); + tt::log_info(tt::LogTest, "Tests for Tensor utils ends"); + return 0; +} diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 6322819f593..3e5f5f857f9 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -2620,6 +2620,92 @@ def test_non_tile_multiple_height_conv_wh( ) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + ( + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 128, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 192, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 256, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 384, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 448, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 512, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 576, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 128, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 192, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 256, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 384, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 448, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 512, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 576, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 128, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 320, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16, ttnn.bfloat8_b], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("enable_auto_formatting", [False]) +def test_non_tile_multiple_width_conv_wh( + device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + enable_auto_formatting, +): + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + use_shallow_conv_variant=(input_channels == 16), + transpose_mcast=use_1d_systolic_array, + enable_auto_formatting=enable_auto_formatting, + padded_input_channels=16 if input_channels == 16 else None, + output_layout=ttnn.ROW_MAJOR_LAYOUT, + ) + + @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) def test_shallow_conv_with_tiled_input(device): diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 7539700b27a..5f9ba6f0ea9 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -80,8 +80,17 @@ Result conv2d( ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt); } + ShardOrientation shard_orientation = + conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; + auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; + auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; + bool is_non_tile_mul_width = + (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 && + (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && + conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; + auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( - device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv); + device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv, is_non_tile_mul_width); if (tensor_manipulated) { if (conv_config.deallocate_activation) { ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op @@ -96,6 +105,9 @@ Result conv2d( uint32_t out_channels_padded = tt::round_up( out_channels, get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH); + if(is_non_tile_mul_width) { + out_channels_padded = tt::round_up(out_channels, 32); + } MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{1, 1, nhw_out, out_channels_padded}), output_parallel_config, @@ -110,6 +122,9 @@ Result conv2d( uint32_t in_channels_padded = tt::round_up( in_channels, get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); + if(is_non_tile_mul_width){ + in_channels_padded = tt::round_up(in_channels, conv_config.input_channels_alignment); + } uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) * conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT; @@ -141,7 +156,9 @@ Result conv2d( device, groups, opt_conv_op_block_config.act_block_h_ntiles, - input_width); + input_width, + true, + is_non_tile_mul_width); } // if 1x1 conv w/ stride 1, convert input tensor to tile layout if required if (mm_conv) { diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 9cda5a46af8..6ac28cf56ca 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -321,16 +321,16 @@ void py_bind_conv2d(py::module& module) { py::arg("grid_size"), py::arg("num_cores_nhw") = 1, py::arg("num_cores_c") = 1, - py::arg("per_core_out_matrix_height_ntiles").noconvert() = 1, - py::arg("per_core_out_matrix_width_ntiles").noconvert() = 1) + py::arg("per_core_out_matrix_height").noconvert(), + py::arg("per_core_out_matrix_width").noconvert()) .def_property_readonly("grid_size", [](OptimizedConvParallelizationConfig const& c) { return c.grid_size; }) .def_property_readonly( "num_cores_nhw", [](OptimizedConvParallelizationConfig const& c) { return c.num_cores_nhw; }) .def_property_readonly( - "per_core_out_matrix_height_ntiles", - [](OptimizedConvParallelizationConfig const& c) { return c.per_core_out_matrix_height_ntiles; }) - .def_property_readonly("per_core_out_matrix_width_ntiles", [](OptimizedConvParallelizationConfig const& c) { - return c.per_core_out_matrix_width_ntiles; + "per_core_out_matrix_height", + [](OptimizedConvParallelizationConfig const& c) { return c.per_core_out_matrix_height; }) + .def_property_readonly("per_core_out_matrix_width", [](OptimizedConvParallelizationConfig const& c) { + return c.per_core_out_matrix_width; }); py::class_(module, "OptimizedConvBlockConfig") diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index 56783493763..bf215230584 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -38,6 +38,12 @@ uint32_t find_closest_largest_divisor(uint32_t num1, uint32_t num2, uint32_t sta return divisor; } +uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + return divisor; +} + uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; uint32_t padded_num = round_up(num, divisor); @@ -85,6 +91,41 @@ Tensor convert_conv_weight_tensor_to_grouped_layout(const Tensor& conv_weight_te return tt::tt_metal::convert_conv_weight_tensor_to_grouped_layout(std::move(conv_weight_tensor), num_groups, output_dtype); } +ParallelConfig determine_parallel_config_non_tile_mul_width( + const TensorMemoryLayout shard_layout, + uint32_t batch_size, + uint32_t input_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t output_channels, + const CoreCoord& compute_grid_size, + ShardOrientation block_shard_orientation) { + + uint32_t effective_tile_height = 1; + uint32_t effective_tile_width = 1; + CoreRangeSet grid; + uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT); + uint32_t start_divisor = + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; + uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); + + uint32_t start_divisor_c = + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x; + uint32_t num_cores_c = find_closest_common_largest_divisor(output_channels, input_channels, start_divisor_c); + uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; + uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; + CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); + grid = CoreRangeSet({core_range}); + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation : ShardOrientation::ROW_MAJOR; + ParallelConfig pconfig = { + .grid = grid, + .shard_scheme = shard_layout, + .shard_orientation = block_shard_orientation}; + + return pconfig; + +} + ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -242,14 +283,11 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); const auto& shard_spec = conv_output_mem_config.shard_spec.value(); const auto& shard_shape = shard_spec.shape; - TT_ASSERT(shard_shape[1] % 32 == 0); uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); return { .grid_size = shard_spec.grid.bounding_box().grid_size(), .num_cores_nhw = num_cores_nhw, .num_cores_c = num_cores_c, - .per_core_out_matrix_height_ntiles = per_core_out_matrix_height_ntiles, - .per_core_out_matrix_width_ntiles = shard_shape[1] / 32, .per_core_out_matrix_height = shard_shape[0], .per_core_out_matrix_width = shard_shape[1], }; @@ -304,7 +342,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); } - uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; + uint32_t act_block_h_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); if (act_block_h_override > 0) { if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { @@ -324,22 +362,22 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } } + auto grid_size = parallel_config.grid.bounding_box().grid_size(); + uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 + : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y + : grid_size.x; + TT_ASSERT(padded_in_channels % act_c_num_blocks == 0); uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED - ? round_up(padded_in_channels * window_w, 32) - : padded_in_channels; + ? round_up(padded_in_channels * window_w, 32) + : round_up((padded_in_channels / act_c_num_blocks) * window_h * window_w, tt::constants::TILE_WIDTH); if(parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { TT_ASSERT(padded_in_channels % (32 * parallel_config.grid.num_cores() * act_block_w_div) == 0); act_block_w = (padded_in_channels * window_h * window_w)/(parallel_config.grid.num_cores() * act_block_w_div); } TT_ASSERT(act_block_w % 32 == 0); uint32_t act_block_w_ntiles = act_block_w / 32; - auto grid_size = parallel_config.grid.bounding_box().grid_size(); - uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 - : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y - : grid_size.x; - uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; - uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntiles; - + uint32_t out_block_h_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t weight_block_w_ntiles = div_up(conv_op_parallel_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); return { @@ -439,7 +477,8 @@ std::tuple get_conv_padded_input_sh uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv) { + bool is_mm_conv, + bool is_non_tile_mul_width) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -510,17 +549,30 @@ std::tuple get_conv_padded_input_sh if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - ParallelConfig optimal_parallel_config = determine_parallel_config( - shard_layout, - batch_size, - in_channels, - height, - width, - out_channels, - device->compute_with_storage_grid_size(), - block_shard_orientation, - !is_mm_conv, - !use_non_tile_height); + ParallelConfig optimal_parallel_config; + if (is_non_tile_mul_width) { + optimal_parallel_config = determine_parallel_config_non_tile_mul_width( + shard_layout, + batch_size, + in_channels, + height, + width, + out_channels, + device->compute_with_storage_grid_size(), + block_shard_orientation); + } else { + optimal_parallel_config = determine_parallel_config( + shard_layout, + batch_size, + in_channels, + height, + width, + out_channels, + device->compute_with_storage_grid_size(), + block_shard_orientation, + !is_mm_conv, + !use_non_tile_height); + } if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); @@ -555,6 +607,10 @@ std::tuple get_conv_padded_input_sh TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); uint32_t input_tensor_width_snapped_to_channels_alignment = tt::round_up(input_tensor.get_shape()[3], input_num_cores_c * conv_config.input_channels_alignment); + if(is_non_tile_mul_width) { + input_tensor_width_snapped_to_channels_alignment = + tt::round_up(input_tensor.get_shape()[3], conv_config.input_channels_alignment); + } auto input_padded_shape = ttnn::Shape(std::array{ 1, @@ -584,7 +640,8 @@ std::tuple shard_or_re uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv) { + bool is_mm_conv, + bool is_non_tile_mul_width) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); auto compute_grid_size = device->compute_with_storage_grid_size(); @@ -599,7 +656,8 @@ std::tuple shard_or_re width, in_channels, out_channels, - is_mm_conv); + is_mm_conv, + is_non_tile_mul_width); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, @@ -696,8 +754,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, - .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, + .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), + .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .fuse_batch = true, .mcast_in0 = false}; if (activation != "") { @@ -705,16 +763,15 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co } return matmul_config; } else { - TT_ASSERT(conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0); ttnn::operations::matmul::MatmulMultiCoreReuseMultiCastProgramConfig matmul_config = { .compute_with_storage_grid_size = conv_parallelization_config.grid_size, - .in0_block_w = conv_blocking_config.act_block_w_ntiles / grid_size_along_c, + .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .out_block_h = conv_parallelization_config.per_core_out_matrix_height_ntiles, - .out_block_w = conv_parallelization_config.per_core_out_matrix_width_ntiles, - .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntiles, - .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntiles, + .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), + .out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), + .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), + .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .transpose_mcast = transpose_mcast}; if (activation != "") { matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); @@ -800,7 +857,8 @@ template std::tuple get_conv_padded uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv); + bool is_mm_conv, + bool is_non_tile_mul_width); template std::tuple get_conv_padded_input_shape_and_mem_config( MeshDevice * device, @@ -811,7 +869,8 @@ template std::tuple get_conv_padded uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv); + bool is_mm_conv, + bool is_non_tile_mul_width); template std::tuple shard_or_reshard_tensor_if_required( Device* device, @@ -822,7 +881,8 @@ template std::tuple sh uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv); + bool is_mm_conv, + bool is_non_tile_mul_width); template std::tuple shard_or_reshard_tensor_if_required( MeshDevice * device, @@ -833,7 +893,8 @@ template std::tuple sh uint32_t width, uint32_t in_channels, uint32_t out_channel, - bool is_mm_conv); + bool is_mm_conv, + bool is_non_tile_mul_width); } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index b837b3ca81e..9b9645f821f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -171,7 +171,8 @@ std::tuple get_conv_padded_input_sh uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv); + bool is_mm_conv, + bool is_non_tile_mul_width=false); OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); @@ -200,7 +201,8 @@ std::tuple #include #include "conv2d_op.hpp" +#include "common/math.hpp" #include "common/math.hpp" #include "tt_metal/host_api.hpp" @@ -110,7 +111,8 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const TT_FATAL((this->dtype == DataType::BFLOAT16) || (this->dtype == DataType::FLOAT32), "Error"); } if (this->memory_config.is_sharded()) { - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; + uint32_t out_block_h_ntiles = optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + uint32_t per_core_out_matrix_width_ntiles = optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( input_tensor_a.get_legacy_shape(), @@ -119,7 +121,7 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const out_block_h_ntiles); uint32_t out_width_ntiles = this->compute_output_shapes(input_tensors).at(0)[-1] / TILE_WIDTH; if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_FATAL(this->parallelization_config.per_core_out_matrix_width_ntiles == out_width_ntiles, "Error"); + TT_FATAL(per_core_out_matrix_width_ntiles == out_width_ntiles, "Error"); TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { // For block sharded, out_width per core is shard width, and this is split along row @@ -129,8 +131,8 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const } else { out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.x); } - TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); } + TT_FATAL(this->block_config.out_subblock_w_ntiles == per_core_out_matrix_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); } } @@ -154,9 +156,7 @@ std::vector OptimizedConvNew::compute_output_shapes(c // Tiled output shape is padded shape. Padded to tile shape. auto shape_w = batch_size * conv_output_h * conv_output_w; auto shape_c = output_channels; - auto padded_shape_w = this->use_non_tile_height ? - parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height - : parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT; + auto padded_shape_w = this->use_non_tile_height ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height : parallelization_config.num_cores_nhw * tt::round_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); auto output_padding = Padding( {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); @@ -179,8 +179,10 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectorparallelization_config.per_core_out_matrix_height_ntiles; - shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, output_shape[-1]}; + num_cores = total_height_tiles / tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); + + shard_shape = {optimized_conv_op_utils::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, output_shape[-1]}; } CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; @@ -189,7 +191,7 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectordtype, output_layout, input_tensor.device(), mem_config)}; } else if(this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; - std::array shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, this->parallelization_config.per_core_out_matrix_width_ntiles * TILE_WIDTH}; + std::array shard_shape = {tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH}; auto shard_grid = this->memory_config.shard_spec.value().grid; auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation}; auto mem_config = this->memory_config; @@ -197,35 +199,10 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectordtype, output_layout, input_tensor.device(), mem_config)}; } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - auto [act_matrix_shape, act_matrix_shape_unpadded] = - optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( - this->input_tensor_shape, - sliding_window_config, - this->parallelization_config.num_cores_nhw, - this->parallelization_config.per_core_out_matrix_height_ntiles); - uint32_t act_matrix_height = (uint32_t) act_matrix_shape[1]; - uint32_t act_matrix_height_ntiles = act_matrix_height / TILE_HEIGHT; - uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / this->parallelization_config.per_core_out_matrix_height_ntiles; - uint32_t weight_matrix_width = weight_tensor.get_legacy_shape()[-1]; - uint32_t weight_matrix_width_ntiles = weight_matrix_width / TILE_WIDTH; - uint32_t num_weight_slices_width = weight_matrix_width_ntiles / this->parallelization_config.per_core_out_matrix_width_ntiles ; - uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; - log_debug(tt::LogOp, "Total active num cores: {}", total_active_num_cores); - log_debug(tt::LogOp, "Parallelization config grid size: {}", this->parallelization_config.grid_size.str()); - uint32_t num_cores_x = this->parallelization_config.grid_size.x; - uint32_t num_cores_y = this->parallelization_config.grid_size.y; - CoreRangeSet shard_grid = - CoreRangeSet(CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1})); - log_debug(tt::LogOp, "Calculated shard_grid: {}", shard_grid.str()); - std::array shard_shape = {this->parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT, this->parallelization_config.per_core_out_matrix_width_ntiles * TILE_WIDTH}; - auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation}; - auto mem_config = this->memory_config; - mem_config.shard_spec = shard_spec; - return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)}; + return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), this->memory_config)}; } else { TT_THROW("Unsupported shard scheme"); } - } return operation::generic_create_output_tensors(*this, input_tensors, this->dtype, output_layout, this->memory_config); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 88151d5a83e..830ca917e33 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -23,8 +23,6 @@ struct OptimizedConvParallelizationConfig { CoreCoord grid_size; // (x,y) uint32_t num_cores_nhw = 1; uint32_t num_cores_c = 1; - uint32_t per_core_out_matrix_height_ntiles = 1; - uint32_t per_core_out_matrix_width_ntiles = 1; uint32_t per_core_out_matrix_height = 1; uint32_t per_core_out_matrix_width = 1; // std::size_t in0_block_w; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index de10fb342a7..0b452a583df 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -33,6 +33,7 @@ const uint32_t tilize_mode_tilized_act_cb = CBIndex::c_25; const uint32_t untilize_mode_reblock_cb = CBIndex::c_26; const uint32_t out0_cb = CBIndex::c_16; const uint32_t temp_sum_cb = CBIndex::c_27; +const uint32_t untilized_padded_out_cb = CBIndex::c_28; } // namespace CMAKE_UNIQUE_NAMESPACE } // namespace @@ -184,6 +185,7 @@ std::tuple create_CBs_for_sharded_input_v2( CBHandle cb_output = 0; if (untilize_out) { + auto output_shard_shape = output.shard_spec().value().shape; CircularBufferConfig cb_matmul_partials_config = CircularBufferConfig(num_output_tiles * interm0_single_tile_size, {{matmul_partials_cb, interm0_df}}) .set_page_size(matmul_partials_cb, interm0_single_tile_size); @@ -195,31 +197,50 @@ std::tuple create_CBs_for_sharded_input_v2( num_output_tiles, interm0_single_tile_size); - // Supposed to be a small CB only responsible for reorganizing - // the output blocks to fill the whole "per core output block width" - CircularBufferConfig cb_reblock_config = - CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) - .set_page_size(untilize_mode_reblock_cb, out_tile_size); - auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); - log_debug( - LogOp, - "Reblock CB: {}, npages: {}, pagesize: {}", - untilize_mode_reblock_cb, - num_reblock_cb_tiles, - out_tile_size); - - auto shard_shape = output.shard_spec().value().shape; - uint32_t aligned_output_stick_nbytes = - use_non_tile_height ? shard_shape[1] * output.element_size() : out_tile_size; - uint32_t aligned_output_num_pages = use_non_tile_height ? shard_shape[0] : num_writer_output_tiles; - CircularBufferConfig cb_output_config = - CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) - .set_page_size(out0_cb, aligned_output_stick_nbytes); - - if (output.is_sharded()) { + bool need_unpad_after_untilize = + output_shard_shape[1] * output_shard_shape[0] < num_writer_output_tiles * TILE_HW; + // If only width is non-tile multiple + if (need_unpad_after_untilize && !use_non_tile_height && weight_width_sliced) { + uint32_t num_bytes_for_df = datum_size(out_df); + CircularBufferConfig compute_cb_output_config = + CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{untilized_padded_out_cb, out_df}}) + .set_page_size(untilized_padded_out_cb, out_tile_size); + auto compute_cb_output = tt_metal::CreateCircularBuffer(program, core, compute_cb_output_config); + log_debug( + LogOp, + "untilized padded out CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", + untilized_padded_out_cb, + num_writer_output_tiles, + out_tile_size * num_bytes_for_df); + CircularBufferConfig cb_output_config = + CircularBufferConfig( + num_bytes_for_df * output_shard_shape[0] * output_shard_shape[1], {{out0_cb, out_df}}) + .set_page_size(out0_cb, output_shard_shape[1] * num_bytes_for_df); cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); + cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); + log_debug( + LogOp, + "output CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", + out0_cb, + output_shard_shape[0], + output_shard_shape[1] * num_bytes_for_df); + } else { + auto shard_shape = output.shard_spec().value().shape; + uint32_t aligned_output_stick_nbytes = + use_non_tile_height ? shard_shape[1] * output.element_size() : out_tile_size; + uint32_t aligned_output_num_pages = use_non_tile_height ? shard_shape[0] : num_writer_output_tiles; + CircularBufferConfig cb_output_config = + CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) + .set_page_size(out0_cb, aligned_output_stick_nbytes); + cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); + cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); + log_debug( + LogOp, + "output CB: {}, npages: {}, pagesize: {}", + out0_cb, + aligned_output_num_pages, + aligned_output_stick_nbytes); } - cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); } else { // Share buffer if same data format if (interm0_df == out_df) { @@ -406,8 +427,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( TT_FATAL(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntiles; - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; + uint32_t weight_block_w_ntiles = div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); + uint32_t out_block_h_ntiles = div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; @@ -495,6 +516,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( // TODO: Only updated variables which is affected, but there may be more that needs to account for this // TODO: Loop naming in reader, writer, and compute kernels could also be cleaned up // TODO: Can conv_act_c_blocks be same as num_blocks_act_w? + auto a_shard_spec = a.shard_spec().value(); auto shard_shape = a.shard_spec().value().shape; // parallelization config @@ -503,14 +525,16 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_cores_y = p_config.grid_size.y; uint32_t total_num_cores = num_cores_x * num_cores_y; - uint32_t per_core_out_matrix_height_ntiles = p_config.per_core_out_matrix_height_ntiles; - uint32_t per_core_out_matrix_width_ntiles = p_config.per_core_out_matrix_width_ntiles; + uint32_t per_core_out_matrix_width_ntiles = div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); + uint32_t per_core_out_matrix_height_ntiles = div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + bool block_sharded = a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED; // weight_width_sliced determines is 1d-sysarr-conv or 2d-sysarr-conv bool weight_width_sliced = per_core_out_matrix_width_ntiles < weight_matrix_width_ntiles; uint32_t conv_act_c_blocks = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; uint32_t input_channels_padded = 0; if (weight_width_sliced) { + conv_act_c_blocks = a_shard_spec.orientation == ShardOrientation::ROW_MAJOR ? num_cores_x : num_cores_y; if (transpose_mcast) { TT_FATAL(conv_act_c_blocks == num_cores_y, "Expected conv_act_c_blocks to be equal to height of grid"); input_channels_padded = shard_shape[1] * num_cores_y; @@ -574,6 +598,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( assert(act_matrix_shape[0] == 1); uint32_t act_matrix_height = (uint32_t)act_matrix_shape[1]; uint32_t act_matrix_width = (uint32_t)act_matrix_shape[2]; + if (block_sharded) { + act_matrix_width = + round_up((input_channels_padded / conv_act_c_blocks) * filter_w * filter_h, TILE_WIDTH) * conv_act_c_blocks; + } uint32_t act_matrix_height_unpadded = (uint32_t)act_matrix_shape_unpadded[1]; uint32_t act_matrix_width_unpadded = (uint32_t)act_matrix_shape_unpadded[2]; @@ -629,7 +657,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_blocks_act_h = act_matrix_height_ntiles / act_block_h_ntiles; uint32_t num_blocks_out_h = act_matrix_height_ntiles / out_block_h_ntiles; - uint32_t num_blocks_act_w = act_matrix_width_ntiles / act_block_w_ntiles; + uint32_t num_blocks_act_w = a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ? 1 : filter_h; uint32_t num_blocks_weight_w = weight_matrix_width_ntiles / weight_block_w_ntiles; // act block info @@ -657,11 +685,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( log_debug(LogOp, "act_block_num_tiles_split: {}", act_block_num_tiles_split); log_debug(LogOp, "act_block_num_tiles_split_last: {}", act_block_num_tiles_split_last); - TT_FATAL( - (act_block_w_datums == round_up(conv_act_size_c * filter_w, TILE_WIDTH)) || - ((act_block_w_datums <= conv_act_size_c) && (conv_act_size_c % act_block_w_datums == 0)), - "Error"); - // weight block info uint32_t weight_block_w_datums = weight_matrix_width / num_blocks_weight_w; assert(weight_block_w_ntiles % out_subblock_w_ntiles == 0); @@ -682,9 +705,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( : (output_channels_padded_to_tile_width % weight_block_w_datums); assert(last_block_width_datums % TILE_WIDTH == 0); - // sanity check - assert(num_blocks_output_w == num_blocks_weight_w); - uint32_t out_block_h_datums = out_block_h_ntiles * TILE_HEIGHT; tt_metal::Buffer* src0_dram_buffer = a.buffer(); @@ -906,7 +926,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( } if (has_bias) { - TT_FATAL(bias_ntiles % num_weight_slices_width == 0, "Error"); TT_FATAL(bias_ntiles == weight_matrix_width_ntiles, "Error"); } uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width; @@ -998,6 +1017,11 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles / conv_act_c_blocks; bool fully_buffer_weights = false; uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; + + if (block_sharded) { + num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles; + num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles; + } uint32_t num_act_cb_second_reader_tiles = 0; // TODO: This flag should be set in kernel logic but need this for create_CB if (weight_width_sliced) { @@ -1005,8 +1029,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( // window before pushing in reader/writer // TODO: Generalize this to not make this assumption read_window_in_inner_loop = true; - num_weight_cb_tiles *= filter_h * filter_w; - num_act_cb_tiles *= filter_h * filter_w; + if (!block_sharded) { + num_weight_cb_tiles *= filter_h * filter_w; + num_act_cb_tiles *= filter_h * filter_w; + } } else if (num_blocks_act_h_per_core > 1) { fully_buffer_weights = true; } @@ -1048,7 +1074,11 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t conv_act_c_read_bytes = conv_act_size_c * a.element_size() / conv_act_c_blocks; uint32_t act_block_w_extra_align_bytes = - (round_up(conv_act_size_c * filter_w, TILE_WIDTH) - (conv_act_size_c * filter_w)) * a.element_size(); + block_sharded ? (round_up(a_shard_spec.shape[1] * filter_h * filter_w, TILE_WIDTH) - + (a_shard_spec.shape[1] * filter_h * filter_w)) * + a.element_size() + : (round_up(a_shard_spec.shape[1] * filter_w, TILE_WIDTH) - (a_shard_spec.shape[1] * filter_w)) * + a.element_size(); uint32_t in0_block_w = act_block_w_ntiles / conv_act_c_blocks; uint32_t in0_block_num_tiles = act_block_num_tiles / conv_act_c_blocks; @@ -1226,7 +1256,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( } uint32_t reader_arg_act_block_h_datums = (split_reader ? act_block_h_datums_split : act_block_h_datums); TT_FATAL(reader_arg_act_block_h_datums % 2 == 0, "2 Indices are packed in one uint32_t word."); - + if (block_sharded) { + in0_block_num_tiles = act_block_num_tiles; + in1_block_num_tiles = weight_block_num_tiles; + in0_block_w = act_block_w_ntiles; + in0_num_blocks_w = 1 * conv_act_c_blocks; + } reader_compile_time_args = { (uint32_t)(src0_dram_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0), (uint32_t)stride_h, @@ -1299,7 +1334,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( num_blocks_act_w, // = number of blocks of weight in height dim in1_block_num_tiles, conv_act_c_blocks, - weight_block_h_ntiles / conv_act_c_blocks, + weight_block_h_ntiles, weight_block_w_ntiles, weight_matrix_width_ntiles, // weight_stride_h weight_matrix_width_ntiles * weight_block_h_ntiles, // weight_next_block_stride_h, @@ -1345,12 +1380,25 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( writer_compile_time_args.insert( writer_compile_time_args.end(), split_reader_args.begin(), split_reader_args.end()); } + bool need_unpad_after_untilize = + parallelization_config.per_core_out_matrix_width < per_core_out_matrix_width_ntiles * TILE_WIDTH; + if (need_unpad_after_untilize) { + TT_FATAL(block_sharded, "Need to handle this case for non-sliced weights"); + TT_FATAL(untilize_out, "Cannot support non-tile multiple shard width with tilized output"); + writer_compile_time_args.push_back(per_core_out_matrix_width_ntiles); + writer_compile_time_args.push_back(per_core_out_matrix_width_ntiles * TILE_WIDTH * 2); + writer_compile_time_args.push_back(parallelization_config.per_core_out_matrix_width * 2); + writer_compile_time_args.push_back(untilized_padded_out_cb); + writer_defines["UNPAD_UNTILIZE_OUT"] = 1; + writer_mcast_sender_defines["UNPAD_UNTILIZE_OUT"] = 1; + } + uint32_t compute_output_cb = need_unpad_after_untilize ? untilized_padded_out_cb : out0_cb; std::vector compute_kernel_args = { in0_block_w, act_num_subblocks, - in0_block_num_tiles, - in0_subblock_num_tiles, + act_block_num_tiles, + act_subblock_num_tiles, act_subblock_h_ntiles, weight_num_subblocks, @@ -1369,6 +1417,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( untilize_out, bias_ntiles_per_core, + compute_output_cb, aligned_output_num_pages, use_non_tile_height}; @@ -1573,6 +1622,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( writer_rt_args.push_back(num_cores_x - 1); // weights_mcast_num_cores writer_rt_args.push_back(weights_mcast_sender_semaphore_id); writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); + writer_rt_args.push_back(output.buffer()->aligned_page_size()); SetRuntimeArgs(program, writer_mcast_sender_id, core, writer_rt_args); } else { @@ -1581,6 +1631,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( writer_rt_args.push_back(right_core_physical.y); // weights_mcast_sender_noc_y writer_rt_args.push_back(weights_mcast_sender_semaphore_id); writer_rt_args.push_back(weights_mcast_receiver_semaphore_id); + writer_rt_args.push_back(output.buffer()->aligned_page_size()); SetRuntimeArgs(program, writer_mcast_receiver_id, core, writer_rt_args); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 53f8e5bceab..41f00c99ff0 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -62,8 +62,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( TT_FATAL(output_channels <= b.get_legacy_shape()[3], "Invalid weight shape. Incorrect weight tensor."); uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntiles; - uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntiles; + uint32_t weight_block_w_ntiles = + div_up(parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); + uint32_t out_block_h_ntiles = div_up(parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; @@ -168,9 +169,9 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( uint32_t num_cores_y = p_config.grid_size.y; TT_FATAL(num_cores_x < 13, "Error"); TT_FATAL(num_cores_y < 10, "Error"); - uint32_t per_core_out_matrix_height_ntiles = p_config.per_core_out_matrix_height_ntiles; - uint32_t per_core_out_matrix_width_ntiles = p_config.per_core_out_matrix_width_ntiles; - + uint32_t per_core_out_matrix_height_ntiles = + div_up(p_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t per_core_out_matrix_width_ntiles = div_up(p_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); // weight_width_sliced determines is 1d-sysarr-conv or 2d-sysarr-conv bool weight_width_sliced = per_core_out_matrix_width_ntiles < weight_matrix_width_ntiles; // uint32_t conv_act_c_blocks = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; @@ -666,6 +667,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( bias_ntiles_per_core, + out0_cb, num_output_tiles, use_non_tile_height, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index 047aaef8d68..f2ad3573b3f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -100,11 +100,12 @@ void MAIN { constexpr uint32_t out_subblock_num_tiles = get_compile_time_arg_val(13); // out_subblock_h * out_subblock_w; constexpr bool tilize_in0 = get_compile_time_arg_val(14); constexpr bool untilize_out = get_compile_time_arg_val(15); - uint32_t output_rows_h = get_compile_time_arg_val(17); - constexpr bool is_non_tile_height = get_compile_time_arg_val(18); + constexpr uint32_t out_cb_id = get_compile_time_arg_val(17); + uint32_t output_rows_h = get_compile_time_arg_val(18); + constexpr bool is_non_tile_height = get_compile_time_arg_val(19); #ifdef WIDTH_SHARDED - constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(19); + constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(20); #endif constexpr uint32_t out_block_num_tiles = in0_num_subblocks * in1_num_subblocks * out_subblock_num_tiles; @@ -119,7 +120,6 @@ void MAIN { constexpr uint32_t matmul_partials_cb = tt::CBIndex::c_24; constexpr uint32_t tilized_in0_cb_id = tt::CBIndex::c_25; // constexpr uint32_t untilize_mode_reblock_cb = tt::CBIndex::c_26; - constexpr uint32_t out_cb_id = tt::CBIndex::c_16; constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? matmul_partials_cb : out_cb_id; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp index 8bbe216e90d..23ee3f28a2d 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp @@ -81,6 +81,7 @@ void kernel_main() { constexpr uint32_t window_inner = get_compile_time_arg_val(9); constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(10); constexpr uint32_t padded_conv_act_size_w = get_compile_time_arg_val(13); + constexpr uint32_t act_block_w_extra_align_bytes = get_compile_time_arg_val(14); constexpr uint32_t act_num_blocks_h = get_compile_time_arg_val(16); constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(17); constexpr uint32_t act_w_num_outer = get_compile_time_arg_val(18); @@ -131,7 +132,6 @@ void kernel_main() { act_mcast_sender_semaphore_valid_addr_ptr[0] = 1; // Load const 1 to be used as semaphore valid value sent from sender to receivers uint32_t act_mcast_sender_semaphore_valid_addr = reinterpret_cast(&l1_array[0]); - // Set up remote VALID value volatile tt_l1_ptr uint32_t* act_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(act_mcast_receiver_semaphore_addr); @@ -178,6 +178,9 @@ void kernel_main() { conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); + if constexpr (act_block_w_extra_align_bytes) { + l1_write_addr_act += act_block_w_extra_align_bytes; + } read_channels( l1_write_addr_act, act_l1_read_addr, @@ -185,6 +188,9 @@ void kernel_main() { conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); + if constexpr (act_block_w_extra_align_bytes) { + l1_write_addr_act += act_block_w_extra_align_bytes; + } #else read_dilated_channels( l1_write_addr_act, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index b58d64fe2b5..652e37e890e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -172,51 +172,7 @@ void kernel_main() { cb_reserve_back(cb_id_weight, total_weight_num_tiles); cb_push_back(cb_id_weight, total_weight_num_tiles); } -#ifndef SHARDED_OUT - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for (uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for (uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for (uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - // DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - // DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - // DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; -#endif - } // out_num_blocks_h + } out_block_w_start_tile_id += out_next_block_stride_w; out_block_w_start_tile_id_w += weight_block_width_ntiles; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index ef9bed9e808..c6dd8d3d08a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -315,50 +315,6 @@ void kernel_main() { cb_push_back(cb_id_weight, total_weight_num_tiles); } -#ifndef SHARDED_OUT - uint32_t out_sbh_start_tile_id = out_block_h_start_tile_id; - uint32_t out_sbh_start_tile_id_h = out_block_h_start_tile_id_h; // - for (uint32_t sbh = 0; sbh < out_num_subblocks_h; sbh++) { - uint32_t out_sbw_start_tile_id = out_sbh_start_tile_id; - uint32_t out_sbw_start_tile_id_w = out_block_w_start_tile_id_w; - for (uint32_t sbw = 0; sbw < out_num_subblocks_w; sbw++) { - uint32_t out_sb_row_start_tile_id = out_sbw_start_tile_id; - // wait for one subblock worth tiles - cb_wait_front(cb_id_out0, out_subblock_tile_count); - uint32_t l1_read_addr = get_read_ptr(cb_id_out0); - for (uint32_t h = 0; h < out_subblock_h; h++) { - uint32_t out_tile_id = out_sb_row_start_tile_id; - uint32_t out_tile_id_h = out_sbh_start_tile_id_h + h; - if (out_tile_id_h >= out_height_num_tiles) { // block shape height padding - break; - } - for (uint32_t w = 0; w < out_subblock_w; w++) { - uint32_t out_tile_id_w = out_sbw_start_tile_id_w + w; - if (out_tile_id_w >= out_width_num_tiles) { // block shape width padding - l1_read_addr += tile_nbytes; - } else { - // DPRINT << "out_tile_id - " << out_tile_id << ENDL(); - uint64_t out_tile_noc_addr = get_noc_addr(out_tile_id, s); - // DPRINT << "out_tile_id=" << out_tile_id << ENDL(); - noc_async_write(l1_read_addr, out_tile_noc_addr, tile_nbytes); - l1_read_addr += tile_nbytes; - out_tile_id += out_next_tile_stride_w; - } - } // out_subblock_w (ntiles) - out_sb_row_start_tile_id += out_next_tile_stride_h; - } // out_subblock_h (ntiles) - noc_async_write_barrier(); - // DPRINT << "Done writing subblock." << ENDL(); - cb_pop_front(cb_id_out0, out_subblock_tile_count); - out_sbw_start_tile_id += out_next_subblock_stride_w; - out_sbw_start_tile_id_w += out_subblock_w; - } // out_num_subblocks_w - out_sbh_start_tile_id += out_next_subblock_stride_h; - out_sbh_start_tile_id_h += out_subblock_h; - } // out_num_subblocks_h - out_block_h_start_tile_id += out_next_block_stride_h; - out_block_h_start_tile_id_h += out_block_height_num_tiles; -#endif } // out_num_blocks_h out_block_w_start_tile_id += out_next_block_stride_w; out_block_w_start_tile_id_w += weight_block_width_ntiles; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index 6690a8d6517..064e2ed3eab 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -58,6 +58,12 @@ void kernel_main() { constexpr uint32_t out_addr = get_compile_time_arg_val(29); +#ifdef UNPAD_UNTILIZE_OUT + constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33); + constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34); + constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35); + constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36); +#endif uint32_t i = 0; i += 19; uint32_t out_start_tile_id = get_arg_val(i); @@ -82,6 +88,8 @@ void kernel_main() { i += 1; uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i += 1; + uint32_t out_aligned_page_size = get_arg_val(i); + i += 1; volatile tt_l1_ptr uint32_t* weights_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(weights_mcast_receiver_semaphore_addr); @@ -196,8 +204,30 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT +#ifdef UNPAD_UNTILIZE_OUT + uint32_t dst_cb_addr = get_write_ptr(cb_id_out0); + + uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); + for (uint32_t nbw = 0; nbw < out_num_blocks_w; nbw++) { + for (uint32_t nbh = 0; nbh < out_num_blocks_h; nbh++) { + for (uint32_t bh = 0; bh < out_block_height_num_tiles; bh++) { + cb_wait_front(untilized_padded_out_cb, out_block_width_ntiles); + uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); + for (uint32_t r = 0; r < 32; r++) { + noc_async_read(get_noc_addr(src_cb_addr), dst_cb_addr, out_block_width_bytes); + noc_async_read_barrier(); + src_cb_addr += out_block_width_padded_bytes; + + dst_cb_addr += out_aligned_page_size; + } + cb_pop_front(untilized_padded_out_cb, out_block_width_ntiles); + } + } + } +#else cb_wait_front( cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif +#endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index d29e0d1116c..8aed491bb61 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -59,6 +59,12 @@ void kernel_main() { constexpr uint32_t out_addr = get_compile_time_arg_val(29); +#ifdef UNPAD_UNTILIZE_OUT + constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33); + constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34); + constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35); + constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36); +#endif uint32_t i = 0; i += 1; const uint32_t weight_addr_dram_base = get_arg_val(i); @@ -100,6 +106,8 @@ void kernel_main() { i += 1; uint32_t weights_mcast_receiver_semaphore_addr = get_semaphore(get_arg_val(i)); i += 1; + uint32_t out_aligned_page_size = get_arg_val(i); + i += 1; #ifndef SKIP_MCAST // Set ur local VALID value, to be mcasted to destinations flag address after the data has been mcasted @@ -158,39 +166,33 @@ void kernel_main() { // read weight blocks inner dim // read weight slice - 1 block of weights in width dim and full weight matrix height // read slice only once for all activation blocks - uint32_t weight_h_offset = 0; + uint32_t weight_current_block_start_tile_id = weight_start_tile_id; for (uint32_t weight_tile_h_outer_i = 0; weight_tile_h_outer_i < weight_block_height_num_outer; weight_tile_h_outer_i++) { - uint32_t weight_current_block_start_tile_id = weight_start_tile_id; cb_reserve_back(cb_id_weight, weight_block_num_tiles); uint32_t weight_write_l1_addr = get_write_ptr(cb_id_weight); // mcast args uint32_t weights_start_address = weight_write_l1_addr; uint32_t weights_block_size_bytes = 0; - - for (uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h; block_weight_h++) { // TODO: 9 - uint32_t weight_row_start_tile_id = weight_current_block_start_tile_id + weight_h_offset; - + // loop over weight block tiles along h + // num_blocks_weight_h * weight_block_height_ntiles + // weight_stride_h + for (uint32_t block_weight_h = 0; block_weight_h < num_blocks_weight_h * weight_block_height_ntiles; + block_weight_h++) { // mcast args // uint32_t weights_start_address = weight_write_l1_addr; // uint32_t weights_block_size_bytes = 0; - // loop over weight block tiles along h - for (uint32_t weight_tile_h_i = 0; weight_tile_h_i < weight_block_height_ntiles; - ++weight_tile_h_i) { // TODO: 2 - uint32_t weight_tile_id = weight_row_start_tile_id; - // loop over weight block tiles along w - for (uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; - ++weight_tile_w_i) { - s_weight.noc_async_read_tile(weight_tile_id, weight_write_l1_addr); - weight_write_l1_addr += weight_tile_nbytes; - weights_block_size_bytes += weight_tile_nbytes; - weight_tile_id += 1; - } // for weight_block_w - weight_row_start_tile_id += weight_stride_h; - } // for weight_block_h - weight_current_block_start_tile_id += weight_next_block_stride_h; + uint32_t weight_tile_id = weight_current_block_start_tile_id; + // loop over weight block tiles along w + for (uint32_t weight_tile_w_i = 0; weight_tile_w_i < weight_block_width_ntiles; ++weight_tile_w_i) { + s_weight.noc_async_read_tile(weight_tile_id, weight_write_l1_addr); + weight_write_l1_addr += weight_tile_nbytes; + weights_block_size_bytes += weight_tile_nbytes; + weight_tile_id += 1; + } // for weight_block_w + weight_current_block_start_tile_id += weight_stride_h; } noc_async_read_barrier(); @@ -226,7 +228,6 @@ void kernel_main() { // be sent in order they are issued noc_async_writes_flushed(); #endif - // We should also multicast the flag to destinations // num_dests must not include source, since we are NOT really doing a local copy! noc_semaphore_set_multicast( @@ -234,9 +235,7 @@ void kernel_main() { weights_mcast_receiver_semaphore_noc_addr, weights_mcast_num_cores); #endif - cb_push_back(cb_id_weight, weight_block_num_tiles); - weight_h_offset += weight_inner_block_stride_h; } // for weight_block_height_num_outer #ifdef FUSE_BIAS @@ -286,7 +285,6 @@ void kernel_main() { // be sent in order they are issued noc_async_writes_flushed(); #endif - // We should also multicast the flag to destinations // num_dests must not include source, since we are NOT really doing a local copy! noc_semaphore_set_multicast( @@ -349,8 +347,30 @@ void kernel_main() { weight_start_tile_id += weight_next_block_stride_w; } // out_num_blocks_w #ifdef SHARDED_OUT +#ifdef UNPAD_UNTILIZE_OUT + uint32_t dst_cb_addr = get_write_ptr(cb_id_out0); + + uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); + for (uint32_t nbw = 0; nbw < out_num_blocks_w; nbw++) { + for (uint32_t nbh = 0; nbh < out_num_blocks_h; nbh++) { + for (uint32_t bh = 0; bh < out_block_height_num_tiles; bh++) { + cb_wait_front(untilized_padded_out_cb, out_block_width_ntiles); + uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); + for (uint32_t r = 0; r < 32; r++) { + noc_async_read(get_noc_addr(src_cb_addr), dst_cb_addr, out_block_width_bytes); + noc_async_read_barrier(); + src_cb_addr += out_block_width_padded_bytes; + + dst_cb_addr += out_aligned_page_size; + } + cb_pop_front(untilized_padded_out_cb, out_block_width_ntiles); + } + } + } +#else cb_wait_front( cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif +#endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 100869a6cb5..668372c49a4 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -157,7 +157,8 @@ std::pair> prepare_conv_weights_biases uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device) { + const bool parameters_on_device, + bool is_non_tile_mul_width) { validate_weight_and_bias_tensors(weight_tensor, bias_tensor); ttnn::Tensor weight_tensor_; // tensor to return @@ -201,6 +202,11 @@ std::pair> prepare_conv_weights_biases tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( {out_channels_padded, in_channels_padded, window_h, window_w})); + if(is_non_tile_mul_width) { + weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( + {round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w})); + out_channels_padded = tt::round_up(out_channels, 32); + } if (weights_bias_dtype == DataType::BFLOAT8_B) { TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); if (bias_tensor.has_value()) { @@ -219,6 +225,9 @@ std::pair> prepare_conv_weights_biases if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_special_padding_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); + } else if(parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { + weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout_block_sharded( + weight_tensor_, num_cores_channels, weights_bias_dtype); } else { weight_tensor_ = tt::tt_metal::convert_conv_weight_tensor_to_tiled_layout( weight_tensor_, weight_block_h_ntiles, weight_block_w_ntiles, weights_bias_dtype); @@ -241,19 +250,23 @@ std::pair> prepare_conv_weights_biases weight_tensor_ = ttnn::operations::core::to_device(weight_tensor_, device, std::nullopt); if (bias_tensor.has_value()) { - bias_tensor_ = bias_tensor.value(); - auto bias_shape = bias_tensor_.get_shape(); - TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); - tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( - std::array({1, 1, 32, tt::round_up(out_channels, weight_block_w_ntiles * 32)})); - bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), 0); - bias_tensor_ = ttnn::to_layout( - bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); - if (bias_tensor_.get_dtype() != weights_bias_dtype) { - bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype); + if (!is_non_tile_mul_width) { + bias_tensor_ = bias_tensor.value(); + auto bias_shape = bias_tensor_.get_shape(); + TT_ASSERT(bias_shape[3] == out_channels && bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1); + tt::tt_metal::LegacyShape bias_channels_padded_shape = tt::tt_metal::LegacyShape( + std::array({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)})); + bias_tensor_ = ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); + bias_tensor_ = ttnn::to_layout( + bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); + if (bias_tensor_.get_dtype() != weights_bias_dtype) { + bias_tensor_ = ttnn::to_dtype(bias_tensor_, weights_bias_dtype); + } + } else { + bias_tensor_ = convert_conv_bias_tensor_to_tiled_layout_block_sharded( + bias_tensor.value(), num_cores_channels, weights_bias_dtype); } - if(parameters_on_device) - bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); + bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } return {weight_tensor_, bias_tensor.has_value() ? bias_tensor_ : std::optional()}; @@ -475,7 +488,8 @@ template std::pair> prepare_conv_weigh uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device); + const bool parameters_on_device, + bool is_non_tile_mul_width); template std::pair> prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, @@ -489,7 +503,8 @@ template std::pair> prepare_conv_weigh uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device); + const bool parameters_on_device, + bool is_non_tile_mul_width); template ttnn::Tensor prepare_conv_bias( const ttnn::Tensor& bias_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 61ac759d745..18e654ad37c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -76,7 +76,8 @@ std::pair> prepare_conv_weights_biases uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device=true); + const bool parameters_on_device=true, + bool is_non_tile_mul_width=false); } // namespace conv2d } // namespace operations::conv diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp index 1435c3f97d7..63e4ed9828f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/kernels/dataflow/halo_gather.cpp @@ -45,7 +45,13 @@ inline bool fill_with_val(uint32_t begin_addr, uint32_t n, uint16_t val) { return true; } -template +template < + uint32_t stick_nbytes, + uint32_t input_aligned_page_size, + bool is_block_sharded, + bool is_width_sharded, + bool is_read, + bool is_col_major> void copy_sticks_async( tt_l1_ptr uint16_t const* config_data, const uint16_t my_noc_x, @@ -68,15 +74,31 @@ void copy_sticks_async( uint16_t nsticks = config_data[i + j + 2]; uint32_t size = nsticks * stick_nbytes; uint32_t dst_offset = dst_local_idx * stick_nbytes; - uint32_t src_offset = src_local_idx * stick_nbytes; + uint32_t src_offset = src_local_idx * input_aligned_page_size; if constexpr (is_read) { uint32_t dst_addr = out_base_l1_addr + dst_offset; uint64_t src_addr = base_addr + src_offset; - noc_async_read(src_addr, dst_addr, size); + if constexpr (stick_nbytes == input_aligned_page_size) { + noc_async_read(src_addr, dst_addr, size); + } else { + for (uint16_t k = 0; k < nsticks; k++) { + noc_async_read(src_addr, dst_addr, stick_nbytes); + dst_addr += stick_nbytes; + src_addr += input_aligned_page_size; + } + } } else { uint64_t dst_addr = base_addr + dst_offset; uint32_t src_addr = in_base_l1_addr + src_offset; - noc_async_write(src_addr, dst_addr, size); + if constexpr (stick_nbytes == input_aligned_page_size) { + noc_async_write(src_addr, dst_addr, size); + } else { + for (uint16_t k = 0; k < nsticks; k++) { + noc_async_write(src_addr, dst_addr, stick_nbytes); + dst_addr += stick_nbytes; + src_addr += input_aligned_page_size; + } + } } } @@ -99,6 +121,7 @@ void kernel_main() { constexpr uint32_t remote_read = get_compile_time_arg_val(11); constexpr bool is_col_major = get_compile_time_arg_val(12) == 1; constexpr uint32_t is_width_sharded = get_compile_time_arg_val(13); + constexpr uint32_t input_aligned_page_size = get_compile_time_arg_val(14); constexpr uint32_t elem_nbytes = sizeof(uint16_t); constexpr uint16_t pad_core_id = 0xFFFF; @@ -144,15 +167,25 @@ void kernel_main() { if constexpr (remote_config_cb_id) { uint32_t config_data_l1_addr = get_read_ptr(remote_config_cb_id); tt_l1_ptr uint16_t const* config_data = reinterpret_cast(config_data_l1_addr); - copy_sticks_async( - config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr); + copy_sticks_async< + stick_nbytes, + input_aligned_page_size, + is_block_sharded, + is_width_sharded, + remote_read, + is_col_major>(config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr); } if constexpr (local_config_cb_id) { uint32_t config_data_l1_addr = get_read_ptr(local_config_cb_id); tt_l1_ptr uint16_t const* config_data = reinterpret_cast(config_data_l1_addr); - copy_sticks_async( - config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr); + copy_sticks_async< + stick_nbytes, + input_aligned_page_size, + is_block_sharded, + is_width_sharded, + false, + is_col_major>(config_data, my_noc_x, my_noc_y, in_base_l1_addr, out_base_l1_addr); } noc_async_read_barrier(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp index 16fbcfc239e..246c2e22cc9 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_halo_v2/device/untilize_with_halo_v2_program_factory.cpp @@ -168,6 +168,10 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( bool const is_block_sharded = input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED; bool const is_width_sharded = input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED; + auto aligned_input_nstick_nbytes = out_stick_nbytes; + if (out_stick_nbytes % input_tensor.buffer()->alignment() != 0) { + aligned_input_nstick_nbytes = tt::round_up(out_stick_nbytes, input_tensor.buffer()->alignment()); + } // reader kernel std::vector reader_ct_args = { 0, // padding_config_cb_id @@ -183,7 +187,8 @@ operation::ProgramWithCallbacks untilize_with_halo_multi_core_v2( is_block_sharded, remote_read, (uint32_t)(transpose_mcast ? 1 : 0), - is_width_sharded}; + is_width_sharded, + aligned_input_nstick_nbytes}; reader_ct_args[0] = 0; reader_ct_args[1] = local_config_cb_id; diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index b364413f57b..ef1e685998d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -12,6 +12,27 @@ namespace tt { namespace tt_metal { +template +Tensor convert_tensor(const Tensor& input_tensor, compute_& compute) { + auto convert_tensor = [&compute](const auto& input_tensor) { + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + input_tensor.get_storage()); + }; + + return ttnn::distributed::is_multi_device_tensor(input_tensor) ? transform(input_tensor, convert_tensor) + : convert_tensor(input_tensor); +} + template Tensor to_weight_special_padding_tile_layout( const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { @@ -68,23 +89,7 @@ Tensor to_weight_special_padding_tile_layout( Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); return rm_tensor.to(Layout::TILE); }; - auto convert_tensor = [&compute](const auto& conv_weight_tensor) { - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - conv_weight_tensor.get_storage()); - }; - - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) - : convert_tensor(conv_weight_tensor); + return convert_tensor(conv_weight_tensor, compute); } template @@ -146,22 +151,7 @@ Tensor to_weight_tile_layout( return rm_tensor.to(Layout::TILE); }; - auto convert_tensor = [&compute](const auto& conv_weight_tensor) { - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - conv_weight_tensor.get_storage()); - }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) - : convert_tensor(conv_weight_tensor); + return convert_tensor(conv_weight_tensor, compute); } // Converts convolution weights to tilized 2d matrix layout. @@ -197,6 +187,211 @@ Tensor convert_conv_weight_tensor_to_tiled_layout( } } +template +Tensor to_weight_tile_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { + auto weight_matrix_cols = w_shape[0]; + TT_ASSERT(weight_matrix_cols % num_channel_shards == 0); + auto conv_output_shard_width = weight_matrix_cols / num_channel_shards; + auto conv_output_shard_width_padded = + (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * + constants::TILE_WIDTH; + if (conv_output_shard_width < conv_output_shard_width_padded) { + // width padding for conv output shard padding + weight_matrix_cols = conv_output_shard_width_padded * num_channel_shards; + } + + auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; + TT_ASSERT(w_shape[1] % num_channel_shards == 0); + auto conv_input_shard_width = w_shape[1] / num_channel_shards; + auto weight_block_height = conv_input_shard_width * w_shape[2] * w_shape[3]; + auto weight_block_height_padded = + (uint32_t)std::ceil((double)weight_block_height / (double)constants::TILE_HEIGHT) * constants::TILE_HEIGHT; + if (weight_block_height < weight_block_height_padded) { + // height padding for non tile multiple block height + weight_matrix_rows = weight_block_height_padded * num_channel_shards; + } + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto ic = 0; ic < num_channel_shards; ic++) { + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c_s = 0; c_s < conv_input_shard_width; c_s++) { + for (auto oc = 0; oc < num_channel_shards; oc++) { + for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { + auto matrix_idx = (oc * conv_output_shard_width_padded + k_s) + + c_s * weight_matrix_cols + + s * conv_input_shard_width * weight_matrix_cols + + r * w_shape[3] * conv_input_shard_width * weight_matrix_cols + + ic * weight_block_height_padded * weight_matrix_cols; + auto idx = (oc * conv_output_shard_width + k_s) * w_shape[1] * w_shape[2] * w_shape[3] + + (ic * conv_input_shard_width + c_s) * w_shape[2] * w_shape[3] + + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + } + } + } + } + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + if (output_dtype == DataType::BFLOAT4_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } + auto rm_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + }; + return convert_tensor(conv_weight_tensor, compute); +} + +// Converts convolution weights to tilized 2d matrix layout for block sharded conv. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype) { + TT_ASSERT( + conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && + "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std:: + map> + to_w_tile_layout_map = { + {DataType::BFLOAT16, &to_weight_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_weight_tile_layout_block_sharded}, + {DataType::UINT32, &to_weight_tile_layout_block_sharded}, + }; + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); + } + } + return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( + conv_weight_tensor, num_channel_shards, output_dtype.value_or(conv_weight_tensor.get_dtype())); +} + +template +Tensor to_bias_tile_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, DataType output_dtype) { + auto b_shape = conv_bias_tensor.get_legacy_shape(); + TT_ASSERT(b_shape[0] == 1 && b_shape[1] == 1 && b_shape[2] == 1); + auto compute = [&b_shape, &num_channel_shards, &output_dtype](const auto& input_buffer) { + auto bias_matrix_cols = b_shape[3]; + /*TT_ASSERT(bias_matrix_cols % num_channel_shards == 0);*/ + auto conv_output_shard_width = bias_matrix_cols / num_channel_shards; + auto conv_output_shard_width_padded = + (uint32_t)std::ceil((double)conv_output_shard_width / (double)constants::TILE_WIDTH) * + constants::TILE_WIDTH; + if (conv_output_shard_width < conv_output_shard_width_padded) { + // width padding for conv output shard padding + bias_matrix_cols = conv_output_shard_width_padded * num_channel_shards; + } + + auto bias_matrix_rows = 32; + ttnn::SimpleShape output_shape{1, 1, bias_matrix_rows, bias_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); + for (auto oc = 0; oc < num_channel_shards; oc++) { + for (auto k_s = 0; k_s < conv_output_shard_width; k_s++) { + auto matrix_idx = oc * conv_output_shard_width_padded + k_s; + auto idx = oc * conv_output_shard_width + k_s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + if (output_dtype == DataType::BFLOAT4_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } + auto rm_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + }; + + return convert_tensor(conv_bias_tensor, compute); +} + +// Converts convolution bias to tilized 2d matrix layout for block sharded conv. +// Returns a new tensor with layout=Tile +Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype) { + TT_ASSERT( + conv_bias_tensor.get_layout() == Layout::ROW_MAJOR && + "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std:: + map> + to_b_tile_layout_map = { + {DataType::BFLOAT16, &to_bias_tile_layout_block_sharded}, + {DataType::FLOAT32, &to_bias_tile_layout_block_sharded}, + {DataType::UINT32, &to_bias_tile_layout_block_sharded}, + }; + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_bias_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_bias_tensor.get_dtype() == conv_bias_tensor.get_dtype()); + } + } + return to_b_tile_layout_map.at(conv_bias_tensor.get_dtype())( + conv_bias_tensor, num_channel_shards, output_dtype.value_or(conv_bias_tensor.get_dtype())); +} + // Converts convolution weights to tilized 2d matrix layout. // Returns a new tensor with layout=Tile Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( @@ -275,21 +470,7 @@ static Tensor conv_group_weight_zero_pad_helper( std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); }; - auto f = [&](const auto& tensor) { - return std::visit( - [&](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return pad_weight(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return pad_weight(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - tensor.get_storage()); - }; - return ttnn::distributed::is_multi_device_tensor(weight) ? transform(weight, f) : f(weight); + return convert_tensor(weight, pad_weight); } /* diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 36c495bffa6..96ce34431b9 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -20,6 +20,16 @@ Tensor convert_conv_weight_tensor_to_tiled_layout( uint32_t in1_block_w, std::optional output_dtype = std::nullopt); +// Converts convolution weights to tilized 2d matrix layout for block sharded conv. Adds zero padding between weight +// blocks based on output shard width padding. Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_weight_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); + +// Converts convolution bias to tilized layout for block sharded conv. Adds zero padding between bias blocks based on +// output shard width padding. Returns a new tensor with layout=Tile +Tensor convert_conv_bias_tensor_to_tiled_layout_block_sharded( + const Tensor& conv_bias_tensor, uint32_t num_channel_shards, std::optional output_dtype = std::nullopt); + // Converts convolution weights to tilized 2d matrix layout with special block height padding // Returns a new tensor with layout=Tile Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout(