-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#8117: Move global_avg_pool2d to C++ #8583
Conversation
@@ -128,32 +133,8 @@ def _torch_global_avg_pool2d(input_tensor: ttnn.Tensor): | |||
return torch.nn.functional.global_avg_pool2d(input_tensor, output_size) | |||
|
|||
|
|||
def _global_avg_pool2d_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to cpp
c1964fc
to
8e471a3
Compare
8e471a3
to
5618f7f
Compare
@@ -50,6 +51,9 @@ void py_module(py::module& module) { | |||
|
|||
auto m_kv_cache = module.def_submodule("kv_cache", "KV cache operations"); | |||
kv_cache::py_module(m_kv_cache); | |||
|
|||
auto m_pool = module.def_submodule("pool", "pool operations"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be nicer to rename this to "pooling" to better align with torch:
We can move the Max/Avg pool operations when we migrate those to C++ as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do 👍
@@ -7,7 +7,7 @@ | |||
#include <pybind11/pybind11.h> | |||
#include <pybind11/stl.h> | |||
|
|||
#include "../decorators.hpp" | |||
#include "ttnn/cpp/pybind11/decorators.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 thanks for refactoring this
namespace detail { | ||
inline const std::array<ttnn::TensorSchema, 1> input_tensor_schemas() { | ||
return {ttnn::TensorSchema{ | ||
4, // min rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min/max rank can only be 4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, relied on python validation info here.
I still struggle to understand real limits of the underlying ops from reading them.
Will gladly take any advice here.
#8117