Skip to content

Commit

Permalink
#9329: Restructure ttnn::argmax (#9331)
Browse files Browse the repository at this point in the history
* #9329: Restructure ttnn::argmax
  • Loading branch information
ayerofieiev-tt authored Jun 10, 2024
1 parent c543781 commit 6958b94
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 140 deletions.
4 changes: 1 addition & 3 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/async_runtime.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/binary/binary_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/to_layout/to_layout_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/reduction/reduction_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/reduction/argmax_create_program.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp
)
add_library(ttnn_lib OBJECT ${TTNN_SRCS})
target_compile_options(ttnn_lib PUBLIC -MP -Wno-int-to-pointer-cast -fno-var-tracking)
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "pybind11/operations/normalization.hpp"
#include "pybind11/operations/pool.hpp"
#include "pybind11/operations/copy.hpp"
#include "pybind11/operations/reduction.hpp"
#include "pybind11/operations/ternary.hpp"
#include "pybind11/operations/transformer.hpp"
#include "pybind11/operations/unary.hpp"


#include "ttnn/operations/reduction/reduction_pybind.hpp"

namespace py = pybind11;

namespace ttnn {
Expand Down
58 changes: 58 additions & 0 deletions ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/validation.hpp"

#include "tt_eager/tt_dnn/op_library/run_operation.hpp"

#include "device/argmax_op.hpp"

namespace ttnn {
namespace operations::reduction {

struct ExecuteArgMax {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{4, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return operation::run(
ArgMax{tt::tt_metal::DataType::UINT32, dim, memory_config.value_or(input_tensor.memory_config())},
{input_tensor}, {}, {optional_output_tensor}, queue_id)
.at(0);
}

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return execute_on_worker_thread(DefaultQueueId, input_tensor, dim, memory_config, optional_output_tensor);
}
};

} // namespace operations::reduction

constexpr auto argmax = ttnn::register_operation<ttnn::operations::reduction::ExecuteArgMax>("ttnn::argmax");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,12 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/reduction.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace reduction {

namespace detail {

template <typename reduction_operation_t>
void bind_reduction_operation(py::module& module, const reduction_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor)doc",
operation.name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("dim") = std::nullopt,
py::arg("keepdim") = true,
py::arg("memory_config") = std::nullopt});
}
#include "argmax.hpp"

using argmax_operation_t = decltype(ttnn::argmax);
void bind_reduction_argmax_operation(py::module& module, const argmax_operation_t& operation) {
namespace ttnn::operations::reduction::detail {
namespace py = pybind11;
void bind_reduction_argmax_operation(py::module& module) {
auto doc =
R"doc(argmax(input_tensor: ttnn.Tensor, *, dim: Optional[int] = None, memory_config: MemoryConfig = std::nullopt, output_tensor : Optional[ttnn.Tensor] = std::nullopt, queue_id : [int] = 0) -> ttnn.Tensor
Expand All @@ -63,12 +40,13 @@ void bind_reduction_argmax_operation(py::module& module, const argmax_operation_
* :attr:`queue_id` (Optional[uint8]): command queue id
)doc";

using OperationType = decltype(ttnn::argmax);
bind_registered_operation(
module,
operation,
ttnn::argmax,
doc,
ttnn::pybind_overload_t{
[] (const argmax_operation_t& self,
[] (const OperationType& self,
const ttnn::Tensor& input_tensor,
const std::optional<int> dim,
const std::optional<ttnn::MemoryConfig>& memory_config,
Expand All @@ -84,21 +62,4 @@ void bind_reduction_argmax_operation(py::module& module, const argmax_operation_
py::arg("queue_id") = 0});
}

} // namespace detail

void py_module(py::module& module) {
// Generic reductions
detail::bind_reduction_operation(module, ttnn::sum);
detail::bind_reduction_operation(module, ttnn::mean);
detail::bind_reduction_operation(module, ttnn::max);
detail::bind_reduction_operation(module, ttnn::min);
detail::bind_reduction_operation(module, ttnn::std);
detail::bind_reduction_operation(module, ttnn::var);

// Special reductions
detail::bind_reduction_argmax_operation(module, ttnn::argmax);
}

} // namespace reduction
} // namespace operations
} // namespace ttnn
} // namespace ttnn::operations::reduction::detail
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/op_library/reduction/reduction_op.hpp"
#include "argmax_op.hpp"
#include "argmax_program_factory.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace ttnn::operations::reduction {

void ArgMax::validate_with_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
Expand Down Expand Up @@ -83,8 +80,4 @@ operation::ProgramWithCallbacks ArgMax::create_program(
return detail::argmax_multi_core(input_tensor, output_tensor, this->dim);
}

} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,10 @@
#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/run_operation.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace ttnn::operations::reduction {

constexpr uint8_t DefaultQueueId = 0;

namespace detail {
operation::ProgramWithCallbacks argmax_multi_core(
const Tensor& input, const Tensor& output, const std::optional<int> dim);
} // namespace detail

struct ArgMax {
const DataType output_dtype;
const std::optional<int> dim;
Expand All @@ -39,8 +30,4 @@ struct ArgMax {
};


} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,13 @@

#include <algorithm>

#include "ttnn/cpp/ttnn/op_library/reduction/reduction_op.hpp"

#include "tt_dnn/op_library/math.hpp"
#include "tt_dnn/op_library/work_split.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"

namespace ttnn {

namespace operations {

namespace reduction {

namespace detail {
namespace ttnn::operations::reduction::detail {

using namespace tt::constants;

Expand Down Expand Up @@ -96,7 +88,7 @@ operation::ProgramWithCallbacks argmax_multi_core(
std::map<string, string> kernel_defines;
tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/op_library/reduction/kernels/reader_argmax_interleaved.cpp",
"ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp",
all_cores,
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines));

Expand Down Expand Up @@ -130,10 +122,4 @@ operation::ProgramWithCallbacks argmax_multi_core(
return {std::move(program), override_runtime_args_callback};
}

} // namespace detail

} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction::detail
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
#include "tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_eager/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_eager/tt_dnn/op_library/run_operation.hpp"
#include "ttnn/cpp/ttnn/op_library/reduction/reduction_op.hpp"

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/validation.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace operations::reduction {

enum class ReduceType {
Sum,
Expand Down Expand Up @@ -164,44 +161,7 @@ struct Reduce {
}
};

struct ExecuteArgMax {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{4, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return operation::run(
ArgMax{tt::tt_metal::DataType::UINT32, dim, memory_config.value_or(input_tensor.memory_config())},
{input_tensor}, {}, {optional_output_tensor}, queue_id)
.at(0);
}

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return execute_on_worker_thread(DefaultQueueId, input_tensor, dim, memory_config, optional_output_tensor);
}
};

} // namespace reduction
} // namespace operations
} // namespace operations::reduction

// Generic reductions
constexpr auto sum =
Expand All @@ -228,7 +188,4 @@ constexpr auto var =
ttnn::register_operation<ttnn::operations::reduction::Reduce<ttnn::operations::reduction::ReduceType::Var>>(
"ttnn::var");

// Special reductions
constexpr auto argmax = ttnn::register_operation<ttnn::operations::reduction::ExecuteArgMax>("ttnn::argmax");

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"

namespace ttnn::operations::reduction::detail {

template <typename reduction_operation_t>
void bind_reduction_operation(py::module& module, const reduction_operation_t& operation) {
namespace py = pybind11;
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor)doc",
operation.name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("dim") = std::nullopt,
py::arg("keepdim") = true,
py::arg("memory_config") = std::nullopt});
}

} // namespace ttnn::operations::reduction::detail
Loading

0 comments on commit 6958b94

Please sign in to comment.