Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8117: Move global_avg_pool2d to C++ #8583

Merged
merged 1 commit into from
May 18, 2024

Conversation

ayerofieiev-tt
Copy link
Member

@ayerofieiev-tt ayerofieiev-tt commented May 17, 2024

@@ -128,32 +133,8 @@ def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor):
return torch.nn.functional.global_avg_pool2d(input_tensor, output_size)


def _global_avg_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to cpp

@ayerofieiev-tt ayerofieiev-tt force-pushed the ay/issue-8117-global_avg_pool2d_to_cpp branch from 8e471a3 to 5618f7f Compare May 17, 2024 23:46
@ayerofieiev-tt ayerofieiev-tt merged commit a60d17c into main May 18, 2024
5 checks passed
@ayerofieiev-tt ayerofieiev-tt deleted the ay/issue-8117-global_avg_pool2d_to_cpp branch May 18, 2024 00:27
@@ -50,6 +51,9 @@ void py_module(py::module& module) {

auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations");
kv_cache::py_module(m_kv_cache);

auto m_pool = module.def_submodule("pool", "pool operations");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nicer to rename this to "pooling" to better align with torch:

We can move the Max/Avg pool operations when we migrate those to C++ as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do 👍

@@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 thanks for refactoring this

namespace detail {
inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
4, // min rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min/max rank can only be 4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, relied on python validation info here.
I still struggle to understand real limits of the underlying ops from reading them.
Will gladly take any advice here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants