diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index d775467e96c..c92543fc359 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -27,7 +27,7 @@ TTPyCompositeConv, SlidingWindowOpParamsWithParallelConfig, ) -from ttnn.operations.conv.tt_py_max_pool import TTPyMaxPool +from ttnn.operations.pool import TTPyMaxPool from models.utility_functions import ( _nearest_32, diff --git a/tests/tt_eager/ops/test_average_pool.cpp b/tests/tt_eager/ops/test_average_pool.cpp index c0772c6032e..108ce824906 100644 --- a/tests/tt_eager/ops/test_average_pool.cpp +++ b/tests/tt_eager/ops/test_average_pool.cpp @@ -2,8 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp" +#include "ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.hpp" #include "ttnn/experimental/tt_dnn/op_library/auto_format.hpp" +#include "tt_dnn/op_library/auto_format.hpp" #include "tt_numpy/functions.hpp" #include "tensor/tensor.hpp" @@ -24,7 +25,7 @@ Tensor run_avg_pool_2d_resnet(Shape& tensor_shape, Device* device) { if (!AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape)) { padded_input_tensor = AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, 0, Layout::TILE); // pad with 0s } - auto device_output = average_pool_2d(padded_input_tensor); + auto device_output = avg_pool2d(padded_input_tensor); return device_output.cpu(); }; diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py index abaea606780..77996476dc8 100644 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py +++ b/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py @@ -11,6 +11,7 @@ from tt_lib.utils import _nearest_32 from models.utility_functions import comp_pcc +import ttnn TILE_HEIGHT = TILE_WIDTH = 32 @@ -63,7 +64,7 @@ def test_run_average_pool(act_shape, dtype, device, use_program_cache, enable_as ttact_res = ttact.to(device) def run_ops(ttact_res): - return ttl.tensor.average_pool_2d(ttact_res) + return ttnn.avg_pool2d(ttact_res) # Compile run_ops(ttact_res) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_average_pool.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_average_pool.py index 6b0133a1637..beb3ee4ab00 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_average_pool.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_average_pool.py @@ -12,6 +12,8 @@ from tt_lib.utils import _nearest_32 from models.utility_functions import comp_pcc +import ttnn + TILE_HEIGHT = TILE_WIDTH = 32 @@ -43,7 +45,7 @@ def test_run_average_pool(act_shape, dtype, device): ttact = ttact.pad_to_tile(0.0) ttact = ttact.to(device) - out = ttl.tensor.average_pool_2d(ttact) + out = ttnn.avg_pool2d(ttact) out = out.cpu().to(ttl.tensor.Layout.ROW_MAJOR) out_shape = [batch_size, 1, 1, channels] diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_max_pool.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_max_pool.py index 40090d8aa6e..661d0a4588a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_max_pool.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_max_pool.py @@ -17,6 +17,7 @@ from functools import reduce import operator +import ttnn def volume(shape): @@ -170,8 +171,8 @@ def test_run_max_pool( f"Skipping over Resnet specific config where parallelization does not fit on core grid {compute_grid_size}" ) - if (compute_grid_size.x * compute_grid_size.y) == ncores_on_n300: - pytest.skip(f"Skipping on N300 (8x7 core grid) due to bug https://github.com/tenstorrent/tt-metal/issues/5458") + # if (compute_grid_size.x * compute_grid_size.y) == ncores_on_n300: + # pytest.skip(f"Skipping on N300 (8x7 core grid) due to bug https://github.com/tenstorrent/tt-metal/issues/5458") torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32) @@ -236,7 +237,7 @@ def test_run_max_pool( else: ttact = ttact.to(device, in_mem_config) - out_padded = ttl.tensor.max_pool2d( + out_padded = ttnn.max_pool2d( ttact, in_n, in_h, @@ -249,9 +250,9 @@ def test_run_max_pool( pad_w, dilation_h, dilation_w, - out_mem_config, - nblocks, - use_multicore, + memory_config=out_mem_config, + nblocks=nblocks, + use_multicore=use_multicore, ) if out_mem_config.is_sharded(): out_padded = ttl.tensor.sharded_to_interleaved(out_padded, interleaved_mem_config) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool.py index 419597c2b1b..e7c09997703 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool.py @@ -14,6 +14,7 @@ from tt_lib.utils import _nearest_32 from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from models.utility_functions import is_wormhole_b0 +import ttnn def volume(shape): @@ -186,7 +187,7 @@ def test_run_max_pool( # ttl.device.DumpDeviceMemoryState(device) ttact_sharded.deallocate() - out_padded = ttl.tensor.max_pool2d( + out_padded = ttnn.max_pool2d( out_untilize, in_n, in_h, @@ -199,9 +200,9 @@ def test_run_max_pool( pad_w, dilation_h, dilation_w, - out_mem_config, - nblocks, - True, + memory_config=out_mem_config, + nblocks=nblocks, + use_multicore=True, ) out_padded = ttl.tensor.sharded_to_interleaved(out_padded, interleaved_mem_config) out_padded = out_padded.cpu().to(ttl.tensor.Layout.ROW_MAJOR) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool_v2.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool_v2.py index b4198064967..fd2376f13d3 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool_v2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_untilize_with_halo_and_max_pool_v2.py @@ -9,10 +9,13 @@ import torch -from ttnn.operations.conv.tt_py_max_pool import ( + +from ttnn.operations.pool import ( TTPyMaxPool, SlidingWindowOpParamsWithParallelConfig, ) +from ttnn.operations.pool import max_pool2d_legacy as ttnn_max_pool2d_legacy + import tt_lib as ttl from tt_lib.utils import _nearest_32 @@ -170,7 +173,12 @@ def test_run_max_pool( assert kernel_w == kernel_h and stride_w == stride_h and pad_w == pad_h and dilation_w == dilation_h max_pool_reader_patterns_cache = {} - max_pool = TTPyMaxPool(sliding_window_op_params, device, max_pool_reader_patterns_cache, pad_val=pad_val) + max_pool = TTPyMaxPool( + sliding_window_op_params, + device, + max_pool_reader_patterns_cache, + pad_val=pad_val, + ) ttact_sharded = max_pool.copy_input_to_device(ttact) out_padded = max_pool(ttact_sharded) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index ab048c351f9..ae8cacab89a 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -58,6 +58,10 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/transformer/device/transformer_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp ) ### Setup TTNN as a shared library with optional Python bindings diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 7ddba605eb6..dc40707d640 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -13,9 +13,10 @@ #include "pybind11/operations/core.hpp" #include "pybind11/operations/creation.hpp" #include "pybind11/operations/kv_cache.hpp" -#include "pybind11/operations/maxpool2d.hpp" -#include "pybind11/operations/pool.hpp" #include "pybind11/operations/ternary.hpp" + +#include "ttnn/operations/pool/avgpool/avg_pool_pybind.hpp" +#include "ttnn/operations/pool/maxpool/maxpool_pybind.hpp" #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" #include "ttnn/operations/conv2d/conv2d_pybind.hpp" @@ -34,7 +35,6 @@ #include "ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp" #include "ttnn/operations/experimental/experimental_pybind.hpp" - namespace py = pybind11; namespace ttnn { @@ -91,8 +91,9 @@ void py_module(py::module& module) { auto m_conv2d = module.def_submodule("conv2d", "conv2d operation"); conv2d::py_module(m_conv2d); - auto m_maxpool2d = module.def_submodule("maxpool2d", "maxpool 2d operation"); - maxpool2d::py_module(m_maxpool2d); + auto m_pool = module.def_submodule("pool", "pooling operations"); + maxpool::py_module(m_pool); + avgpool::py_module(m_pool); auto m_normalization = module.def_submodule("normalization", "normalization operations"); normalization::py_module(m_normalization); @@ -106,9 +107,6 @@ void py_module(py::module& module) { auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations"); kv_cache::py_module(m_kv_cache); - auto m_pool = module.def_submodule("pool", "pool operations"); - pool::py_module(m_pool); - auto m_copy = module.def_submodule("copy", "copy operations"); copy::py_module(m_copy); diff --git a/ttnn/cpp/pybind11/operations/maxpool2d.hpp b/ttnn/cpp/pybind11/operations/maxpool2d.hpp deleted file mode 100644 index a215600d258..00000000000 --- a/ttnn/cpp/pybind11/operations/maxpool2d.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -#include "ttnn/operations/maxpool2d.hpp" -#include "ttnn/types.hpp" - -namespace py = pybind11; - -namespace ttnn::operations::maxpool2d { - -using array2_t = std::array; - -void py_module(py::module& module) { - module.def( - "maxpool2d", - [](const ttnn::Tensor& input_tensor, - uint32_t batch_size, - uint32_t input_height, - uint32_t input_width, - uint32_t channels, - array2_t kernel_size, - array2_t stride, - array2_t padding, - array2_t dilation, - Device& device) -> Tensor { - return maxpool2d(input_tensor, batch_size, input_height, input_width, channels, kernel_size, stride, padding, dilation, device); - }, - py::kw_only(), - py::arg("input_tensor"), - py::arg("batch_size"), - py::arg("input_height"), - py::arg("input_width"), - py::arg("channels"), - py::arg("kernel_size"), - py::arg("stride"), - py::arg("padding"), - py::arg("dilation"), - py::arg("device")); -} - -} // namespace ttnn::operations::maxpool diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt index 412dda8643f..54bcd4d6c98 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/CMakeLists.txt @@ -25,10 +25,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/non_zero_indices/non_zero_indices_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fill_rm/fill_rm_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fully_connected/fully_connected_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pool/average_pool.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pool/max_pool.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pool/max_pool_single_core.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pool/max_pool_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transpose/transpose_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transpose/wh_multi_core/transpose_wh_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transpose/hc_multi_core/transpose_hc_op_multi_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp deleted file mode 100644 index e6f2b63da50..00000000000 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "tt_metal/host_api.hpp" -#include "tensor/tensor.hpp" - -#include "ttnn/experimental/tt_dnn/op_library/operation.hpp" - -namespace tt { -namespace tt_metal { - -enum class PoolType { - AVG -}; - -Tensor average_pool_2d(const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const std::optional& output_dtype = std::nullopt); - -} // namespace tt_metal -} // namespace tt diff --git a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor.cpp index 69f93ba78ae..46de5a50e1c 100644 --- a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -14,8 +14,6 @@ #include "ttnn/experimental/tt_dnn/op_library/fully_connected/fully_connected_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/layernorm_distributed/layernorm_pre_allgather_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/layernorm_distributed/layernorm_post_allgather_op.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/fast_reduce_nc/fast_reduce_nc_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp" @@ -524,103 +522,6 @@ void TensorModule(py::module& m_tensor) { "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - // Pools - m_tensor.def( - "average_pool_2d", - &average_pool_2d, - py::arg().noconvert(), - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("output_dtype").noconvert() = std::nullopt, - R"doc( - Average Pool 2D - It operates on tensors whose that have channels as the last dimension - - +----------+----------------------------+------------+-------------------------------+----------+ - | Argument | Description | Data type | Valid range | Required | - +==========+============================+============+===============================+==========+ - | act | Input activations tensor | Tensor | | Yes | - +----------+----------------------------+------------+-------------------------------+----------+ - )doc"); - - m_tensor.def( - "max_pool2d", - &max_pool2d, - py::arg("input").noconvert(), - py::arg("in_n").noconvert(), - py::arg("in_h").noconvert(), - py::arg("in_w").noconvert(), - py::arg("kernel_h").noconvert(), - py::arg("kernel_w").noconvert(), - py::arg("stride_h") = 1, - py::arg("stride_w") = 1, - py::arg("pad_h") = 0, - py::arg("pad_w") = 0, - py::arg("dilation_h") = 1, - py::arg("dilation_w") = 1, - py::arg("output_mem_config") = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("nblocks") = 1, - py::arg("use_multicore") = true, - R"doc( - Max Pool 2D - +-------------------+-------------------------------+---------------+-------------+----------+ - | Argument | Description | Data type | Valid range | Required | - +===================+===============================+===============+=============+==========+ - | input | Input activations tensor | Tensor | | Yes | - | in_n | Input nbatch | Tensor | | Yes | - | in_h | Input height | Tensor | | Yes | - | in_w | Input width | Tensor | | Yes | - | kernel_h | kernel window height | uint32_t | | Yes | - | kernel_w | kernel window width | uint32_t | | Yes | - | stride_h | stride in height dim | uint32_t | | No | - | stride_w | stride in width dim | uint32_t | | No | - | pad_h | padding in height dim | uint32_t | | No | - | pad_w | padding in width dim | uint32_t | | No | - | dilation_h | kernel dilation in height dim | uint32_t | | No | - | dilation_w | kernel dilation in width dim | uint32_t | | No | - | output_mem_config | output tensor memory config | MemoryConfig | | No | - +-------------------+-------------------------------+---------------+-------------+----------+ - )doc"); - - m_tensor.def( - "max_pool2d_v2", - &max_pool2d_v2, - py::arg("input").noconvert(), - py::arg("reader_indices").noconvert(), - py::arg("in_n").noconvert(), - py::arg("in_h").noconvert(), - py::arg("in_w").noconvert(), - py::arg("kernel_h").noconvert(), - py::arg("kernel_w").noconvert(), - py::arg("stride_h") = 1, - py::arg("stride_w") = 1, - py::arg("pad_h") = 0, - py::arg("pad_w") = 0, - py::arg("dilation_h") = 1, - py::arg("dilation_w") = 1, - py::arg("output_mem_config") = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("nblocks") = 1, - py::arg("use_multicore") = true, - R"doc( - Max Pool 2D - +-------------------+-------------------------------+---------------+-------------+----------+ - | Argument | Description | Data type | Valid range | Required | - +===================+===============================+===============+=============+==========+ - | input | Input activations tensor | Tensor | | Yes | - | in_n | Input nbatch | Tensor | | Yes | - | in_h | Input height | Tensor | | Yes | - | in_w | Input width | Tensor | | Yes | - | kernel_h | kernel window height | uint32_t | | Yes | - | kernel_w | kernel window width | uint32_t | | Yes | - | stride_h | stride in height dim | uint32_t | | No | - | stride_w | stride in width dim | uint32_t | | No | - | pad_h | padding in height dim | uint32_t | | No | - | pad_w | padding in width dim | uint32_t | | No | - | dilation_h | kernel dilation in height dim | uint32_t | | No | - | dilation_w | kernel dilation in width dim | uint32_t | | No | - | output_mem_config | output tensor memory config | MemoryConfig | | No | - +-------------------+-------------------------------+---------------+-------------+----------+ - )doc"); - // TMs m_tensor.def( "split_last_dim_two_chunks_tiled", diff --git a/ttnn/cpp/ttnn/operations/maxpool2d.hpp b/ttnn/cpp/ttnn/operations/maxpool2d.hpp index ee15eacbdc4..db7472942e2 100644 --- a/ttnn/cpp/ttnn/operations/maxpool2d.hpp +++ b/ttnn/cpp/ttnn/operations/maxpool2d.hpp @@ -9,7 +9,7 @@ #include "ttnn/operations/core.hpp" #include "tt_metal/common/math.hpp" #include "ttnn/operations/conv2d/conv2d.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" +#include "ttnn/cpp/ttnn/operations/pool/max_pool.hpp" #include "ttnn/experimental/tt_dnn/op_library/sliding_window_op_infra/halo_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.cpp b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp similarity index 73% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.cpp rename to ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp index f2853708507..b341f0965fe 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/average_pool.cpp +++ b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp" +#include "ttnn/operations/pool/avgpool/avg_pool.hpp" #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" @@ -10,20 +10,20 @@ namespace tt { namespace tt_metal { template -Tensor pool_2d(const Tensor& input, const MemoryConfig& output_mem_config, const std::optional& output_dtype) { +Tensor pool_2d(const Tensor& input, const MemoryConfig& memory_config, const std::optional& output_dtype) { TT_ASSERT(input.storage_type() == StorageType::DEVICE, "Input tensor needs to be on device"); auto input_shape = input.get_legacy_shape(); switch (pool) { case PoolType::AVG: { auto height_without_padding = input.get_legacy_shape().without_padding()[-2]; - return reduce(input, ReduceOpMath::SUM, ReduceOpDim::H, 1 / float(height_without_padding), output_mem_config, output_dtype); + return reduce(input, ReduceOpMath::SUM, ReduceOpDim::H, 1 / float(height_without_padding), memory_config, output_dtype); } default: TT_ASSERT(false && "Undefined pool type"); } } -Tensor average_pool_2d(const Tensor& input, const MemoryConfig& output_mem_config, const std::optional& output_dtype) { +Tensor avg_pool2d(const Tensor& input, const MemoryConfig& memory_config, const std::optional& output_dtype) { TT_ASSERT(input.storage_type() == StorageType::DEVICE, "Input tensor needs to be on device"); auto output = input; @@ -34,7 +34,7 @@ Tensor average_pool_2d(const Tensor& input, const MemoryConfig& output_mem_confi auto output_shape = Shape({in_shape[0], 1, in_shape[1] * in_shape[2], in_shape[3]}, output_padding); output = output.reshape(output_shape); - output = pool_2d(output, output_mem_config, output_dtype); + output = pool_2d(output, memory_config, output_dtype); return output; } diff --git a/ttnn/cpp/ttnn/operations/pool.hpp b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.hpp similarity index 59% rename from ttnn/cpp/ttnn/operations/pool.hpp rename to ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.hpp index 4631c645ed5..c7de056b039 100644 --- a/ttnn/cpp/ttnn/operations/pool.hpp +++ b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.hpp @@ -4,7 +4,25 @@ #pragma once -#include "ttnn/experimental/tt_dnn/op_library/pool/average_pool.hpp" +#include "tt_metal/host_api.hpp" +#include "tensor/tensor.hpp" + +#include "tt_dnn/op_library/operation.hpp" + +namespace tt { +namespace tt_metal { + +enum class PoolType { + AVG +}; + +Tensor avg_pool2d(const Tensor& input, const MemoryConfig& memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const std::optional& output_dtype = std::nullopt); + +} // namespace tt_metal +} // namespace tt + + +#include "ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool.hpp" #include "ttnn/decorators.hpp" #include "ttnn/operations/core.hpp" @@ -18,7 +36,7 @@ struct GlobalAveragePool2D { const std::optional& memory_config_arg = std::nullopt, const std::optional& output_dtype = std::nullopt) { auto memory_config = memory_config_arg.value_or(input.memory_config()); - auto result = tt::tt_metal::average_pool_2d(input, memory_config, output_dtype); + auto result = tt::tt_metal::avg_pool2d(input, memory_config, output_dtype); return result; } }; diff --git a/ttnn/cpp/pybind11/operations/pool.hpp b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool_pybind.hpp similarity index 65% rename from ttnn/cpp/pybind11/operations/pool.hpp rename to ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool_pybind.hpp index b4056ca4163..a44d75cef0d 100644 --- a/ttnn/cpp/pybind11/operations/pool.hpp +++ b/ttnn/cpp/ttnn/operations/pool/avgpool/avg_pool_pybind.hpp @@ -8,14 +8,13 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/pool.hpp" +#include "ttnn/operations/pool/avgpool/avg_pool.hpp" #include "ttnn/types.hpp" namespace py = pybind11; - namespace ttnn { namespace operations { -namespace pool { +namespace avgpool { namespace detail { @@ -59,8 +58,27 @@ void bind_global_avg_pool2d(py::module& module) { } // namespace detail -void py_module(py::module& module) { detail::bind_global_avg_pool2d(module); } +void py_module(py::module& module) { + detail::bind_global_avg_pool2d(module); + module.def( + "avg_pool2d", + &avg_pool2d, + py::arg().noconvert(), + py::kw_only(), + py::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("dtype").noconvert() = std::nullopt, + R"doc( + Average Pool 2D + It operates on tensors that have channels as the last dimension. + + +----------+----------------------------+------------+-------------------------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +==========+============================+============+===============================+==========+ + | act | Input activations tensor | Tensor | | Yes | + +----------+----------------------------+------------+-------------------------------+----------+ + )doc"); +} -} // namespace pool +} // namespace avgpool } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_single_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_single_core.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_single_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_single_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_single_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_single_core.cpp similarity index 100% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_single_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_single_core.cpp diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp similarity index 95% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_multi_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp index 7fd6ca2b3d6..702ea6a4a1d 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp @@ -8,12 +8,12 @@ #include "detail/util.hpp" #include "tensor/host_buffer/functions.hpp" #include "tensor/tensor_utils.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" -#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils -#include "ttnn/experimental/tt_dnn/op_library/sharding_utilities.hpp" -#include "ttnn/experimental/tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" -#include "ttnn/experimental/tt_dnn/op_library/sliding_window_op_infra/utils.hpp" -#include "ttnn/experimental/tt_dnn/op_library/work_split.hpp" +#include "ttnn/operations/pool/maxpool/max_pool.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils +#include "tt_dnn/op_library/sharding_utilities.hpp" +#include "tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" +#include "tt_dnn/op_library/sliding_window_op_infra/utils.hpp" +#include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" namespace tt { @@ -91,7 +91,7 @@ uint32_t get_num_cores(const Device* device, uint32_t out_nhw, uint32_t nbatch) break; default: // TT_ASSERT(false, "General case is not yet handled! Only RN50 shapes supported in multicore."); - uint32_t out_nhw_per_core = (uint32_t)ceil((float)out_nhw / avail_ncores); + uint32_t out_nhw_per_core = (uint32_t)std::ceil((float)out_nhw / avail_ncores); ncores = out_nhw / out_nhw_per_core; while (avail_ncores > 0) { if (out_nhw % avail_ncores == 0 && (out_nhw / avail_ncores) % TILE_HEIGHT == 0) { @@ -104,7 +104,7 @@ uint32_t get_num_cores(const Device* device, uint32_t out_nhw, uint32_t nbatch) break; } } else if (device->arch() == ARCH::WORMHOLE_B0) { - uint32_t out_nhw_per_core = (uint32_t)ceil((float)out_nhw / avail_ncores); + uint32_t out_nhw_per_core = (uint32_t)std::ceil((float)out_nhw / avail_ncores); ncores = out_nhw / out_nhw_per_core; while (avail_ncores > 0) { if (out_nhw % avail_ncores == 0 && (out_nhw / avail_ncores) % TILE_HEIGHT == 0) { @@ -215,14 +215,14 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read uint32_t kernel_size_hw_padded = ceil_multiple_of(kernel_size_hw, constants::TILE_HEIGHT); - uint32_t in_ntiles_hw = (uint32_t)ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); - uint32_t in_ntiles_c = (uint32_t)ceil((float)input_shape[3] / constants::TILE_WIDTH); - uint32_t out_ntiles_hw = (uint32_t)ceil((float)output_shape[2] / constants::TILE_HEIGHT); - uint32_t out_ntiles_c = (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH); + uint32_t in_ntiles_hw = (uint32_t)std::ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); + uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / constants::TILE_WIDTH); + uint32_t out_ntiles_hw = (uint32_t)std::ceil((float)output_shape[2] / constants::TILE_HEIGHT); + uint32_t out_ntiles_c = (uint32_t)std::ceil((float)output_shape[3] / constants::TILE_WIDTH); uint32_t out_nelems = nblocks; // TODO [AS]: Remove hard coding after identifying optimal param val // Also ensure the calculated ncores is good - uint32_t out_w_loop_count = ceil((float)out_w / out_nelems); + uint32_t out_w_loop_count = std::ceil((float)out_w / out_nelems); uint32_t in_hw = in_h * in_w; uint32_t in_nhw = in_hw * nbatch; @@ -492,10 +492,10 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( if (input.memory_config().is_sharded()) { // sharded, without halo reader_kernel_fname = - std::string("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp"); + std::string("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded.cpp"); } else { reader_kernel_fname = - std::string("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core.cpp"); + std::string("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core.cpp"); } auto reader_kernel = CreateKernel(program, reader_kernel_fname, all_cores, reader_config); @@ -509,7 +509,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( std::vector writer_ct_args = reader_ct_args; auto writer_config = WriterDataMovementConfig(writer_ct_args, writer_defines); std::string writer_kernel_fname( - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core.cpp"); + "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core.cpp"); auto writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, writer_config); /** @@ -523,8 +523,8 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( kernel_size_hw, out_h, out_w, - (uint32_t)ceil((float)output_shape[2] / constants::TILE_HEIGHT), - (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH), + (uint32_t)std::ceil((float)output_shape[2] / constants::TILE_HEIGHT), + (uint32_t)std::ceil((float)output_shape[3] / constants::TILE_WIDTH), out_nelems, out_w_loop_count, nbatch, @@ -542,7 +542,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_generic( .math_approx_mode = false, .compile_args = compute_ct_args, .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; - std::string compute_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp"); + std::string compute_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp"); auto compute_kernel = CreateKernel(program, compute_kernel_fname, core_range, compute_config); if (out_nhw_per_core_cliff > 0) { @@ -724,9 +724,9 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read uint32_t kernel_size_hw_padded = ceil_multiple_of(kernel_size_hw, constants::TILE_HEIGHT); - uint32_t in_ntiles_hw = (uint32_t)ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); - uint32_t in_ntiles_c = (uint32_t)ceil((float)input_shape[3] / constants::TILE_WIDTH); - uint32_t out_ntiles_c = (uint32_t)ceil((float)output_shape[3] / constants::TILE_WIDTH); + uint32_t in_ntiles_hw = (uint32_t)std::ceil((float)kernel_size_hw_padded / constants::TILE_HEIGHT); + uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / constants::TILE_WIDTH); + uint32_t out_ntiles_c = (uint32_t)std::ceil((float)output_shape[3] / constants::TILE_WIDTH); TT_ASSERT(nblocks == 1, "Multiple blocks not yet supported"); @@ -735,7 +735,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl TT_FATAL(input_shape[3] == 16); tile_w = constants::FACE_WIDTH; } - uint32_t out_w_loop_count = ceil((float)out_w / nblocks); + uint32_t out_w_loop_count = std::ceil((float)out_w / nblocks); // distributing out_hw across the grid auto grid_size = device->compute_with_storage_grid_size(); @@ -942,7 +942,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl bf16_one_u32}; std::string reader_kernel_fname( - "ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"); + "ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp"); auto reader0_config = DataMovementConfig{ .processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default, .compile_args = reader0_ct_args}; @@ -973,7 +973,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl .compile_args = writer_ct_args, .defines = writer_defines}; std::string - writer_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp"); auto + writer_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_multi_core_v2.cpp"); auto writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, writer_config); */ @@ -1007,7 +1007,7 @@ operation::ProgramWithCallbacks max_pool_2d_multi_core_sharded_with_halo_v2_impl .math_approx_mode = false, .compile_args = compute_ct_args, .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; - std::string compute_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool_multi_core.cpp"); + std::string compute_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp"); auto compute_kernel = CreateKernel(program, compute_kernel_fname, core_range, compute_config); /* diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp similarity index 99% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp index 6d2d5814a57..2fe6a5cde86 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" +#include "ttnn/operations/pool/maxpool/max_pool.hpp" #include #include @@ -201,7 +201,7 @@ Tensor max_pool2d(const Tensor &input, {input}).at(0); } -Tensor max_pool2d_v2(const Tensor &input, +Tensor max_pool2d_legacy(const Tensor &input, const Tensor &reader_indices, uint32_t in_n, uint32_t in_h, uint32_t in_w, uint32_t kernel_size_h, uint32_t kernel_size_w, @@ -353,7 +353,7 @@ operation::ProgramWithCallbacks MaxPoolNew::create_program(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const { +operation::OpPerformanceModel MaxPoolNew::create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors, const std::vector &output_tensors) const { const auto& input = input_tensors.at(0); const auto& input_shape = input.get_shape(); uint32_t batch_size = sliding_window_config_.batch_size_; diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_single_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp similarity index 92% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_single_core.cpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp index c5de5ed12d6..d8fce28c364 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool_single_core.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp @@ -8,9 +8,9 @@ #include "detail/util.hpp" #include "tensor/host_buffer/functions.hpp" #include "tensor/tensor_utils.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" -#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils -#include "ttnn/experimental/tt_dnn/op_library/work_split.hpp" +#include "ttnn/operations/pool/maxpool/max_pool.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils +#include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" namespace tt { @@ -55,13 +55,13 @@ operation::ProgramWithCallbacks max_pool_2d_single_core(const Tensor &input, Ten uint32_t kernel_size_hw = kernel_size_w * kernel_size_h; // number of valid rows, to read uint32_t kernel_size_hw_padded = ceil_multiple_of(kernel_size_hw, constants::TILE_HEIGHT); - uint32_t in_ntiles_hw = (uint32_t) ceil((float) kernel_size_hw_padded / constants::TILE_HEIGHT); - uint32_t in_ntiles_c = (uint32_t) ceil((float) input_shape[3] / constants::TILE_WIDTH); - uint32_t out_ntiles_hw = (uint32_t) ceil((float) output_shape[2] / constants::TILE_HEIGHT); - uint32_t out_ntiles_c = (uint32_t) ceil((float) output_shape[3] / constants::TILE_WIDTH); + uint32_t in_ntiles_hw = (uint32_t) std::ceil((float) kernel_size_hw_padded / constants::TILE_HEIGHT); + uint32_t in_ntiles_c = (uint32_t) std::ceil((float) input_shape[3] / constants::TILE_WIDTH); + uint32_t out_ntiles_hw = (uint32_t) std::ceil((float) output_shape[2] / constants::TILE_HEIGHT); + uint32_t out_ntiles_c = (uint32_t) std::ceil((float) output_shape[3] / constants::TILE_WIDTH); uint32_t out_nelems = nblocks; // TODO [AS]: Remove hard coding after identifying optimal param val - uint32_t out_w_loop_count = ceil((float) out_w / out_nelems); + uint32_t out_w_loop_count = std::ceil((float) out_w / out_nelems); uint32_t in_hw = in_h * in_w; uint32_t out_hw = out_h * out_w; @@ -151,7 +151,7 @@ operation::ProgramWithCallbacks max_pool_2d_single_core(const Tensor &input, Ten (in_cb_page_nelems_padded * out_nelems * 2) >> 5 // TODO: generalize num rows to fill in in_cb }; auto reader_config = ReaderDataMovementConfig(reader_ct_args); - std::string reader_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/reader_max_pool_2d_single_core.cpp"); + std::string reader_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_single_core.cpp"); auto reader_kernel = CreateKernel(program, reader_kernel_fname, cores, @@ -200,7 +200,7 @@ operation::ProgramWithCallbacks max_pool_2d_single_core(const Tensor &input, Ten std::vector writer_ct_args = reader_ct_args; std::vector writer_rt_args = reader_rt_args; auto writer_config = WriterDataMovementConfig(writer_ct_args); - std::string writer_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/dataflow/writer_max_pool_2d_single_core.cpp"); + std::string writer_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/writer_max_pool_2d_single_core.cpp"); auto writer_kernel = CreateKernel(program, writer_kernel_fname, cores, @@ -221,14 +221,14 @@ operation::ProgramWithCallbacks max_pool_2d_single_core(const Tensor &input, Ten kernel_size_hw_padded, out_h, out_w, - (uint32_t) ceil((float) output_shape[2] / constants::TILE_HEIGHT), - (uint32_t) ceil((float) output_shape[3] / constants::TILE_WIDTH), + (uint32_t) std::ceil((float) output_shape[2] / constants::TILE_HEIGHT), + (uint32_t) std::ceil((float) output_shape[3] / constants::TILE_WIDTH), out_nelems, out_w_loop_count, nbatch, out_h}, // out_h_per_core .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; - std::string compute_kernel_fname("ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/kernels/compute/max_pool.cpp"); + std::string compute_kernel_fname("ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool.cpp"); auto compute_kernel = CreateKernel(program, compute_kernel_fname, cores, diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool.hpp similarity index 77% rename from ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp rename to ttnn/cpp/ttnn/operations/pool/maxpool/max_pool.hpp index e2a3610010c..4fa8c3752c4 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool.hpp @@ -4,10 +4,14 @@ #pragma once +#include "ttnn/core.hpp" +#include "ttnn/types.hpp" #include "tensor/tensor.hpp" +#include "ttnn/operations/conv2d/conv2d.hpp" #include "ttnn/experimental/tt_dnn/op_library/run_operation.hpp" #include "ttnn/experimental/tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" + inline uint32_t ceil_multiple_of(uint32_t n, uint32_t m) { return (uint32_t) std::ceil((float) n / m) * m; } @@ -126,7 +130,7 @@ Tensor max_pool2d(const Tensor &input, const MemoryConfig& out_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, uint32_t nblocks = 1, bool use_multicore = true); -Tensor max_pool2d_v2(const Tensor &input, const Tensor &reader_indices, +Tensor max_pool2d_legacy(const Tensor &input, const Tensor &reader_indices, uint32_t in_n, uint32_t in_h, uint32_t in_w, uint32_t kernel_size_h, uint32_t kernel_size_w, uint32_t stride_h = 1, uint32_t stride_w = 1, @@ -148,7 +152,7 @@ struct MaxPoolNew { std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; - operation::OpPerformanceModel create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const; + operation::OpPerformanceModel create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors, const std::vector &output_tensors) const; static constexpr auto attribute_names = std::make_tuple( "sliding_window_config", @@ -173,3 +177,45 @@ Tensor maxpool2d_new(const Tensor &input, } // namespace tt_metal } // namespace tt + +namespace ttnn::operations { +namespace maxpool { + + +// maxpool macro-op +inline Tensor maxpool2d(const Tensor& input_tensor, uint32_t batch_size, uint32_t input_h, uint32_t input_w, uint32_t channels, std::array kernel_size, std::array stride, std::array padding, std::array dilation, Device& device) { + MemoryConfig memory_config = input_tensor.memory_config(); + const auto shard_grid = memory_config.shard_spec.value().grid; + const auto shard_scheme = memory_config.memory_layout; + const auto shard_orientation = memory_config.shard_spec.value().orientation; + + TT_FATAL(shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, "Only height sharded tensors are supported."); + TT_FATAL(shard_orientation == ShardOrientation::ROW_MAJOR, "Only row major orientation is supported."); + + ParallelConfig parallel_config = conv2d::determine_parallel_config( + shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, + batch_size, + 0, // in_channels -- not used + input_h, + input_w, + 0, // out_channels -- not used + device, + shard_orientation); + uint32_t num_cores_nhw = conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); + + SlidingWindowConfig sliding_window_config = SlidingWindowConfig(batch_size, + input_h, input_w, + kernel_size.at(0), kernel_size.at(1), + stride.at(0), stride.at(1), + padding.at(0), padding.at(1), + dilation.at(0), dilation.at(1), + num_cores_nhw, + parallel_config.grid); + uint32_t neg_inf_pad_val = 0xf7ff; // TODO: double check + + auto haloed_tensor = ttnn::operations::halo::halo_op(input_tensor, sliding_window_config, neg_inf_pad_val, false, parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, 0, memory_config); + return tt::tt_metal::maxpool2d_new(haloed_tensor, sliding_window_config, channels, memory_config); +} + +} // namespace maxpool +} // namespace ttnn::operations diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/maxpool_pybind.hpp b/ttnn/cpp/ttnn/operations/pool/maxpool/maxpool_pybind.hpp new file mode 100644 index 00000000000..4338fcbb1e6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/maxpool_pybind.hpp @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/pool/maxpool/max_pool.hpp" + +#include "ttnn/types.hpp" + +namespace py = pybind11; +namespace ttnn { +namespace operations { +namespace maxpool { + + +void py_module(py::module& module) { + module.def( + "max_pool2d", + &max_pool2d, + py::arg("input").noconvert(), + py::arg("in_n").noconvert(), + py::arg("in_h").noconvert(), + py::arg("in_w").noconvert(), + py::arg("kernel_h").noconvert(), + py::arg("kernel_w").noconvert(), + py::arg("stride_h") = 1, + py::arg("stride_w") = 1, + py::arg("pad_h") = 0, + py::arg("pad_w") = 0, + py::arg("dilation_h") = 1, + py::arg("dilation_w") = 1, + py::kw_only(), + py::arg("memory_config") = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("nblocks") = 1, + py::arg("use_multicore") = true, + R"doc( + Max Pool 2D + +-------------------+-------------------------------+---------------+-------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +===================+===============================+===============+=============+==========+ + | input | Input activations tensor | Tensor | | Yes | + | in_n | Input nbatch | Tensor | | Yes | + | in_h | Input height | Tensor | | Yes | + | in_w | Input width | Tensor | | Yes | + | kernel_h | kernel window height | uint32_t | | Yes | + | kernel_w | kernel window width | uint32_t | | Yes | + | stride_h | stride in height dim | uint32_t | | No | + | stride_w | stride in width dim | uint32_t | | No | + | pad_h | padding in height dim | uint32_t | | No | + | pad_w | padding in width dim | uint32_t | | No | + | dilation_h | kernel dilation in height dim | uint32_t | | No | + | dilation_w | kernel dilation in width dim | uint32_t | | No | + | memory_config | Output memory config | MemoryConfig | | No | + +-------------------+-------------------------------+---------------+-------------+----------+ + )doc"); + + module.def( + "max_pool2d_legacy", + &max_pool2d_legacy, + py::arg("input").noconvert(), + py::arg("reader_indices").noconvert(), + py::arg("in_n").noconvert(), + py::arg("in_h").noconvert(), + py::arg("in_w").noconvert(), + py::arg("kernel_h").noconvert(), + py::arg("kernel_w").noconvert(), + py::arg("stride_h") = 1, + py::arg("stride_w") = 1, + py::arg("pad_h") = 0, + py::arg("pad_w") = 0, + py::arg("dilation_h") = 1, + py::arg("dilation_w") = 1, + py::kw_only(), + py::arg("memory_config") = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("nblocks") = 1, + py::arg("use_multicore") = true, + R"doc( + Max Pool 2D + +-------------------+-------------------------------+---------------+-------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +===================+===============================+===============+=============+==========+ + | input | Input activations tensor | Tensor | | Yes | + | in_n | Input nbatch | Tensor | | Yes | + | in_h | Input height | Tensor | | Yes | + | in_w | Input width | Tensor | | Yes | + | kernel_h | kernel window height | uint32_t | | Yes | + | kernel_w | kernel window width | uint32_t | | Yes | + | stride_h | stride in height dim | uint32_t | | No | + | stride_w | stride in width dim | uint32_t | | No | + | pad_h | padding in height dim | uint32_t | | No | + | pad_w | padding in width dim | uint32_t | | No | + | dilation_h | kernel dilation in height dim | uint32_t | | No | + | dilation_w | kernel dilation in width dim | uint32_t | | No | + | memory_config | output tensor memory config | MemoryConfig | | No | + +-------------------+-------------------------------+---------------+-------------+----------+ + )doc"); +} + +} // namespace maxpool +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/upsample/upsample_op.cpp b/ttnn/cpp/ttnn/operations/upsample/upsample_op.cpp index 5b5d2d823d0..1fc453867e3 100644 --- a/ttnn/cpp/ttnn/operations/upsample/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/upsample/upsample_op.cpp @@ -10,7 +10,7 @@ #include "detail/util.hpp" #include "tensor/host_buffer/functions.hpp" #include "tensor/tensor_utils.hpp" -#include "ttnn/experimental/tt_dnn/op_library/pool/max_pool.hpp" +#include "ttnn/cpp/ttnn/operations/pool/maxpool/max_pool.hpp" #include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils #include "ttnn/experimental/tt_dnn/op_library/work_split.hpp" #include "tt_metal/host_api.hpp" diff --git a/ttnn/tt_lib/fused_ops/average_pool.py b/ttnn/tt_lib/fused_ops/average_pool.py index 22ac1781c22..a296e7ab7a0 100644 --- a/ttnn/tt_lib/fused_ops/average_pool.py +++ b/ttnn/tt_lib/fused_ops/average_pool.py @@ -5,11 +5,12 @@ import tt_lib as ttl from typing import Union, List +import ttnn def run_avg_pool_on_device_wrapper(device): - def average_pool_2d(x, output_mem_config, output_dtype=None): - out = ttl.tensor.average_pool_2d(x, output_mem_config, output_dtype) + def avg_pool2d(x, output_mem_config, output_dtype=None): + out = ttnn.avg_pool2d(x, memory_config=output_mem_config, dtype=output_dtype) return out - return average_pool_2d + return avg_pool2d diff --git a/ttnn/tt_lib/fused_ops/max_pool.py b/ttnn/tt_lib/fused_ops/max_pool.py index b8b0b8240d6..466db7f0641 100644 --- a/ttnn/tt_lib/fused_ops/max_pool.py +++ b/ttnn/tt_lib/fused_ops/max_pool.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import tt_lib as ttl - +import ttnn from typing import Union, List @@ -33,8 +33,8 @@ def max_pool_2d(x): # out_shape_nopad = compute_max_pool_shape(kernel_size, stride, padding, x_shape_nopad) # if reshape_2d and channels_last: # x = x.reshape(x_shape_nopad[0], 1, x_shape_nopad[1] * x_shape_nopad[2], x_shape_nopad[3]) - # out = ttl.tensor.max_pool2d(x, x_shape_nopad[1], x_shape_nopad[2], kernel_size, kernel_size, stride, stride, padding, padding, output_mem_config=output_mem_config, nblocks=nblocks, use_multicore=True) - out = ttl.tensor.max_pool2d( + # out = ttnn.max_pool2d(x, x_shape_nopad[1], x_shape_nopad[2], kernel_size, kernel_size, stride, stride, padding, padding, output_mem_config=output_mem_config, nblocks=nblocks, use_multicore=True) + out = ttnn.max_pool2d( x, in_n, in_h, @@ -45,7 +45,7 @@ def max_pool_2d(x): stride, padding, padding, - output_mem_config=output_mem_config, + memory_config=output_mem_config, nblocks=nblocks, use_multicore=True, ) diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index ae8983df9e4..ec97f0792a0 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -322,4 +322,4 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL determine_expected_group_norm_sharded_config_and_grid_size, ) from ttnn.operations.conv2d import Conv2d, Conv2dConfig, get_conv_output_dim, get_conv_padded_input_shape_and_mem_config -from ttnn.operations.pool import MaxPool2d +from ttnn.operations.pool import TTPyMaxPool, max_pool2d, max_pool2d_legacy, MaxPool2d, global_avg_pool2d, avg_pool2d diff --git a/ttnn/ttnn/operations/conv/tt_py_max_pool.py b/ttnn/ttnn/operations/conv/tt_py_max_pool.py deleted file mode 100644 index 03f8f878240..00000000000 --- a/ttnn/ttnn/operations/conv/tt_py_max_pool.py +++ /dev/null @@ -1,281 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -from ttnn.operations.conv.tt_py_op import TTPyOp -from ttnn.operations.conv.tt_py_untilize_with_halo import TTPyUntilizeWithHalo -from ttnn.operations.conv.untilize_with_halo_config_generation_and_validation import ( - trace_conv_to_generate_data_top_left_indices_and_pad_metadata, - decompose_conv_into_shards_and_generate_tensor_metadata, -) -from ttnn.operations.conv.sliding_window_op_config_generation_and_validation import ( - generate_sliding_window_op_sharded_input_top_left_indices, -) -from ttnn.operations.conv.tt_py_composite_conv import ( - determine_parallel_config, -) -from ttnn.operations.conv.sliding_window_op_utils import ( - SlidingWindowOpParamsWithParallelConfig, - SlidingWindowOpParams, - get_hash_from_sliding_window_op_params, - calculate_shard_grid, - calculate_memory_config, -) - -from typing import Union - -from tt_lib.utils import _nearest_32 -import tt_lib as ttl - -import math -import torch - - -class TTPyMaxPool(TTPyOp): - def __init__( - self, - sliding_window_op_params: Union[SlidingWindowOpParams, SlidingWindowOpParamsWithParallelConfig], - device, - reader_patterns_cache, - pad_val=0xF7FF, - parallel_config_override=None, - output_mem_config=None, - deallocate_activation=True, - act_dtype=None, - channels=None, - ): - if parallel_config_override is None: - parallel_config_override = {} - if "max_pool" not in reader_patterns_cache: - reader_patterns_cache["max_pool"] = {} - if "halo" not in reader_patterns_cache: - reader_patterns_cache["halo"] = {} - - for key in reader_patterns_cache: - assert ( - key == "max_pool" or key == "halo" or key == "conv" - ), f"reader_patterns_cache should have 1 of the following keys - 'conv', 'max_pool' or 'halo'. Found key - {key}" - - snap_to_tile = parallel_config_override.get("snap_to_tile", False) - df_needs_tiled = act_dtype is not None and act_dtype == ttnn.bfloat8_b - conv_parallel_config = determine_parallel_config( - True, - 0, - 0, - sliding_window_op_params, - device, - config_override=parallel_config_override, - is_out_tiled=snap_to_tile or df_needs_tiled, - ) - self.grid_size = (conv_parallel_config.grid_size.x, conv_parallel_config.grid_size.y) - self.ncores_nhw = conv_parallel_config.num_cores_nhw - self.shard_grid, self.shard_layout = calculate_shard_grid(self.grid_size, self.ncores_nhw) - assert ( - self.shard_layout == ttnn.TensorMemoryLayout.HEIGHT_SHARDED - ), "TTPyMaxPool currently only supports height sharding" - - if isinstance(sliding_window_op_params, SlidingWindowOpParams): - self.sliding_window_op_params = SlidingWindowOpParamsWithParallelConfig( - stride_h=sliding_window_op_params.stride_h, - stride_w=sliding_window_op_params.stride_w, - pad_h=sliding_window_op_params.pad_h, - pad_w=sliding_window_op_params.pad_w, - window_h=sliding_window_op_params.window_h, - window_w=sliding_window_op_params.window_w, - batch_size=sliding_window_op_params.batch_size, - input_h=sliding_window_op_params.input_h, - input_w=sliding_window_op_params.input_w, - num_cores_h=self.grid_size[1], - num_cores_w=self.grid_size[0], - num_cores_nhw=self.ncores_nhw, - ) - else: - self.sliding_window_op_params = sliding_window_op_params - - sliding_window_op_params_hash = get_hash_from_sliding_window_op_params(self.sliding_window_op_params) - - self.device = device - - self.input_sharded_memory_config = calculate_memory_config( - self.sliding_window_op_params, - True, - 0 if channels is None else channels, - calc_input=True, - tile_size=32 if snap_to_tile else 1, - ) - self.output_sharded_memory_config = ( - calculate_memory_config( - self.sliding_window_op_params, - True, - 0 if channels is None else channels, - calc_input=False, - tile_size=32 if snap_to_tile else 1, - ) - if output_mem_config is None - else output_mem_config - ) - - self.set_op_configs( - sliding_window_op_params_hash, - reader_patterns_cache["max_pool"], - ) - assert sliding_window_op_params_hash in reader_patterns_cache["max_pool"] - reader_indices = reader_patterns_cache["max_pool"][sliding_window_op_params_hash] - - self.set_op_weights_biases( - self.sliding_window_op_params, - reader_indices, - ) - - self.pad_val = pad_val - self.untilize_with_halo = TTPyUntilizeWithHalo( - self.device, - self.sliding_window_op_params, - reader_patterns_cache["halo"], - pad_val=self.pad_val, - is_out_tiled=snap_to_tile, - ) - - self.deallocate_activation = deallocate_activation - - # override abstract methods from base class TTPyOp - def set_op_configs(self, sliding_window_op_params_hash, reader_patterns_cache): - if sliding_window_op_params_hash not in reader_patterns_cache: - stride_h = self.sliding_window_op_params.stride_h - stride_w = self.sliding_window_op_params.stride_w - pad_h = self.sliding_window_op_params.pad_h - pad_w = self.sliding_window_op_params.pad_w - window_h = self.sliding_window_op_params.window_h - window_w = self.sliding_window_op_params.window_w - batch_size = self.sliding_window_op_params.batch_size - input_h = self.sliding_window_op_params.input_h - input_w = self.sliding_window_op_params.input_w - ncores_h = self.sliding_window_op_params.num_cores_h - ncores_w = self.sliding_window_op_params.num_cores_w - ncores_nhw = self.sliding_window_op_params.num_cores_nhw - - input_nchw_shape = [batch_size, 1, input_h, input_w] - input_shard_height = self.input_sharded_memory_config.shard_spec.shape[0] - output_shard_height = self.output_sharded_memory_config.shard_spec.shape[0] - input_padded_width = input_w + 2 * pad_w - - pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - (1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1), input_nchw_shape - ) - - req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( - data_top_left_indices, - pad_metadata, - input_padded_width, - output_shard_height, - input_shard_height, - ncores_nhw, - window_h, - window_w, - ) - - sliding_window_op_sharded_input_top_left_indices = ( - generate_sliding_window_op_sharded_input_top_left_indices( - data_top_left_indices, req_conv_input_shard_start_end, pad_tile=True, pad_last_core=True - ) - ) - - indices_torch_dtype = torch.int16 - indices_tt_dtype = ttnn.uint16 - - # Create sharded tensor on device for conv_reader_indices - reader_indices_torch_tensor = torch.tensor( - [[sliding_window_op_sharded_input_top_left_indices]], dtype=indices_torch_dtype - ) - reader_indices_tt_tensor = ttnn.Tensor( - reader_indices_torch_tensor, - indices_tt_dtype, - ) - shard_orientation = ttnn.ShardOrientation.ROW_MAJOR - shard_halo = False - shard_spec = ttnn.ShardSpec( - self.shard_grid, [1, reader_indices_tt_tensor.get_legacy_shape()[-1]], shard_orientation, shard_halo - ) - mem_config = ttnn.MemoryConfig(self.shard_layout, ttnn.BufferType.L1_SMALL, shard_spec) - reader_indices_sharded_tensor = reader_indices_tt_tensor.to(self.device, mem_config) - - reader_patterns_cache[sliding_window_op_params_hash] = reader_indices_sharded_tensor - - return - - def set_op_weights_biases(self, op_params, reader_indices): - stride_h = op_params.stride_h - stride_w = op_params.stride_w - pad_h = op_params.pad_h - pad_w = op_params.pad_w - window_h = op_params.window_h - window_w = op_params.window_w - in_n = op_params.batch_size - in_h = op_params.input_h - in_w = op_params.input_w - - def max_pool_(activation): - act_mem_config = activation.memory_config() - haloed_act = self.untilize_with_halo(activation) - - if self.deallocate_activation: - activation.deallocate() - output = ttl.tensor.max_pool2d_v2( - haloed_act, - reader_indices, - in_n, - in_h, - in_w, - window_h, - window_w, - stride_h, - stride_w, - pad_h, - pad_w, - output_mem_config=self.output_sharded_memory_config, - ) - haloed_act.deallocate() - return output - - self.max_pool = max_pool_ - - def __call__(self, activation): - return self.max_pool(activation) - - def copy_input_to_device(self, input: ttnn.Tensor): - in_shape = input.get_legacy_shape() - in_c = in_shape[-1] - in_n = self.sliding_window_op_params.batch_size - in_h = self.sliding_window_op_params.input_h - in_w = self.sliding_window_op_params.input_w - assert in_c % 16 == 0, "Input channels should be multiple of 16. General case is TODO" - act_shape = (1, 1, in_n * in_h * in_w, in_c) - act_reshaped = input.reshape(act_shape) - padded_nhw = self.input_sharded_memory_config.shard_spec.shape[0] * self.sliding_window_op_params.num_cores_nhw - if padded_nhw != act_shape[-2]: - padded_shape = ttnn.Shape(act_shape, (1, 1, padded_nhw, in_c)) - act_reshaped = ttl.tensor.format_input_tensor( - act_reshaped, - self.device, - padded_shape, - -float("inf"), - act_reshaped.layout, - ) - - interleaved_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1) - mem_config = self.input_sharded_memory_config - shard_shape = mem_config.shard_spec.shape - shard_shape[1] = in_c - mem_config.shard_spec.shape = shard_shape - act_reshaped = act_reshaped.to(self.device, interleaved_mem_config) - return ttl.tensor.interleaved_to_sharded( - act_reshaped, - mem_config, - input.get_dtype(), - ) - - def copy_output_from_device(self, output_d: ttnn.Tensor): - interleaved_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM) - output_d = ttl.tensor.sharded_to_interleaved(output_d, interleaved_mem_config) - return output_d.cpu() diff --git a/ttnn/ttnn/operations/pool.py b/ttnn/ttnn/operations/pool.py index 40d678634ce..0b7a588a591 100644 --- a/ttnn/ttnn/operations/pool.py +++ b/ttnn/ttnn/operations/pool.py @@ -2,19 +2,74 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple, Union, Dict - -import sys import ttnn - -from ttnn.operations.conv.tt_py_max_pool import ( - TTPyMaxPool, +from ttnn.operations.conv.tt_py_op import TTPyOp +from ttnn.operations.conv.tt_py_untilize_with_halo import TTPyUntilizeWithHalo +from ttnn.operations.conv.untilize_with_halo_config_generation_and_validation import ( + trace_conv_to_generate_data_top_left_indices_and_pad_metadata, + decompose_conv_into_shards_and_generate_tensor_metadata, +) +from ttnn.operations.conv.sliding_window_op_config_generation_and_validation import ( + generate_sliding_window_op_sharded_input_top_left_indices, +) +from ttnn.operations.conv.tt_py_composite_conv import ( + determine_parallel_config, +) +from ttnn.operations.conv.sliding_window_op_utils import ( + SlidingWindowOpParamsWithParallelConfig, SlidingWindowOpParams, + get_hash_from_sliding_window_op_params, + calculate_shard_grid, + calculate_memory_config, ) +from typing import Union, Tuple, Dict + +from tt_lib.utils import _nearest_32 import tt_lib as ttl -__all__ = [] +import math +import torch +import ttnn + + +def golden_maxpool2d( + _input_tensor: ttnn.Tensor, + in_n: int, + in_h: int, + in_w: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + *, + memory_config: ttnn.MemoryConfig, + nblocks: int, + use_multicore: bool, +): + import torch + + kernel_size = (kernel_h, kernel_w) + stride = (stride_h, stride_w) + padding = (pad_h, pad_w) + dilation = (dilation_h, dilation_w) + + return torch.nn.functional.max_pool2d( + _input_tensor, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation + ) + + +max_pool2d = ttnn.register_python_operation(name="ttnn.max_pool2d", golden_function=golden_maxpool2d)( + ttnn._ttnn.operations.pool.max_pool2d +) + +max_pool2d_legacy = ttnn.register_python_operation(name="ttnn.max_pool2d_legacy", golden_function=golden_maxpool2d)( + ttnn._ttnn.operations.pool.max_pool2d_legacy +) class MaxPool2d: @@ -112,20 +167,268 @@ def copy_output_from_device(self, output: ttnn.Tensor): return self.max_pool.copy_output_from_device(output) -## Average Pooling +class TTPyMaxPool(TTPyOp): + def __init__( + self, + sliding_window_op_params: Union[SlidingWindowOpParams, SlidingWindowOpParamsWithParallelConfig], + device, + reader_patterns_cache, + pad_val=0xF7FF, + parallel_config_override=None, + output_mem_config=None, + deallocate_activation=True, + act_dtype=None, + channels=None, + pool_op=None, + ): + self.pool_op = pool_op + if parallel_config_override is None: + parallel_config_override = {} + if "max_pool" not in reader_patterns_cache: + reader_patterns_cache["max_pool"] = {} + if "halo" not in reader_patterns_cache: + reader_patterns_cache["halo"] = {} + for key in reader_patterns_cache: + assert ( + key == "max_pool" or key == "halo" or key == "conv" + ), f"reader_patterns_cache should have 1 of the following keys - 'conv', 'max_pool' or 'halo'. Found key - {key}" -def _golden_function(input_tensor: ttnn.Tensor): - import torch + snap_to_tile = parallel_config_override.get("snap_to_tile", False) + df_needs_tiled = act_dtype is not None and act_dtype == ttnn.bfloat8_b + conv_parallel_config = determine_parallel_config( + True, + 0, + 0, + sliding_window_op_params, + device, + config_override=parallel_config_override, + is_out_tiled=snap_to_tile or df_needs_tiled, + ) + self.grid_size = (conv_parallel_config.grid_size.x, conv_parallel_config.grid_size.y) + self.ncores_nhw = conv_parallel_config.num_cores_nhw + self.shard_grid, self.shard_layout = calculate_shard_grid(self.grid_size, self.ncores_nhw) + assert ( + self.shard_layout == ttnn.TensorMemoryLayout.HEIGHT_SHARDED + ), "TTPyMaxPool currently only supports height sharding" + + if isinstance(sliding_window_op_params, SlidingWindowOpParams): + self.sliding_window_op_params = SlidingWindowOpParamsWithParallelConfig( + stride_h=sliding_window_op_params.stride_h, + stride_w=sliding_window_op_params.stride_w, + pad_h=sliding_window_op_params.pad_h, + pad_w=sliding_window_op_params.pad_w, + window_h=sliding_window_op_params.window_h, + window_w=sliding_window_op_params.window_w, + batch_size=sliding_window_op_params.batch_size, + input_h=sliding_window_op_params.input_h, + input_w=sliding_window_op_params.input_w, + num_cores_h=self.grid_size[1], + num_cores_w=self.grid_size[0], + num_cores_nhw=self.ncores_nhw, + ) + else: + self.sliding_window_op_params = sliding_window_op_params + + sliding_window_op_params_hash = get_hash_from_sliding_window_op_params(self.sliding_window_op_params) + + self.device = device + + self.input_sharded_memory_config = calculate_memory_config( + self.sliding_window_op_params, + True, + 0 if channels is None else channels, + calc_input=True, + tile_size=32 if snap_to_tile else 1, + ) + self.output_sharded_memory_config = ( + calculate_memory_config( + self.sliding_window_op_params, + True, + 0 if channels is None else channels, + calc_input=False, + tile_size=32 if snap_to_tile else 1, + ) + if output_mem_config is None + else output_mem_config + ) + + self.set_op_configs( + sliding_window_op_params_hash, + reader_patterns_cache["max_pool"], + ) + assert sliding_window_op_params_hash in reader_patterns_cache["max_pool"] + reader_indices = reader_patterns_cache["max_pool"][sliding_window_op_params_hash] + + self.set_op_weights_biases( + self.sliding_window_op_params, + reader_indices, + ) + + self.pad_val = pad_val + self.untilize_with_halo = TTPyUntilizeWithHalo( + self.device, + self.sliding_window_op_params, + reader_patterns_cache["halo"], + pad_val=self.pad_val, + is_out_tiled=snap_to_tile, + ) + + self.deallocate_activation = deallocate_activation + + # override abstract methods from base class TTPyOp + def set_op_configs(self, sliding_window_op_params_hash, reader_patterns_cache): + if sliding_window_op_params_hash not in reader_patterns_cache: + stride_h = self.sliding_window_op_params.stride_h + stride_w = self.sliding_window_op_params.stride_w + pad_h = self.sliding_window_op_params.pad_h + pad_w = self.sliding_window_op_params.pad_w + window_h = self.sliding_window_op_params.window_h + window_w = self.sliding_window_op_params.window_w + batch_size = self.sliding_window_op_params.batch_size + input_h = self.sliding_window_op_params.input_h + input_w = self.sliding_window_op_params.input_w + ncores_h = self.sliding_window_op_params.num_cores_h + ncores_w = self.sliding_window_op_params.num_cores_w + ncores_nhw = self.sliding_window_op_params.num_cores_nhw + + input_nchw_shape = [batch_size, 1, input_h, input_w] + input_shard_height = self.input_sharded_memory_config.shard_spec.shape[0] + output_shard_height = self.output_sharded_memory_config.shard_spec.shape[0] + input_padded_width = input_w + 2 * pad_w + + pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( + (1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1), input_nchw_shape + ) + + req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( + data_top_left_indices, + pad_metadata, + input_padded_width, + output_shard_height, + input_shard_height, + ncores_nhw, + window_h, + window_w, + ) + + sliding_window_op_sharded_input_top_left_indices = ( + generate_sliding_window_op_sharded_input_top_left_indices( + data_top_left_indices, req_conv_input_shard_start_end, pad_tile=True, pad_last_core=True + ) + ) + + indices_torch_dtype = torch.int16 + indices_tt_dtype = ttnn.uint16 + + # Create sharded tensor on device for conv_reader_indices + reader_indices_torch_tensor = torch.tensor( + [[sliding_window_op_sharded_input_top_left_indices]], dtype=indices_torch_dtype + ) + reader_indices_tt_tensor = ttnn.Tensor( + reader_indices_torch_tensor, + indices_tt_dtype, + ) + shard_orientation = ttnn.ShardOrientation.ROW_MAJOR + shard_halo = False + shard_spec = ttnn.ShardSpec( + self.shard_grid, [1, reader_indices_tt_tensor.get_legacy_shape()[-1]], shard_orientation, shard_halo + ) + mem_config = ttnn.MemoryConfig(self.shard_layout, ttnn.BufferType.L1_SMALL, shard_spec) + reader_indices_sharded_tensor = reader_indices_tt_tensor.to(self.device, mem_config) + + reader_patterns_cache[sliding_window_op_params_hash] = reader_indices_sharded_tensor - input_tensor = ttnn.from_device(input_tensor) - input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) - input_tensor = ttnn.to_torch(input_tensor) + return + + def set_op_weights_biases(self, op_params, reader_indices): + stride_h = op_params.stride_h + stride_w = op_params.stride_w + pad_h = op_params.pad_h + pad_w = op_params.pad_w + window_h = op_params.window_h + window_w = op_params.window_w + in_n = op_params.batch_size + in_h = op_params.input_h + in_w = op_params.input_w + + def max_pool_(activation): + act_mem_config = activation.memory_config() + haloed_act = self.untilize_with_halo(activation) + + if self.deallocate_activation: + activation.deallocate() + output = max_pool2d_legacy( + haloed_act, + reader_indices, + in_n, + in_h, + in_w, + window_h, + window_w, + stride_h, + stride_w, + pad_h, + pad_w, + memory_config=self.output_sharded_memory_config, + ) + haloed_act.deallocate() + return output + + self.max_pool = max_pool_ + + def __call__(self, activation): + return self.max_pool(activation) + + def copy_input_to_device(self, input: ttnn.Tensor): + in_shape = input.get_legacy_shape() + in_c = in_shape[-1] + in_n = self.sliding_window_op_params.batch_size + in_h = self.sliding_window_op_params.input_h + in_w = self.sliding_window_op_params.input_w + assert in_c % 16 == 0, "Input channels should be multiple of 16. General case is TODO" + act_shape = (1, 1, in_n * in_h * in_w, in_c) + act_reshaped = input.reshape(act_shape) + padded_nhw = self.input_sharded_memory_config.shard_spec.shape[0] * self.sliding_window_op_params.num_cores_nhw + if padded_nhw != act_shape[-2]: + padded_shape = ttnn.Shape(act_shape, (1, 1, padded_nhw, in_c)) + act_reshaped = ttl.tensor.format_input_tensor( + act_reshaped, + self.device, + padded_shape, + -float("inf"), + act_reshaped.layout, + ) + + interleaved_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1) + mem_config = self.input_sharded_memory_config + shard_shape = mem_config.shard_spec.shape + shard_shape[1] = in_c + mem_config.shard_spec.shape = shard_shape + act_reshaped = act_reshaped.to(self.device, interleaved_mem_config) + return ttl.tensor.interleaved_to_sharded( + act_reshaped, + mem_config, + input.get_dtype(), + ) + + def copy_output_from_device(self, output_d: ttnn.Tensor): + interleaved_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM) + output_d = ttl.tensor.sharded_to_interleaved(output_d, interleaved_mem_config) + return output_d.cpu() + + +def golden_global_avg_pool2d(input_tensor: ttnn.Tensor): + import torch output_size = (1, 1) return torch.nn.functional.global_avg_pool2d(input_tensor, output_size) -ttnn.attach_golden_function(ttnn.global_avg_pool2d, golden_function=_golden_function) +global_avg_pool2d = ttnn.register_python_operation( + name="ttnn.global_avg_pool2d", golden_function=golden_global_avg_pool2d +)(ttnn._ttnn.operations.pool.global_avg_pool2d) -__all__ = [] +avg_pool2d = ttnn.register_python_operation(name="ttnn.avg_pool2d", golden_function=golden_global_avg_pool2d)( + ttnn._ttnn.operations.pool.avg_pool2d +)