Skip to content

Commit

Permalink
#16153: Implement binary-ng fused input activations
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Dec 23, 2024
1 parent 5a5126d commit de965af
Show file tree
Hide file tree
Showing 17 changed files with 465 additions and 281 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ std::map<std::string, std::string> get_defines(
defines["ELTWISE_OP"] = op_name.c_str();
defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str();
if (fused_activations.has_value()) {
if (op_type == BinaryOpType::ADD and fused_activations.value().size() == 1 and
fused_activations.value().at(0).op_type == UnaryOpType::RELU) {
if (op_type == BinaryOpType::ADD and fused_activations->size() == 1 and
fused_activations->at(0).op_type == UnaryOpType::RELU and not input_tensor_a_activation.has_value()) {
defines["PACK_RELU"] = "1";
} else {
defines.merge(ttnn::operations::unary::utils::get_block_defines(fused_activations.value(), "0", idst));
defines.merge(ttnn::operations::unary::utils::get_block_defines(*fused_activations, "0", idst));
}
}

Expand Down
93 changes: 71 additions & 22 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,93 @@ namespace ttnn::operations::binary_ng {
template <BinaryOpType binary_op_type>
Tensor BinaryNg<binary_op_type>::invoke(
uint8_t queue_id,
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor) {
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
return ttnn::prim::binary_ng(
queue_id, input_tensor_a, input_tensor_b, binary_op_type, output_dtype, memory_config, optional_output_tensor);
queue_id,
input_tensor_a,
input_tensor_b,
binary_op_type,
output_dtype,
memory_config,
optional_output_tensor,
lhs_activations,
rhs_activations,
post_activations);
}

template <BinaryOpType binary_op_type>
Tensor BinaryNg<binary_op_type>::invoke(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor) {
return invoke(DefaultQueueId, input_tensor_a, input_tensor_b, output_dtype, memory_config, optional_output_tensor);
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
return invoke(
DefaultQueueId,
input_tensor_a,
input_tensor_b,
output_dtype,
memory_config,
optional_output_tensor,
lhs_activations,
rhs_activations,
post_activations);
}

template <BinaryOpType binary_op_type>
Tensor BinaryNg<binary_op_type>::invoke(
uint8_t queue_id,
const Tensor &input_tensor_a,
const Tensor& input_tensor_a,
float scalar,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor) {
return ttnn::prim::binary_ng(queue_id, input_tensor_a, scalar, binary_op_type, output_dtype, memory_config, optional_output_tensor);
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
return ttnn::prim::binary_ng(
queue_id,
input_tensor_a,
scalar,
binary_op_type,
output_dtype,
memory_config,
optional_output_tensor,
lhs_activations,
rhs_activations,
post_activations);
}

template <BinaryOpType binary_op_type>
Tensor BinaryNg<binary_op_type>::invoke(
const Tensor &input_tensor_a,
const Tensor& input_tensor_a,
float scalar,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor) {
return invoke(DefaultQueueId, input_tensor_a, scalar, output_dtype, memory_config, optional_output_tensor);
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
return invoke(
DefaultQueueId,
input_tensor_a,
scalar,
output_dtype,
memory_config,
optional_output_tensor,
lhs_activations,
rhs_activations,
post_activations);
}

template struct BinaryNg<BinaryOpType::ADD>;
Expand Down
52 changes: 33 additions & 19 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,53 @@

#include "ttnn/decorators.hpp"
#include "ttnn/operations/eltwise/binary_ng/types.hpp"
#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp"

namespace ttnn::operations::binary_ng {

template <BinaryOpType binary_op_type>
struct BinaryNg {
static Tensor invoke(
uint8_t queue_id,
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt);
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});

static Tensor invoke(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt);
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});

static Tensor invoke(
uint8_t queue_id,
const Tensor &input_tensor_a,
const Tensor& input_tensor_a,
float scalar,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt);
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});

static Tensor invoke(
const Tensor &input_tensor_a,
const Tensor& input_tensor_a,
float scalar,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt);
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});
};

} // namespace ttnn::operations::binary_ng
Expand Down Expand Up @@ -117,4 +130,5 @@ constexpr auto logaddexp = ttnn::register_operation_with_auto_launch_op<
constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op<
"ttnn::experimental::logaddexp2",
ttnn::operations::binary_ng::BinaryNg<operations::binary_ng::BinaryOpType::LOGADDEXP2>>();
}

} // namespace ttnn::experimental
36 changes: 33 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,30 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const ttnn::SmallVector<unary::UnaryOpType>& lhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& rhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& post_activations,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor);
return self(
queue_id,
input_tensor_a,
scalar,
dtype,
memory_config,
output_tensor,
lhs_activations,
rhs_activations,
post_activations);
},
py::arg("input_tensor_a"),
py::arg("scalar"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("queue_id") = 0},

// tensor and tensor
Expand All @@ -43,15 +58,30 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const ttnn::SmallVector<unary::UnaryOpType>& lhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& rhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& post_activations,
uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor);
return self(
queue_id,
input_tensor_a,
input_tensor_b,
dtype,
memory_config,
output_tensor,
lhs_activations,
rhs_activations,
post_activations);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("queue_id") = 0});
}
} // namespace detail
Expand All @@ -77,4 +107,4 @@ void py_module(py::module& module) {
detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp, "Binary Logaddexp Operation");
detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp2, "Binary Logaddexp2 Operation");
}
} // namespace ttnn::operations::eltwise::binary_ng
} // namespace ttnn::operations::binary_ng
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@ SubtileBroadcastType get_subtile_broadcast_type(uint32_t a_h, uint32_t a_w, uint

tt::stl::hash::hash_t BinaryNgDeviceOperation::operation_attributes_t::to_hash() const {
return tt::stl::hash::hash_objects_with_default_seed(
binary_op_type, memory_config, get_dtype(), compute_kernel_config, subtile_broadcast_type);
binary_op_type,
lhs_activations,
rhs_activations,
post_activations,
memory_config,
get_dtype(),
compute_kernel_config,
subtile_broadcast_type);
}

DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const {
Expand Down Expand Up @@ -197,7 +204,10 @@ BinaryNgDeviceOperation::invoke(
BinaryOpType binary_op_type,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor) {
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
auto subtile_broadcast_type = get_subtile_broadcast_type(
input_tensor_a_arg.get_logical_shape()[-2],
input_tensor_a_arg.get_logical_shape()[-1],
Expand All @@ -207,6 +217,9 @@ BinaryNgDeviceOperation::invoke(
return {
operation_attributes_t{
binary_op_type,
{lhs_activations.begin(), lhs_activations.end()},
{rhs_activations.begin(), rhs_activations.end()},
{post_activations.begin(), post_activations.end()},
std::nullopt,
memory_config.value_or(input_tensor_a_arg.memory_config()),
input_tensor_a_arg.get_dtype(),
Expand All @@ -223,10 +236,16 @@ BinaryNgDeviceOperation::invoke(
BinaryOpType binary_op_type,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor) {
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations) {
return {
operation_attributes_t{
binary_op_type,
{lhs_activations.begin(), lhs_activations.end()},
{rhs_activations.begin(), rhs_activations.end()},
{post_activations.begin(), post_activations.end()},
scalar,
memory_config.value_or(input_tensor_a_arg.memory_config()),
input_tensor_a_arg.get_dtype(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/eltwise/binary_ng/types.hpp"

#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp"
namespace ttnn::operations::binary_ng {

enum class SubtileBroadcastType {
Expand All @@ -31,6 +31,9 @@ struct BinaryNgDeviceOperation {

struct operation_attributes_t {
BinaryOpType binary_op_type;
ttnn::SmallVector<unary::UnaryOpType> lhs_activations;
ttnn::SmallVector<unary::UnaryOpType> rhs_activations;
ttnn::SmallVector<unary::UnaryOpType> post_activations;
std::optional<float> scalar;
MemoryConfig memory_config;
DataType input_dtype;
Expand Down Expand Up @@ -86,7 +89,10 @@ struct BinaryNgDeviceOperation {
BinaryOpType binary_op_type,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor);
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations);

// tensor-scalar invocation
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand All @@ -95,7 +101,10 @@ struct BinaryNgDeviceOperation {
BinaryOpType binary_op_type,
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor);
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations);
};

} // namespace ttnn::operations::binary_ng
Expand Down
Loading

0 comments on commit de965af

Please sign in to comment.