Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#9329: Restructure ttnn::argmax #9331

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ set(TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/async_runtime.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/binary/binary_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/to_layout/to_layout_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/reduction/reduction_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/op_library/reduction/argmax_create_program.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/matmul.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp
)
add_library(ttnn_lib OBJECT ${TTNN_SRCS})
target_compile_options(ttnn_lib PUBLIC -MP -Wno-int-to-pointer-cast -fno-var-tracking)
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "pybind11/operations/normalization.hpp"
#include "pybind11/operations/pool.hpp"
#include "pybind11/operations/copy.hpp"
#include "pybind11/operations/reduction.hpp"
#include "pybind11/operations/ternary.hpp"
#include "pybind11/operations/transformer.hpp"
#include "pybind11/operations/unary.hpp"


#include "ttnn/operations/reduction/reduction_pybind.hpp"

namespace py = pybind11;

namespace ttnn {
Expand Down
58 changes: 58 additions & 0 deletions ttnn/cpp/ttnn/operations/reduction/argmax/argmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/validation.hpp"

#include "tt_eager/tt_dnn/op_library/run_operation.hpp"

#include "device/argmax_op.hpp"

namespace ttnn {
namespace operations::reduction {

struct ExecuteArgMax {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{4, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return operation::run(
ArgMax{tt::tt_metal::DataType::UINT32, dim, memory_config.value_or(input_tensor.memory_config())},
{input_tensor}, {}, {optional_output_tensor}, queue_id)
.at(0);
}

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return execute_on_worker_thread(DefaultQueueId, input_tensor, dim, memory_config, optional_output_tensor);
}
};

} // namespace operations::reduction

constexpr auto argmax = ttnn::register_operation<ttnn::operations::reduction::ExecuteArgMax>("ttnn::argmax");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,12 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/reduction.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace reduction {

namespace detail {

template <typename reduction_operation_t>
void bind_reduction_operation(py::module& module, const reduction_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor)doc",
operation.name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("dim") = std::nullopt,
py::arg("keepdim") = true,
py::arg("memory_config") = std::nullopt});
}
#include "argmax.hpp"

using argmax_operation_t = decltype(ttnn::argmax);
void bind_reduction_argmax_operation(py::module& module, const argmax_operation_t& operation) {
namespace ttnn::operations::reduction::detail {
namespace py = pybind11;
void bind_reduction_argmax_operation(py::module& module) {
auto doc =
R"doc(argmax(input_tensor: ttnn.Tensor, *, dim: Optional[int] = None, memory_config: MemoryConfig = std::nullopt, output_tensor : Optional[ttnn.Tensor] = std::nullopt, queue_id : [int] = 0) -> ttnn.Tensor

Expand All @@ -63,12 +40,13 @@ void bind_reduction_argmax_operation(py::module& module, const argmax_operation_
* :attr:`queue_id` (Optional[uint8]): command queue id
)doc";

using OperationType = decltype(ttnn::argmax);
bind_registered_operation(
module,
operation,
ttnn::argmax,
doc,
ttnn::pybind_overload_t{
[] (const argmax_operation_t& self,
[] (const OperationType& self,
const ttnn::Tensor& input_tensor,
const std::optional<int> dim,
const std::optional<ttnn::MemoryConfig>& memory_config,
Expand All @@ -84,21 +62,4 @@ void bind_reduction_argmax_operation(py::module& module, const argmax_operation_
py::arg("queue_id") = 0});
}

} // namespace detail

void py_module(py::module& module) {
// Generic reductions
detail::bind_reduction_operation(module, ttnn::sum);
detail::bind_reduction_operation(module, ttnn::mean);
detail::bind_reduction_operation(module, ttnn::max);
detail::bind_reduction_operation(module, ttnn::min);
detail::bind_reduction_operation(module, ttnn::std);
detail::bind_reduction_operation(module, ttnn::var);

// Special reductions
detail::bind_reduction_argmax_operation(module, ttnn::argmax);
}

} // namespace reduction
} // namespace operations
} // namespace ttnn
} // namespace ttnn::operations::reduction::detail
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/op_library/reduction/reduction_op.hpp"
#include "argmax_op.hpp"
#include "argmax_program_factory.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace ttnn::operations::reduction {

void ArgMax::validate_with_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
Expand Down Expand Up @@ -83,8 +80,4 @@ operation::ProgramWithCallbacks ArgMax::create_program(
return detail::argmax_multi_core(input_tensor, output_tensor, this->dim);
}

} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,10 @@
#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/run_operation.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace ttnn::operations::reduction {

constexpr uint8_t DefaultQueueId = 0;

namespace detail {
operation::ProgramWithCallbacks argmax_multi_core(
const Tensor& input, const Tensor& output, const std::optional<int> dim);
} // namespace detail

struct ArgMax {
const DataType output_dtype;
const std::optional<int> dim;
Expand All @@ -39,8 +30,4 @@ struct ArgMax {
};


} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,13 @@

#include <algorithm>

#include "ttnn/cpp/ttnn/op_library/reduction/reduction_op.hpp"

#include "tt_dnn/op_library/math.hpp"
#include "tt_dnn/op_library/work_split.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"

namespace ttnn {

namespace operations {

namespace reduction {

namespace detail {
namespace ttnn::operations::reduction::detail {

using namespace tt::constants;

Expand Down Expand Up @@ -96,7 +88,7 @@ operation::ProgramWithCallbacks argmax_multi_core(
std::map<string, string> kernel_defines;
tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/op_library/reduction/kernels/reader_argmax_interleaved.cpp",
"ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved.cpp",
all_cores,
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines));

Expand Down Expand Up @@ -130,10 +122,4 @@ operation::ProgramWithCallbacks argmax_multi_core(
return {std::move(program), override_runtime_args_callback};
}

} // namespace detail

} // namespace reduction

} // namespace operations

} // namespace ttnn
} // namespace ttnn::operations::reduction::detail
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
#include "tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_eager/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_eager/tt_dnn/op_library/run_operation.hpp"
#include "ttnn/cpp/ttnn/op_library/reduction/reduction_op.hpp"

#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/validation.hpp"

namespace ttnn {

namespace operations {

namespace reduction {
namespace operations::reduction {

enum class ReduceType {
Sum,
Expand Down Expand Up @@ -164,44 +161,7 @@ struct Reduce {
}
};

struct ExecuteArgMax {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{4, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(uint8_t queue_id, const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return operation::run(
ArgMax{tt::tt_metal::DataType::UINT32, dim, memory_config.value_or(input_tensor.memory_config())},
{input_tensor}, {}, {optional_output_tensor}, queue_id)
.at(0);
}

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return std::forward_as_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
const Tensor& input_tensor,
const std::optional<int> dim = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return execute_on_worker_thread(DefaultQueueId, input_tensor, dim, memory_config, optional_output_tensor);
}
};

} // namespace reduction
} // namespace operations
} // namespace operations::reduction

// Generic reductions
constexpr auto sum =
Expand All @@ -228,7 +188,4 @@ constexpr auto var =
ttnn::register_operation<ttnn::operations::reduction::Reduce<ttnn::operations::reduction::ReduceType::Var>>(
"ttnn::var");

// Special reductions
constexpr auto argmax = ttnn::register_operation<ttnn::operations::reduction::ExecuteArgMax>("ttnn::argmax");

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"

namespace ttnn::operations::reduction::detail {

template <typename reduction_operation_t>
void bind_reduction_operation(py::module& module, const reduction_operation_t& operation) {
namespace py = pybind11;
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor)doc",
operation.name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("dim") = std::nullopt,
py::arg("keepdim") = true,
py::arg("memory_config") = std::nullopt});
}

} // namespace ttnn::operations::reduction::detail
Loading
Loading