diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 0a1d72134eb4..0bda64159e45 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -83,6 +83,7 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp index e824bcf280f9..8954f7dfdd52 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp @@ -14,15 +14,7 @@ MaxPoolNew::program_factory_t MaxPoolNew::select_program_factory(const operation return MultiCore{}; } -void MaxPoolNew::validate_on_program_cache_miss(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - return validate(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); -} - -void MaxPoolNew::validate_on_program_cache_hit(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - return validate(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); -} - -void MaxPoolNew::validate(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config) { +void validate_maxpool(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config) { TT_FATAL(input.storage_type() == StorageType::DEVICE, "Operands to reshape need to be on device!"); TT_FATAL(input.buffer() != nullptr , "Operands to reshape need to be allocated in buffers on device!"); TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Only BFLOAT16 supported for now"); @@ -40,11 +32,20 @@ void MaxPoolNew::validate(const Tensor& input, const tt::tt_metal::SlidingWindow TT_FATAL(out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Only height sharded tensors are supported."); } -MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - return compute_output_shapes(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.output_dtype_, op_attr.memory_config_); +void MaxPoolNew::validate_on_program_cache_miss(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { + return validate_maxpool(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); } -MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config) { +void MaxPoolNew::validate_on_program_cache_hit(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { + return validate_maxpool(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_); +} + +MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { + auto& input = tensors.input_tensor_; + auto& sliding_window_config = op_attr.sliding_window_config_; + auto& out_mem_config = op_attr.memory_config_; + auto& output_dtype = op_attr.output_dtype_; + // NOTE: Only for RM // NOTE2: Assuming { N, 1, H * W, C } // NOTE3: Assuming output data type is same as input @@ -74,11 +75,12 @@ MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const Tensor& } MaxPoolNew::tensor_return_value_t MaxPoolNew::create_output_tensors(const operation_attributes_t& op_attr, const tensor_args_t& tensors) { - return create_output_tensors(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.output_dtype_, op_attr.memory_config_); -} + auto& input = tensors.input_tensor_; + auto& sliding_window_config = op_attr.sliding_window_config_; + auto& out_mem_config = op_attr.memory_config_; + auto& output_dtype = op_attr.output_dtype_; -MaxPoolNew::tensor_return_value_t MaxPoolNew::create_output_tensors(const Tensor &input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config) { - Shape output_shape = compute_output_shapes(input, sliding_window_config, output_dtype, out_mem_config); + Shape output_shape = compute_output_shapes(op_attr, tensors); auto mem_config = out_mem_config; if (mem_config.shard_spec.has_value()) { mem_config.shard_spec->shape[1] = output_shape[3]; @@ -103,7 +105,6 @@ tt::stl::hash::hash_t MaxPoolNew::compute_program_hash(const operation_attribute return operation::hash_operation(op_attr.sliding_window_config_.get_hash(), op_attr.memory_config_, input_mem_config, dtype); } - operation::OpPerformanceModel MaxPoolNew::create_op_performance_model(const operation_attributes_t& op_attr, const tensor_args_t& inputs, const Tensor& output) { const auto& input = inputs.input_tensor_; const auto& input_shape = input.get_shape(); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp index 750064045383..e9e9a9119e38 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.hpp @@ -56,10 +56,6 @@ struct MaxPoolNew { const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& output_tensor); - static cached_program_t max_pool_2d_multi_core_sharded_with_halo_v2_new(const Tensor &input, - Tensor& output, - const SlidingWindowConfig& sliding_window_config, - const MemoryConfig& out_mem_config); }; using program_factory_t = std::variant; @@ -70,11 +66,6 @@ struct MaxPoolNew { static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); static Tensor create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&); - - // call old funcs from the above - static void validate(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config); - static shape_return_value_t compute_output_shapes(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config); - static tensor_return_value_t create_output_tensors(const Tensor &input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config); static operation::OpPerformanceModel create_op_performance_model(const operation_attributes_t&, const tensor_args_t&, const Tensor&); }; diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp index e5487ba818ad..d76202e24979 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp @@ -2,7 +2,19 @@ // // SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/core.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/types.hpp" +#include "ttnn/operations/conv2d/conv2d.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp" + #include "max_pool2d_device_op.hpp" +// #include "max_pool2d_multi_core_program_factory.hpp" #include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils /** @@ -314,11 +326,11 @@ MaxPoolNew::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo }}; } -MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi_core_sharded_with_halo_v2_new( - const Tensor& input, - Tensor& output, - const SlidingWindowConfig& sliding_window_config, - const MemoryConfig& out_mem_config) { +MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::create(const operation_attributes_t& op_attr, const tensor_args_t& tensor_args, tensor_return_value_t& output_tensor) { + const auto& input = tensor_args.input_tensor_; + auto& sliding_window_config = op_attr.sliding_window_config_; + auto& out_mem_config = op_attr.memory_config_; + tt::tt_metal::Program program{}; ParallelConfig parallel_config = ParallelConfig{ @@ -362,7 +374,7 @@ MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi program, input, reader_indices_on_device, - output, + output_tensor, in_n, in_h, in_w, @@ -380,15 +392,6 @@ MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi 1); } -MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::create(const operation_attributes_t& op_attr, const tensor_args_t& tensor_args, tensor_return_value_t& output_tensor) { - const auto& input = tensor_args.input_tensor_; - return max_pool_2d_multi_core_sharded_with_halo_v2_new( - input, - output_tensor, - op_attr.sliding_window_config_, - op_attr.memory_config_); -} - void MaxPoolNew::MultiCore::override_runtime_arguments(cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp new file mode 100644 index 000000000000..12d5bc270efe --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/types.hpp" + +#include "ttnn/operations/pool/maxpool/max_pool2d.hpp" +#include "ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp" + + +namespace ttnn::operations::pool { + +void bind_max_pool2d_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::max_pool2d_new, + R"doc( + Max Pool 2D + +-------------------+-------------------------------+---------------+-------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +===================+===============================+===============+=============+==========+ + | input | Input activations tensor | Tensor | | Yes | + | in_n | Input nbatch | Tensor | | Yes | + | in_h | Input height | Tensor | | Yes | + | in_w | Input width | Tensor | | Yes | + | kernel_h | kernel window height | uint32_t | | Yes | + | kernel_w | kernel window width | uint32_t | | Yes | + | stride_h | stride in height dim | uint32_t | | No | + | stride_w | stride in width dim | uint32_t | | No | + | pad_h | padding in height dim | uint32_t | | No | + | pad_w | padding in width dim | uint32_t | | No | + | dilation_h | kernel dilation in height dim | uint32_t | | No | + | dilation_w | kernel dilation in width dim | uint32_t | | No | + | memory_config | Output memory config | MemoryConfig | | No | + +-------------------+-------------------------------+---------------+-------------+----------+ + )doc", + ttnn::pybind_overload_t{ + [](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor, + uint32_t batch_size, + uint32_t input_h, + uint32_t input_w, + uint32_t channels, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + ttnn::Device* device, + const uint8_t& queue_id) + -> ttnn::Tensor { return self(queue_id, + input_tensor, + batch_size, + input_h, + input_w, + channels, + kernel_size, + stride, + padding, + dilation, + device); }, + py::arg("input_tensor"), + py::arg("batch_size"), + py::arg("input_h"), + py::arg("input_w"), + py::arg("channels"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::kw_only(), + py::arg("device"), + py::arg("queue_id") = 0}, + ttnn::pybind_overload_t{ + [](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor, + uint32_t batch_size, + uint32_t input_h, + uint32_t input_w, + uint32_t channels, + std::array kernel_size, + std::array stride, + std::array padding, + std::array dilation, + DeviceMesh* device, + const uint8_t& queue_id) + -> ttnn::Tensor { return self(queue_id, + input_tensor, + batch_size, + input_h, + input_w, + channels, + kernel_size, + stride, + padding, + dilation, + device); }, + py::arg("input_tensor"), + py::arg("batch_size"), + py::arg("input_h"), + py::arg("input_w"), + py::arg("channels"), + py::arg("kernel_size"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::kw_only(), + py::arg("device"), + py::arg("queue_id") = 0}); +} + +} // namespace ttnn::operations::pool diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp index 757f4052e3bb..61b4290d98db 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp @@ -7,110 +7,10 @@ #include #include -#include "ttnn/cpp/pybind11/decorators.hpp" -#include "ttnn/operations/pool/maxpool/max_pool2d.hpp" -#include "ttnn/types.hpp" - namespace py = pybind11; namespace ttnn::operations::pool { -void bind_max_pool2d_operation(py::module& module) { - bind_registered_operation( - module, - ttnn::max_pool2d_new, - R"doc( - Max Pool 2D - +-------------------+-------------------------------+---------------+-------------+----------+ - | Argument | Description | Data type | Valid range | Required | - +===================+===============================+===============+=============+==========+ - | input | Input activations tensor | Tensor | | Yes | - | in_n | Input nbatch | Tensor | | Yes | - | in_h | Input height | Tensor | | Yes | - | in_w | Input width | Tensor | | Yes | - | kernel_h | kernel window height | uint32_t | | Yes | - | kernel_w | kernel window width | uint32_t | | Yes | - | stride_h | stride in height dim | uint32_t | | No | - | stride_w | stride in width dim | uint32_t | | No | - | pad_h | padding in height dim | uint32_t | | No | - | pad_w | padding in width dim | uint32_t | | No | - | dilation_h | kernel dilation in height dim | uint32_t | | No | - | dilation_w | kernel dilation in width dim | uint32_t | | No | - | memory_config | Output memory config | MemoryConfig | | No | - +-------------------+-------------------------------+---------------+-------------+----------+ - )doc", - ttnn::pybind_overload_t{ - [](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor, - uint32_t batch_size, - uint32_t input_h, - uint32_t input_w, - uint32_t channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - ttnn::Device* device, - const uint8_t& queue_id) - -> ttnn::Tensor { return self(queue_id, - input_tensor, - batch_size, - input_h, - input_w, - channels, - kernel_size, - stride, - padding, - dilation, - device); }, - py::arg("input_tensor"), - py::arg("batch_size"), - py::arg("input_h"), - py::arg("input_w"), - py::arg("channels"), - py::arg("kernel_size"), - py::arg("stride"), - py::arg("padding"), - py::arg("dilation"), - py::kw_only(), - py::arg("device"), - py::arg("queue_id") = 0}, - ttnn::pybind_overload_t{ - [](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor, - uint32_t batch_size, - uint32_t input_h, - uint32_t input_w, - uint32_t channels, - std::array kernel_size, - std::array stride, - std::array padding, - std::array dilation, - DeviceMesh* device, - const uint8_t& queue_id) - -> ttnn::Tensor { return self(queue_id, - input_tensor, - batch_size, - input_h, - input_w, - channels, - kernel_size, - stride, - padding, - dilation, - device); }, - py::arg("input_tensor"), - py::arg("batch_size"), - py::arg("input_h"), - py::arg("input_w"), - py::arg("channels"), - py::arg("kernel_size"), - py::arg("stride"), - py::arg("padding"), - py::arg("dilation"), - py::kw_only(), - py::arg("device"), - py::arg("queue_id") = 0}); -} - -// void py_module(py::module& module) { bind_example_operation(module); } +void bind_max_pool2d_operation(py::module& module); } // namespace ttnn::operations::pool