-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update eltwise flatbuffer to include extra params table, move eltwise…
… composite to its own file to match ttnn implementation
- Loading branch information
Showing
16 changed files
with
277 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
47 changes: 47 additions & 0 deletions
47
runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "binary_composite.h" | ||
#include "tt/runtime/detail/logger.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
|
||
namespace tt::runtime::ttnn::operations::binary::composite { | ||
|
||
static void runEltwiseBinaryCompositeOP( | ||
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, | ||
std::function< | ||
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &, | ||
const std::optional<::tt::tt_metal::MemoryConfig> &)> | ||
ttnnOp) { | ||
|
||
::ttnn::Tensor *lhs = nullptr; | ||
::ttnn::Tensor *rhs = nullptr; | ||
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs); | ||
|
||
::tt::tt_metal::MemoryConfig outputMemoryConfig = | ||
utils::createMemoryConfig(op->out()); | ||
|
||
::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig); | ||
tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { | ||
ProgramTensorPool &tensorPool = context.getTensorPool(); | ||
switch (op->type()) { | ||
case ::tt::target::ttnn::EltwiseOpType::Maximum: { | ||
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum); | ||
break; | ||
} | ||
case ::tt::target::ttnn::EltwiseOpType::Minimum: { | ||
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum); | ||
break; | ||
} | ||
default: | ||
throw std::invalid_argument( | ||
"Unsupported Eltwise Binary Composite operation"); | ||
} | ||
} | ||
|
||
} // namespace tt::runtime::ttnn::operations::binary::composite |
27 changes: 27 additions & 0 deletions
27
runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H | ||
#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::binary::composite { | ||
|
||
inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { | ||
switch (op->type()) { | ||
case ::tt::target::ttnn::EltwiseOpType::Maximum: | ||
case ::tt::target::ttnn::EltwiseOpType::Minimum: | ||
return true; | ||
default: | ||
return false; | ||
} | ||
} | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::binary::composite | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
42 changes: 42 additions & 0 deletions
42
runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "unary_composite.h" | ||
#include "tt/runtime/detail/logger.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
#include "ttnn/operations/eltwise/unary/unary_composite.hpp" | ||
|
||
namespace tt::runtime::ttnn::operations::unary::composite { | ||
|
||
static void runEltwiseUnaryCompositeOP( | ||
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, | ||
std::function<::ttnn::Tensor(const ::ttnn::Tensor &, | ||
const ::tt::tt_metal::MemoryConfig &)> | ||
ttnnOp) { | ||
|
||
::ttnn::Tensor *in = nullptr; | ||
getEltwiseUnaryOPInputTensor(op, tensorPool, &in); | ||
|
||
::tt::tt_metal::MemoryConfig outputMemoryConfig = | ||
utils::createMemoryConfig(op->out()); | ||
|
||
::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig); | ||
tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { | ||
ProgramTensorPool &tensorPool = context.getTensorPool(); | ||
switch (op->type()) { | ||
case ::tt::target::ttnn::EltwiseOpType::Cbrt: { | ||
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt); | ||
break; | ||
} | ||
default: | ||
throw std::invalid_argument( | ||
"Unsupported Eltwise Binary Composite operation"); | ||
} | ||
} | ||
|
||
} // namespace tt::runtime::ttnn::operations::unary::composite |
26 changes: 26 additions & 0 deletions
26
runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H | ||
#define TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::unary::composite { | ||
|
||
inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { | ||
switch (op->type()) { | ||
case ::tt::target::ttnn::EltwiseOpType::Cbrt: | ||
return true; | ||
default: | ||
return false; | ||
} | ||
} | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::unary::composite | ||
|
||
#endif |
25 changes: 25 additions & 0 deletions
25
runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "utils.h" | ||
#include "tt/runtime/detail/logger.h" | ||
|
||
namespace tt::runtime::ttnn::operations::binary { | ||
|
||
void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, | ||
ProgramTensorPool &tensorPool, | ||
::ttnn::Tensor **lhs, | ||
::ttnn::Tensor **rhs) { | ||
LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs"); | ||
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id())); | ||
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id())); | ||
DEBUG_ASSERT((*lhs)->is_allocated()); | ||
DEBUG_ASSERT((*rhs)->is_allocated()); | ||
|
||
// Switch the order of operands if the second operand requires broadcast | ||
if ((*rhs)->volume() < (*lhs)->volume()) { | ||
std::swap(*lhs, *rhs); | ||
} | ||
} | ||
|
||
} // namespace tt::runtime::ttnn::operations::binary |
19 changes: 19 additions & 0 deletions
19
runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H | ||
#define TTNN_RUNTIME_ELTWISE_BINARY_UTILS_H | ||
|
||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::binary { | ||
void getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, | ||
ProgramTensorPool &tensorPool, | ||
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs); | ||
|
||
} // namespace tt::runtime::ttnn::operations::binary | ||
|
||
#endif |
18 changes: 18 additions & 0 deletions
18
runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "utils.h" | ||
#include "tt/runtime/detail/logger.h" | ||
|
||
namespace tt::runtime::ttnn::operations::unary { | ||
|
||
void getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op, | ||
ProgramTensorPool &tensorPool, | ||
::ttnn::Tensor **in) { | ||
LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ", | ||
op->ins()->size()); | ||
*in = &(tensorPool.at(op->ins()->Get(0)->global_id())); | ||
DEBUG_ASSERT((*in)->is_allocated()); | ||
} | ||
|
||
} // namespace tt::runtime::ttnn::operations::unary |
Oops, something went wrong.