Skip to content

Commit

Permalink
#9999: Move pybind impl to cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
mywoodstock committed Jul 31, 2024
1 parent 832a92e commit 90c51e3
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 142 deletions.
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/avgpool/avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_single_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/device/max_pool_program_factory.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,7 @@ MaxPoolNew::program_factory_t MaxPoolNew::select_program_factory(const operation
return MultiCore{};
}

void MaxPoolNew::validate_on_program_cache_miss(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return validate(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_);
}

void MaxPoolNew::validate_on_program_cache_hit(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return validate(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_);
}

void MaxPoolNew::validate(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config) {
void validate_maxpool(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config) {
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Operands to reshape need to be on device!");
TT_FATAL(input.buffer() != nullptr , "Operands to reshape need to be allocated in buffers on device!");
TT_FATAL(input.get_dtype() == DataType::BFLOAT16, "Only BFLOAT16 supported for now");
Expand All @@ -40,11 +32,20 @@ void MaxPoolNew::validate(const Tensor& input, const tt::tt_metal::SlidingWindow
TT_FATAL(out_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Only height sharded tensors are supported.");
}

MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return compute_output_shapes(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.output_dtype_, op_attr.memory_config_);
void MaxPoolNew::validate_on_program_cache_miss(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return validate_maxpool(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_);
}

MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config) {
void MaxPoolNew::validate_on_program_cache_hit(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return validate_maxpool(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.memory_config_);
}

MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
auto& input = tensors.input_tensor_;
auto& sliding_window_config = op_attr.sliding_window_config_;
auto& out_mem_config = op_attr.memory_config_;
auto& output_dtype = op_attr.output_dtype_;

// NOTE: Only for RM
// NOTE2: Assuming { N, 1, H * W, C }
// NOTE3: Assuming output data type is same as input
Expand Down Expand Up @@ -74,11 +75,12 @@ MaxPoolNew::shape_return_value_t MaxPoolNew::compute_output_shapes(const Tensor&
}

MaxPoolNew::tensor_return_value_t MaxPoolNew::create_output_tensors(const operation_attributes_t& op_attr, const tensor_args_t& tensors) {
return create_output_tensors(tensors.input_tensor_, op_attr.sliding_window_config_, op_attr.output_dtype_, op_attr.memory_config_);
}
auto& input = tensors.input_tensor_;
auto& sliding_window_config = op_attr.sliding_window_config_;
auto& out_mem_config = op_attr.memory_config_;
auto& output_dtype = op_attr.output_dtype_;

MaxPoolNew::tensor_return_value_t MaxPoolNew::create_output_tensors(const Tensor &input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config) {
Shape output_shape = compute_output_shapes(input, sliding_window_config, output_dtype, out_mem_config);
Shape output_shape = compute_output_shapes(op_attr, tensors);
auto mem_config = out_mem_config;
if (mem_config.shard_spec.has_value()) {
mem_config.shard_spec->shape[1] = output_shape[3];
Expand All @@ -103,7 +105,6 @@ tt::stl::hash::hash_t MaxPoolNew::compute_program_hash(const operation_attribute
return operation::hash_operation<MaxPoolNew>(op_attr.sliding_window_config_.get_hash(), op_attr.memory_config_, input_mem_config, dtype);
}


operation::OpPerformanceModel MaxPoolNew::create_op_performance_model(const operation_attributes_t& op_attr, const tensor_args_t& inputs, const Tensor& output) {
const auto& input = inputs.input_tensor_;
const auto& input_shape = input.get_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ struct MaxPoolNew {
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
tensor_return_value_t& output_tensor);
static cached_program_t max_pool_2d_multi_core_sharded_with_halo_v2_new(const Tensor &input,
Tensor& output,
const SlidingWindowConfig& sliding_window_config,
const MemoryConfig& out_mem_config);
};

using program_factory_t = std::variant<MultiCore>;
Expand All @@ -70,11 +66,6 @@ struct MaxPoolNew {
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static Tensor create_output_tensors(const operation_attributes_t&, const tensor_args_t&);
static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&);

// call old funcs from the above
static void validate(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, const MemoryConfig& out_mem_config);
static shape_return_value_t compute_output_shapes(const Tensor& input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config);
static tensor_return_value_t create_output_tensors(const Tensor &input, const tt::tt_metal::SlidingWindowConfig& sliding_window_config, DataType output_dtype, const MemoryConfig& out_mem_config);
static operation::OpPerformanceModel create_op_performance_model(const operation_attributes_t&, const tensor_args_t&, const Tensor&);

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@
//
// SPDX-License-Identifier: Apache-2.0


#include <optional>
#include <variant>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/core.hpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/types.hpp"
#include "ttnn/operations/conv2d/conv2d.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/sliding_window_op_infra/sliding_window.hpp"

#include "max_pool2d_device_op.hpp"
// #include "max_pool2d_multi_core_program_factory.hpp"
#include "tt_dnn/op_library/reduce/reduce_op.hpp" // for reduce_op_utils

/**
Expand Down Expand Up @@ -314,11 +326,11 @@ MaxPoolNew::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo
}};
}

MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi_core_sharded_with_halo_v2_new(
const Tensor& input,
Tensor& output,
const SlidingWindowConfig& sliding_window_config,
const MemoryConfig& out_mem_config) {
MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::create(const operation_attributes_t& op_attr, const tensor_args_t& tensor_args, tensor_return_value_t& output_tensor) {
const auto& input = tensor_args.input_tensor_;
auto& sliding_window_config = op_attr.sliding_window_config_;
auto& out_mem_config = op_attr.memory_config_;

tt::tt_metal::Program program{};

ParallelConfig parallel_config = ParallelConfig{
Expand Down Expand Up @@ -362,7 +374,7 @@ MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi
program,
input,
reader_indices_on_device,
output,
output_tensor,
in_n,
in_h,
in_w,
Expand All @@ -380,15 +392,6 @@ MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::max_pool_2d_multi
1);
}

MaxPoolNew::MultiCore::cached_program_t MaxPoolNew::MultiCore::create(const operation_attributes_t& op_attr, const tensor_args_t& tensor_args, tensor_return_value_t& output_tensor) {
const auto& input = tensor_args.input_tensor_;
return max_pool_2d_multi_core_sharded_with_halo_v2_new(
input,
output_tensor,
op_attr.sliding_window_config_,
op_attr.memory_config_);
}

void MaxPoolNew::MultiCore::override_runtime_arguments(cached_program_t& cached_program,
const operation_attributes_t& operation_attributes,
const tensor_args_t& tensor_args,
Expand Down
110 changes: 110 additions & 0 deletions ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/types.hpp"

#include "ttnn/operations/pool/maxpool/max_pool2d.hpp"
#include "ttnn/operations/pool/maxpool/max_pool2d_pybind.hpp"


namespace ttnn::operations::pool {

void bind_max_pool2d_operation(py::module& module) {
bind_registered_operation(
module,
ttnn::max_pool2d_new,
R"doc(
Max Pool 2D
+-------------------+-------------------------------+---------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+===================+===============================+===============+=============+==========+
| input | Input activations tensor | Tensor | | Yes |
| in_n | Input nbatch | Tensor | | Yes |
| in_h | Input height | Tensor | | Yes |
| in_w | Input width | Tensor | | Yes |
| kernel_h | kernel window height | uint32_t | | Yes |
| kernel_w | kernel window width | uint32_t | | Yes |
| stride_h | stride in height dim | uint32_t | | No |
| stride_w | stride in width dim | uint32_t | | No |
| pad_h | padding in height dim | uint32_t | | No |
| pad_w | padding in width dim | uint32_t | | No |
| dilation_h | kernel dilation in height dim | uint32_t | | No |
| dilation_w | kernel dilation in width dim | uint32_t | | No |
| memory_config | Output memory config | MemoryConfig | | No |
+-------------------+-------------------------------+---------------+-------------+----------+
)doc",
ttnn::pybind_overload_t{
[](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor,
uint32_t batch_size,
uint32_t input_h,
uint32_t input_w,
uint32_t channels,
std::array<uint32_t, 2> kernel_size,
std::array<uint32_t, 2> stride,
std::array<uint32_t, 2> padding,
std::array<uint32_t, 2> dilation,
ttnn::Device* device,
const uint8_t& queue_id)
-> ttnn::Tensor { return self(queue_id,
input_tensor,
batch_size,
input_h,
input_w,
channels,
kernel_size,
stride,
padding,
dilation,
device); },
py::arg("input_tensor"),
py::arg("batch_size"),
py::arg("input_h"),
py::arg("input_w"),
py::arg("channels"),
py::arg("kernel_size"),
py::arg("stride"),
py::arg("padding"),
py::arg("dilation"),
py::kw_only(),
py::arg("device"),
py::arg("queue_id") = 0},
ttnn::pybind_overload_t{
[](const decltype(ttnn::max_pool2d_new)& self, const ttnn::Tensor& input_tensor,
uint32_t batch_size,
uint32_t input_h,
uint32_t input_w,
uint32_t channels,
std::array<uint32_t, 2> kernel_size,
std::array<uint32_t, 2> stride,
std::array<uint32_t, 2> padding,
std::array<uint32_t, 2> dilation,
DeviceMesh* device,
const uint8_t& queue_id)
-> ttnn::Tensor { return self(queue_id,
input_tensor,
batch_size,
input_h,
input_w,
channels,
kernel_size,
stride,
padding,
dilation,
device); },
py::arg("input_tensor"),
py::arg("batch_size"),
py::arg("input_h"),
py::arg("input_w"),
py::arg("channels"),
py::arg("kernel_size"),
py::arg("stride"),
py::arg("padding"),
py::arg("dilation"),
py::kw_only(),
py::arg("device"),
py::arg("queue_id") = 0});
}

} // namespace ttnn::operations::pool
Loading

0 comments on commit 90c51e3

Please sign in to comment.