-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9281a1c
commit dff3292
Showing
11 changed files
with
261 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _ttnn.unary_mul_bw: | ||
|
||
ttnn.unary_mul_bw | ||
################# | ||
|
||
.. autofunction:: ttnn.unary_mul_bw |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp" | ||
|
||
#include "third_party/magic_enum/magic_enum.hpp" | ||
#include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp" | ||
#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" | ||
#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" | ||
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
#include "tt_metal/tools/profiler/op_profiler.hpp" | ||
#include "ttnn/operations/eltwise/unary/unary.hpp" | ||
|
||
namespace ttnn::operations::unary_backward { | ||
|
||
namespace utils { | ||
|
||
|
||
std::vector<ttnn::Tensor> _unary_mul_bw( | ||
const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { | ||
std::vector<Tensor> grad_tensor; | ||
Tensor result = mul_unary(grad, scalar, output_mem_config); | ||
grad_tensor.emplace_back(result); | ||
return grad_tensor; | ||
} | ||
|
||
|
||
std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(UnaryBackwardOpType OpType){ | ||
switch (OpType) { | ||
default: | ||
TT_ASSERT(false && "Undefined op type"); | ||
return 0; | ||
} | ||
} | ||
|
||
std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(UnaryBackwardOpType OpType){ | ||
switch (OpType) { | ||
case UnaryBackwardOpType::UNARY_MUL_BW: | ||
return _unary_mul_bw; | ||
default: | ||
TT_ASSERT(false && "Undefined op type"); | ||
return 0; | ||
} | ||
} | ||
|
||
} | ||
|
||
} // namespace ttnn::operations::unary |
20 changes: 20 additions & 0 deletions
20
ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <functional> | ||
#include <optional> | ||
|
||
#include "third_party/magic_enum/magic_enum.hpp" | ||
|
||
namespace ttnn::operations::unary_backward { | ||
|
||
constexpr uint8_t DefaultQueueId = 0; | ||
enum class UnaryBackwardOpType { | ||
UNARY_MUL_BW, | ||
}; | ||
|
||
|
||
} // namespace ttnn::operations::unary |
84 changes: 84 additions & 0 deletions
84
ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "device/unary_backward_op.cpp" | ||
#include "ttnn/device_operation.hpp" | ||
#include "ttnn/operations/data_movement.hpp" | ||
|
||
namespace ttnn { | ||
|
||
namespace operations::unary_backward { | ||
|
||
template <UnaryBackwardOpType unary_backward_op_type> | ||
struct ExecuteUnaryBackward { | ||
|
||
static inline const std::array<TensorSchema, 2> input_tensor_schemas() { | ||
return { | ||
ttnn::TensorSchema{ | ||
2, | ||
4, | ||
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, | ||
{ttnn::TILE_LAYOUT}, | ||
true, | ||
false, | ||
false, | ||
false}, | ||
ttnn::TensorSchema{ | ||
2, | ||
4, | ||
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b, ttnn::uint16}, | ||
{ttnn::TILE_LAYOUT}, | ||
true, | ||
false, | ||
false, | ||
false}}; | ||
} | ||
|
||
static inline std::vector<ttnn::Tensor> create_async_output_tensors( | ||
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) { | ||
const auto& input_tensor = input_tensors.at(0); | ||
return {Tensor(operation::get_workers_for_op_output({input_tensor}))}; | ||
} | ||
|
||
//Type 1: 2 inputs, 1 grad tensor | ||
template <typename... Args> | ||
static auto input_tensors_to_validate(const Tensor &grad_tensor, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { | ||
return std::forward_as_tuple(grad_tensor, input_tensor_a, input_tensor_b); | ||
} | ||
|
||
static std::vector<ttnn::Tensor> execute_on_worker_thread( | ||
const Tensor &grad_tensor_arg, | ||
const Tensor &input_tensor_arg, | ||
const std::optional<MemoryConfig> &memory_config = std::nullopt) { | ||
|
||
auto op_type = utils::get_function_type1(unary_backward_op_type); | ||
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); | ||
return op_type(grad_tensor_arg, input_tensor_arg, output_memory_config); | ||
} | ||
|
||
//Type 1: Type 1 with 1 float | ||
template <typename... Args> | ||
|
||
static std::vector<ttnn::Tensor> execute_on_worker_thread( | ||
const Tensor &grad_tensor_arg, | ||
const Tensor &input_tensor_arg, | ||
float alpha, | ||
const std::optional<MemoryConfig> &memory_config = std::nullopt) { | ||
|
||
auto op_type = utils::get_function_type1_w_float(unary_backward_op_type); | ||
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); | ||
return op_type(grad_tensor_arg, input_tensor_arg, alpha, output_memory_config); | ||
} | ||
|
||
}; | ||
|
||
} // operations::unary | ||
|
||
//type 1 | ||
constexpr auto unary_mul_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_MUL_BW>>("ttnn::unary_mul_bw"); | ||
|
||
} // namespace ttnn |
93 changes: 93 additions & 0 deletions
93
ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
#include "ttnn/operations/eltwise/unary_backward/unary_backward.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace ttnn { | ||
namespace operations { | ||
namespace unary_backward { | ||
|
||
namespace detail { | ||
|
||
template <typename unary_backward_operation_t> | ||
void bind_unary_backward(py::module& module, const unary_backward_operation_t& operation, const std::string& description) { | ||
auto doc = fmt::format( | ||
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor: ttnn.Tensor *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor> | ||
{2} | ||
Args: | ||
* :attr:`grad_tensor` | ||
* :attr:`input_tensor` | ||
Keyword args: | ||
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor | ||
Example: | ||
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) | ||
>>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) | ||
>>> output = {1}(grad_tensor, input) | ||
)doc", | ||
operation.name(), | ||
operation.python_fully_qualified_name(), | ||
description); | ||
|
||
bind_registered_operation( | ||
module, | ||
operation, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const unary_backward_operation_t& self, | ||
const ttnn::Tensor& grad_tensor, | ||
const ttnn::Tensor& input_tensor, | ||
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> { | ||
auto output_memory_config = memory_config.value_or(input_tensor.memory_config()); | ||
return self(grad_tensor, input_tensor, output_memory_config); | ||
}, | ||
py::arg("grad_tensor"), | ||
py::arg("input_tensor"), | ||
py::kw_only(), | ||
py::arg("memory_config") = std::nullopt}, | ||
|
||
|
||
ttnn::pybind_overload_t{ | ||
[](const unary_backward_operation_t& self, | ||
const ttnn::Tensor& grad_tensor, | ||
const ttnn::Tensor& input_tensor, | ||
const float alpha, | ||
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> { | ||
return self(grad_tensor, input_tensor, alpha, memory_config); | ||
}, | ||
py::arg("grad_tensor"), | ||
py::arg("input_tensor"), | ||
py::arg("alpha"), | ||
py::kw_only(), | ||
py::arg("memory_config") = std::nullopt}); | ||
|
||
} | ||
|
||
} // namespace detail | ||
|
||
|
||
void py_module(py::module& module) { | ||
detail::bind_unary_backward( | ||
module, | ||
ttnn::unary_mul_bw, | ||
R"doc(Performs backward operations for multiply on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc"); | ||
|
||
} | ||
|
||
} // namespace binary_backward | ||
} // namespace operations | ||
} // namespace ttnn |