diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 2cc43aa8ff7..82173ef6d2a 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -16,6 +16,7 @@ #include "transformer.hpp" #include "normalization.hpp" #include "kv_cache.hpp" +#include "pool.hpp" namespace py = pybind11; @@ -50,6 +51,9 @@ void py_module(py::module& module) { auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations"); kv_cache::py_module(m_kv_cache); + + auto m_pool = module.def_submodule("pool", "pool operations"); + pool::py_module(m_pool); } } // namespace operations diff --git a/ttnn/cpp/pybind11/operations/binary.hpp b/ttnn/cpp/pybind11/operations/binary.hpp index 3ecfd30d618..095ac9b0708 100644 --- a/ttnn/cpp/pybind11/operations/binary.hpp +++ b/ttnn/cpp/pybind11/operations/binary.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/binary.hpp" #include "ttnn/types.hpp" diff --git a/ttnn/cpp/pybind11/operations/ccl.hpp b/ttnn/cpp/pybind11/operations/ccl.hpp index 7701da39549..03d049b8ed9 100644 --- a/ttnn/cpp/pybind11/operations/ccl.hpp +++ b/ttnn/cpp/pybind11/operations/ccl.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/ccl.hpp" #include "ttnn/types.hpp" diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index 76b9985305e..f92773cf3f9 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/core.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/pybind11/operations/kv_cache.hpp b/ttnn/cpp/pybind11/operations/kv_cache.hpp index a616a36afcb..f994782d8dd 100644 --- a/ttnn/cpp/pybind11/operations/kv_cache.hpp +++ b/ttnn/cpp/pybind11/operations/kv_cache.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/kv_cache.hpp" #include "ttnn/types.hpp" diff --git a/ttnn/cpp/pybind11/operations/normalization.hpp b/ttnn/cpp/pybind11/operations/normalization.hpp index 955dbe3e88a..8d72f20a62c 100644 --- a/ttnn/cpp/pybind11/operations/normalization.hpp +++ b/ttnn/cpp/pybind11/operations/normalization.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/normalization.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/pybind11/operations/pool.hpp b/ttnn/cpp/pybind11/operations/pool.hpp new file mode 100644 index 00000000000..547ff7e30a1 --- /dev/null +++ b/ttnn/cpp/pybind11/operations/pool.hpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/pool.hpp" +#include "ttnn/types.hpp" + +namespace py = pybind11; + +namespace ttnn { +namespace operations { +namespace pool { + +namespace detail { + +void bind_global_avg_pool2d(py::module& module) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` by performing a 2D adaptive average pooling over an input signal composed of several input planes. This operation computes the average of all elements in each channel across the entire spatial dimensions. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` (ttnn.Tensor): The input tensor to be pooled. Typically of shape (batch_size, channels, height, width). + + Keyword Args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor + + Returns: + ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1). + + Example:: + + >>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device) + >>> output = {1}(tensor) + )doc", + ttnn::operations::pool::global_avg_pool2d.name(), + ttnn::operations::pool::global_avg_pool2d.python_fully_qualified_name()); + + bind_registered_operation( + module, + ttnn::operations::pool::global_avg_pool2d, + doc, + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt}); +} + +} // namespace detail + +void py_module(py::module& module) { + detail::bind_global_avg_pool2d(module); +} + +} // namespace pool +} // namespace operations +} // namespace ttnn diff --git a/ttnn/cpp/pybind11/operations/transformer.hpp b/ttnn/cpp/pybind11/operations/transformer.hpp index 62efe5678ee..c753a8c9d1a 100644 --- a/ttnn/cpp/pybind11/operations/transformer.hpp +++ b/ttnn/cpp/pybind11/operations/transformer.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/transformer.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index f98b4c1aecb..42e28445bca 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -7,7 +7,7 @@ #include #include -#include "../decorators.hpp" +#include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/unary.hpp" #include "ttnn/types.hpp" diff --git a/ttnn/cpp/ttnn/operations/pool.hpp b/ttnn/cpp/ttnn/operations/pool.hpp new file mode 100644 index 00000000000..e300503e339 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool.hpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/decorators.hpp" +#include "ttnn/operations/core.hpp" +#include "tt_eager/tt_dnn/op_library/pool/average_pool.hpp" + +namespace ttnn { +namespace operations { +namespace pool { + +namespace detail { +inline const std::array input_tensor_schemas() { + return {ttnn::TensorSchema{ + 4, // min rank + 4, // max rank + {ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::uint32}, + {ttnn::TILE_LAYOUT}, + true, // can_be_on_device + false, // can_be_on_cpu + false, // can_be_scalar + false // is_optional} + }}; +} +} // namespace details + +struct GlobalAveragePool2D { + static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } + + template + static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { + return std::make_tuple(input_tensor); + } + + static Tensor execute(const Tensor& input, const std::optional& memory_config_arg = std::nullopt, const std::optional& output_dtype = std::nullopt) { + auto memory_config = memory_config_arg.value_or(input.memory_config()); + auto result = tt::tt_metal::average_pool_2d(input, memory_config, output_dtype); + return result; + } +}; +constexpr auto global_avg_pool2d = ttnn::register_operation("ttnn::pool::global_avg_pool2d"); +} // namespace pool +} // namespace operations +} // namespace ttnn diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 52d3c3f409b..58804c4effc 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -465,7 +465,7 @@ def manage_config(name, value): from ttnn.operations import transformer from ttnn.operations import kv_cache from ttnn.operations.conv2d import Conv2d -from ttnn.operations.maxpool2d import ( +from ttnn.operations.pool import ( MaxPool2d, global_avg_pool2d, ) diff --git a/ttnn/ttnn/operations/maxpool2d.py b/ttnn/ttnn/operations/pool.py similarity index 79% rename from ttnn/ttnn/operations/maxpool2d.py rename to ttnn/ttnn/operations/pool.py index 3a5a3646228..3d30efbd718 100644 --- a/ttnn/ttnn/operations/maxpool2d.py +++ b/ttnn/ttnn/operations/pool.py @@ -6,6 +6,7 @@ import tt_lib as ttl +import sys import ttnn from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import ( @@ -13,6 +14,10 @@ SlidingWindowOpParams, ) +THIS_MODULE = sys.modules[__name__] + +__all__ = [] + class MaxPool2d: r""" @@ -117,7 +122,7 @@ def copy_output_from_device(self, output: ttnn.Tensor): ## Average Pooling -def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor): +def _golden_function(input_tensor: ttnn.Tensor): import torch input_tensor = ttnn.from_device(input_tensor) @@ -128,32 +133,8 @@ def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor): return torch.nn.functional.global_avg_pool2d(input_tensor, output_size) -def _global_avg_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(4,), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) - - -@ttnn.register_operation( - name="ttnn.global_avg_pool2d", - validate_input_tensors=_global_avg_pool2d_validate_input_tensors, - golden_function=_torch_global_avg_pool2d, +global_avg_pool2d = ttnn.register_operation(golden_function=_golden_function)( + ttnn._ttnn.operations.pool.global_avg_pool2d ) -def global_avg_pool2d(input_tensor: ttnn.Tensor, memory_config: ttnn.MemoryConfig = None) -> ttnn.Tensor: - r""" - Applies a 2D adaptive average pooling over an input signal composed of several input planes. - Arguments: - * :attr: input_tensor: the input tensor - """ - if memory_config is None: - output = ttl.tensor.average_pool_2d(input_tensor) - else: - output = ttl.tensor.average_pool_2d(input_tensor, memory_config) - return output +__all__ = []