Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8117: Move global_avg_pool2d to C++ #8583

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "transformer.hpp"
#include "normalization.hpp"
#include "kv_cache.hpp"
#include "pool.hpp"

namespace py = pybind11;

Expand Down Expand Up @@ -50,6 +51,9 @@ void py_module(py::module& module) {

auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations");
kv_cache::py_module(m_kv_cache);

auto m_pool = module.def_submodule("pool", "pool operations");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nicer to rename this to "pooling" to better align with torch:

We can move the Max/Avg pool operations when we migrate those to C++ as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do 👍

pool::py_module(m_pool);
}

} // namespace operations
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 thanks for refactoring this

#include "ttnn/operations/binary.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/ccl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/ccl.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/core.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/kv_cache.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/normalization.hpp"

namespace py = pybind11;
Expand Down
68 changes: 68 additions & 0 deletions ttnn/cpp/pybind11/operations/pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/pool.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace pool {

namespace detail {

void bind_global_avg_pool2d(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` by performing a 2D adaptive average pooling over an input signal composed of several input planes. This operation computes the average of all elements in each channel across the entire spatial dimensions.

.. math::
{0}(\\mathrm{{input\\_tensor}}_i)

Args:
* :attr:`input_tensor` (ttnn.Tensor): The input tensor to be pooled. Typically of shape (batch_size, channels, height, width).

Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor

Returns:
ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1).

Example::

>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
ttnn::operations::pool::global_avg_pool2d.name(),
ttnn::operations::pool::global_avg_pool2d.python_fully_qualified_name());

bind_registered_operation(
module,
ttnn::operations::pool::global_avg_pool2d,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("dtype") = std::nullopt});
}

} // namespace detail

void py_module(py::module& module) {
detail::bind_global_avg_pool2d(module);
}

} // namespace pool
} // namespace operations
} // namespace ttnn
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/transformer.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/unary.hpp"
#include "ttnn/types.hpp"

Expand Down
47 changes: 47 additions & 0 deletions ttnn/cpp/ttnn/operations/pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/pool/average_pool.hpp"

namespace ttnn {
namespace operations {
namespace pool {

namespace detail {
inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
4, // min rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min/max rank can only be 4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, relied on python validation info here.
I still struggle to understand real limits of the underlying ops from reading them.
Will gladly take any advice here.

4, // max rank
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::uint32},
{ttnn::TILE_LAYOUT},
true, // can_be_on_device
false, // can_be_on_cpu
false, // can_be_scalar
false // is_optional}
}};
}
} // namespace details

struct GlobalAveragePool2D {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static Tensor execute(const Tensor& input, const std::optional<MemoryConfig>& memory_config_arg = std::nullopt, const std::optional<DataType>& output_dtype = std::nullopt) {
auto memory_config = memory_config_arg.value_or(input.memory_config());
auto result = tt::tt_metal::average_pool_2d(input, memory_config, output_dtype);
return result;
}
};
constexpr auto global_avg_pool2d = ttnn::register_operation<ttnn::operations::pool::GlobalAveragePool2D>("ttnn::pool::global_avg_pool2d");
} // namespace pool
} // namespace operations
} // namespace ttnn
2 changes: 1 addition & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def manage_config(name, value):
from ttnn.operations import transformer
from ttnn.operations import kv_cache
from ttnn.operations.conv2d import Conv2d
from ttnn.operations.maxpool2d import (
from ttnn.operations.pool import (
MaxPool2d,
global_avg_pool2d,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@

import tt_lib as ttl

import sys
import ttnn

from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import (
TTPyMaxPool,
SlidingWindowOpParams,
)

THIS_MODULE = sys.modules[__name__]

__all__ = []


class MaxPool2d:
r"""
Expand Down Expand Up @@ -117,7 +122,7 @@ def copy_output_from_device(self, output: ttnn.Tensor):
## Average Pooling


def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor):
def _golden_function(input_tensor: ttnn.Tensor):
import torch

input_tensor = ttnn.from_device(input_tensor)
Expand All @@ -128,32 +133,8 @@ def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor):
return torch.nn.functional.global_avg_pool2d(input_tensor, output_size)


def _global_avg_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to cpp

ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(4,),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint16, ttnn.uint32),
layouts=(ttnn.TILE_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
)


@ttnn.register_operation(
name="ttnn.global_avg_pool2d",
validate_input_tensors=_global_avg_pool2d_validate_input_tensors,
golden_function=_torch_global_avg_pool2d,
global_avg_pool2d = ttnn.register_operation(golden_function=_golden_function)(
ttnn._ttnn.operations.pool.global_avg_pool2d
)
def global_avg_pool2d(input_tensor: ttnn.Tensor, memory_config: ttnn.MemoryConfig = None) -> ttnn.Tensor:
r"""
Applies a 2D adaptive average pooling over an input signal composed of several input planes.

Arguments:
* :attr: input_tensor: the input tensor
"""
if memory_config is None:
ayerofieiev-tt marked this conversation as resolved.
Show resolved Hide resolved
output = ttl.tensor.average_pool_2d(input_tensor)
else:
output = ttl.tensor.average_pool_2d(input_tensor, memory_config)
return output
__all__ = []
Loading