Skip to content

Commit

Permalink
Update eltwise flatbuffer to include extra params table, move eltwise…
Browse files Browse the repository at this point in the history
… composite to its own file to match ttnn implementation
  • Loading branch information
jnie-TT committed Oct 31, 2024
1 parent 2a2121d commit 278a774
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 84 deletions.
5 changes: 5 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ enum EltwiseOpType: uint32 {
Cos = 27
}

union EltwiseOpParams {

}

table EltwiseOp {
type: EltwiseOpType;
ins: [tt.target.TensorRef];
out: tt.target.TensorRef;
params: EltwiseOpParams;
}

enum ReductionOpType: uint32 {
Expand Down
6 changes: 5 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
::tt::target::ttnn::EltwiseOpParams paramsType =
::tt::target::ttnn::EltwiseOpParams::NONE;
::flatbuffers::Offset<void> params = 0;
if constexpr (std::is_same_v<EltwiseOp, AbsOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Abs;
} else if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
Expand Down Expand Up @@ -360,7 +363,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
return ::tt::target::ttnn::CreateEltwiseOpDirect(
*cache.fbb, type, &ins,
cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getOutputs().front())));
getOperandThroughDPSOps(op.getOutputs().front())),
paramsType, params);
}

template <typename ReductionOp>
Expand Down
8 changes: 6 additions & 2 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand All @@ -9,8 +11,10 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,12 @@
#include "binary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"

namespace tt::runtime::ttnn::operations::binary {

static void
getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) {
LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs");
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id()));
DEBUG_ASSERT((*lhs)->is_allocated());
DEBUG_ASSERT((*rhs)->is_allocated());

// Switch the order of operands if the second operand requires broadcast
if ((*rhs)->volume() < (*lhs)->volume()) {
std::swap(*lhs, *rhs);
}
}

static void runEltwiseBinaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(
Expand All @@ -48,24 +34,6 @@ static void runEltwiseBinaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand Down Expand Up @@ -118,14 +86,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::divide);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument("Unsupported Eltwise Binary operation");
}
Expand Down
47 changes: 47 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "binary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::binary::composite {

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::binary::composite
27 changes: 27 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::binary::composite {

inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum:
case ::tt::target::ttnn::EltwiseOpType::Minimum:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::binary::composite

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,12 @@
#include "unary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/copy.hpp"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary {

static void
getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **in) {
LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ",
op->ins()->size());
*in = &(tensorPool.at(op->ins()->Get(0)->global_id()));
DEBUG_ASSERT((*in)->is_allocated());
}

static void runEltwiseUnaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -38,22 +28,6 @@ static void runEltwiseUnaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryWithFastAndApproximateModeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -80,10 +54,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Ceil: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::ceil);
break;
Expand Down
File renamed without changes.
42 changes: 42 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "unary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary::composite {

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::unary::composite
26 changes: 26 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::unary::composite {

inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::unary::composite

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "utils.h"
#include "tt/runtime/detail/logger.h"

namespace tt::runtime::ttnn::operations::binary {

void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs,
::ttnn::Tensor **rhs) {
LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs");
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id()));
DEBUG_ASSERT((*lhs)->is_allocated());
DEBUG_ASSERT((*rhs)->is_allocated());

// Switch the order of operands if the second operand requires broadcast
if ((*rhs)->volume() < (*lhs)->volume()) {
std::swap(*lhs, *rhs);
}
}

} // namespace tt::runtime::ttnn::operations::binary
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H
#define TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H

#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::binary {
void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs);

} // namespace tt::runtime::ttnn::operations::binary

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "utils.h"
#include "tt/runtime/detail/logger.h"

namespace tt::runtime::ttnn::operations::unary {

void getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **in) {
LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ",
op->ins()->size());
*in = &(tensorPool.at(op->ins()->Get(0)->global_id()));
DEBUG_ASSERT((*in)->is_allocated());
}

} // namespace tt::runtime::ttnn::operations::unary
Loading

0 comments on commit 278a774

Please sign in to comment.