Skip to content

Commit

Permalink
#0: correct include
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed May 17, 2024
1 parent f3d899a commit 3e077a0
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/binary.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/ccl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/ccl.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/core.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/kv_cache.hpp"
#include "ttnn/types.hpp"

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/normalization.hpp"

namespace py = pybind11;
Expand Down
9 changes: 5 additions & 4 deletions ttnn/cpp/pybind11/operations/pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/pool.hpp"
#include "ttnn/types.hpp"

Expand All @@ -21,7 +21,7 @@ namespace detail {

void bind_global_avg_pool2d(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` by performing a 2D adaptive average pooling over an input signal composed of several input planes. This operation computes the average of all elements in each channel across the entire spatial dimensions.
Expand All @@ -33,13 +33,14 @@ void bind_global_avg_pool2d(py::module& module) {
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor
Returns:
ttnn.Tensor: The tensor with the averaged values. The output tensor shape is (batch_size, channels, 1, 1).
Example::
>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=torch.float32), device=device)
>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
ttnn::operations::pool::global_avg_pool2d.name(),
Expand All @@ -53,7 +54,7 @@ void bind_global_avg_pool2d(py::module& module) {
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_dtype") = std::nullopt});
py::arg("dtype") = std::nullopt});
}

} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/transformer.hpp"

namespace py = pybind11;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/decorators.hpp"
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/unary.hpp"
#include "ttnn/types.hpp"

Expand Down

0 comments on commit 3e077a0

Please sign in to comment.