-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for increasing 1-D row major int32 tensors by one (#12773)
* #12767: add ttnn.plusone * #12767: add test * #0: Update plusone_pybind.cpp * #0: Update plusone_pybind.cpp * #0: Update plusone_pybind.cpp
- Loading branch information
1 parent
07b67d6
commit 3875f61
Showing
12 changed files
with
346 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
|
||
|
||
@pytest.mark.parametrize("w", [1, 4, 8, 32]) | ||
def test_plus_one(device, w): | ||
torch_input_tensor = torch.randint(32000, (w,)) | ||
torch_output_tensor = torch_input_tensor + 1 | ||
|
||
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.int32, device=device) | ||
ttnn.plus_one(input_tensor) | ||
output_tensor = ttnn.to_torch(input_tensor) | ||
assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
ttnn/cpp/ttnn/operations/experimental/plusone/device/kernels/reader_plusone_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
|
||
#include "dataflow_api.h" | ||
|
||
//#include "debug/dprint.h" | ||
|
||
void kernel_main() { | ||
uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
|
||
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); | ||
constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(1); | ||
constexpr uint32_t stick_size = get_compile_time_arg_val(2); | ||
constexpr uint32_t W = get_compile_time_arg_val(3); | ||
|
||
const InterleavedAddrGen<src0_is_dram> s0 = {.bank_base_address = src_addr, .page_size = stick_size}; | ||
|
||
// Use cb as L1 scratch memory | ||
uint32_t cb_addr = get_write_ptr(cb_id_in0); | ||
volatile tt_l1_ptr uint32_t* stick = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(cb_addr); | ||
|
||
|
||
noc_async_read_page(0, s0, cb_addr); | ||
noc_async_read_barrier(); | ||
for(uint32_t i = 0; i < W; i++) { | ||
uint32_t val = stick[i]; | ||
stick[i] = val+1; | ||
//DPRINT << "val: " << val << ENDL(); | ||
} | ||
|
||
uint64_t dst_noc_addr = get_noc_addr(0, s0); | ||
|
||
noc_async_write(cb_addr, dst_noc_addr, stick_size); | ||
noc_async_write_barrier(); | ||
} |
39 changes: 39 additions & 0 deletions
39
ttnn/cpp/ttnn/operations/experimental/plusone/device/plusone_op.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "plusone_op.hpp" | ||
#include "plusone_program_factory.hpp" | ||
|
||
namespace ttnn::operations::experimental { | ||
|
||
void PlusOne::validate_with_output_tensors( | ||
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const { | ||
const auto &input_tensor_a = input_tensors.at(0); | ||
|
||
TT_FATAL(input_tensor_a.get_dtype() == DataType::INT32, "Only INT32 is supported for inputs!"); | ||
TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Only ROW_MAJOR layout is supported for inputs!"); | ||
|
||
auto input_shape = input_tensor_a.get_legacy_shape(); | ||
TT_FATAL(input_shape.size()==1, "must have 1 dimension"); | ||
|
||
} | ||
|
||
std::vector<tt::tt_metal::LegacyShape> PlusOne::compute_output_shapes(const std::vector<Tensor> &input_tensors) const { | ||
auto input_shape = input_tensors[0].get_legacy_shape(); | ||
return {input_shape}; | ||
} | ||
|
||
std::vector<Tensor> PlusOne::create_output_tensors( | ||
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const { | ||
return {input_tensors.at(0)}; | ||
|
||
} | ||
|
||
operation::ProgramWithCallbacks PlusOne::create_program( | ||
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const { | ||
const auto &input_tensor = input_tensors.at(0); | ||
return detail::plusone_single_core(input_tensor); | ||
} | ||
|
||
} // namespace ttnn::operations::experimental |
25 changes: 25 additions & 0 deletions
25
ttnn/cpp/ttnn/operations/experimental/plusone/device/plusone_op.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <optional> | ||
|
||
#include "ttnn/common/constants.hpp" | ||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/run_operation.hpp" | ||
|
||
namespace ttnn::operations::experimental { | ||
|
||
struct PlusOne { | ||
|
||
void validate_with_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const; | ||
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const; | ||
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const; | ||
operation::ProgramWithCallbacks create_program( | ||
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const; | ||
}; | ||
|
||
|
||
} // namespace ttnn::operations::experimental |
87 changes: 87 additions & 0 deletions
87
ttnn/cpp/ttnn/operations/experimental/plusone/device/plusone_program_factory.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include <algorithm> | ||
|
||
#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" | ||
#include "tt_metal/common/work_split.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
|
||
namespace ttnn::operations::experimental::detail { | ||
|
||
using namespace tt::constants; | ||
|
||
operation::ProgramWithCallbacks plusone_single_core( | ||
const Tensor &input) { | ||
tt::tt_metal::Program program{}; | ||
|
||
tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); | ||
uint32_t input_unit_size = input.element_size(); | ||
|
||
tt::tt_metal::Device *device = input.device(); | ||
|
||
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); | ||
uint32_t num_cores_x = compute_with_storage_grid_size.x; | ||
uint32_t num_cores_y = compute_with_storage_grid_size.y; | ||
uint32_t num_units = 1; // single-core | ||
auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] = | ||
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_units); | ||
|
||
const auto &input_shape = input.get_legacy_shape(); | ||
const uint32_t W = input_shape[0]; | ||
|
||
uint32_t src0_cb_index = tt::CB::c_in0; | ||
uint32_t num_input_units = W; | ||
uint32_t aligned_input_unit_size = round_up_to_mul32(num_input_units * input_unit_size); | ||
tt::tt_metal::CircularBufferConfig cb_src0_config = | ||
tt::tt_metal::CircularBufferConfig(aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) | ||
.set_page_size(src0_cb_index, aligned_input_unit_size); | ||
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); | ||
|
||
auto src_buffer = input.buffer(); | ||
bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; | ||
|
||
std::vector<uint32_t> reader_compile_time_args = { | ||
src0_cb_index, | ||
src_is_dram, | ||
aligned_input_unit_size, | ||
W, | ||
}; | ||
|
||
std::map<string, string> kernel_defines; | ||
tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( | ||
program, | ||
"ttnn/cpp/ttnn/operations/experimental/plusone/device/kernels/reader_plusone_interleaved.cpp", | ||
all_cores, | ||
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); | ||
|
||
uint32_t g1_numcores = core_group_1.num_cores(); | ||
uint32_t g2_numcores = core_group_2.num_cores(); | ||
auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); | ||
|
||
for (uint32_t i = 0; i < cores.size(); ++i) { | ||
const CoreCoord &core = cores.at(i); | ||
|
||
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, {src_buffer->address()}); | ||
} | ||
|
||
auto override_runtime_args_callback = [reader_kernel_id, cores]( | ||
const Program &program, | ||
const std::vector<Buffer *> &input_buffers, | ||
const std::vector<Buffer *> &) { | ||
auto src_buffer = input_buffers.at(0); | ||
|
||
for (const auto &core : cores) { | ||
{ | ||
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); | ||
runtime_args[0] = src_buffer->address(); | ||
} | ||
} | ||
}; | ||
|
||
return {std::move(program), override_runtime_args_callback}; | ||
} | ||
|
||
} // namespace ttnn::operations::experimental::detail |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/experimental/plusone/device/plusone_program_factory.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "ttnn/run_operation.hpp" | ||
|
||
namespace ttnn::operations::experimental::detail { | ||
|
||
using namespace tt::constants; | ||
|
||
operation::ProgramWithCallbacks plusone_single_core( | ||
const Tensor &input); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "device/plusone_op.hpp" | ||
#include "ttnn/operations/experimental/plusone/plusone.hpp" | ||
|
||
#include "ttnn/run_operation.hpp" | ||
#include "ttnn/decorators.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
namespace ttnn::operations::experimental { | ||
|
||
ttnn::Tensor PlusOneOperation::invoke( | ||
uint8_t queue_id, | ||
const Tensor& input_tensor) { | ||
return operation::run( | ||
PlusOne{}, | ||
{input_tensor}, {}, {}, queue_id) | ||
.at(0); | ||
} | ||
|
||
ttnn::Tensor PlusOneOperation::invoke( | ||
const Tensor& input_tensor) { | ||
return invoke(DefaultQueueId, input_tensor); | ||
} | ||
|
||
|
||
} // namespace ttnn::operations::experimental |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/run_operation.hpp" | ||
#include "ttnn/decorators.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
namespace ttnn { | ||
namespace operations::experimental { | ||
|
||
struct PlusOneOperation { | ||
static ttnn::Tensor invoke( | ||
uint8_t queue_id, | ||
const Tensor& input_tensor); | ||
|
||
static ttnn::Tensor invoke( | ||
const Tensor& input_tensor); | ||
|
||
}; | ||
|
||
} // namespace operations::experimental | ||
|
||
constexpr auto plus_one = | ||
ttnn::register_operation_with_auto_launch_op<"ttnn::plus_one", ttnn::operations::experimental::PlusOneOperation>(); | ||
|
||
} // namespace ttnn |
45 changes: 45 additions & 0 deletions
45
ttnn/cpp/ttnn/operations/experimental/plusone/plusone_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
|
||
#include "ttnn/operations/experimental/plusone/plusone.hpp" | ||
#include "ttnn/operations/experimental/plusone/plusone_pybind.hpp" | ||
|
||
namespace ttnn::operations::experimental::plusone::detail { | ||
namespace py = pybind11; | ||
void bind_experimental_plusone_operation(py::module& module) { | ||
auto doc = | ||
R"doc(plus_one(input_tensor: ttnn.Tensor) -> ttnn.Tensor | ||
Returns input tensor elements increased by 1. | ||
Input tensor must have UINT32 data type, ROW_MAJOR layout, and 1-D shape. | ||
This op only gives decent performance for small tensors (up to 100 elements). | ||
Equivalent pytorch code: | ||
.. code-block:: python | ||
return torch.add(input_tensor, 1) | ||
Args: | ||
* :attr:`input_tensor`: Input Tensor for plusone. | ||
)doc"; | ||
|
||
using OperationType = decltype(ttnn::plus_one); | ||
bind_registered_operation( | ||
module, | ||
ttnn::plus_one, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[] (const OperationType& self, | ||
const ttnn::Tensor& input_tensor | ||
) { | ||
return self(input_tensor); | ||
}, | ||
py::arg("input_tensor").noconvert()}); | ||
} | ||
|
||
} // namespace ttnn::operations::experimental::plusone::detail |
12 changes: 12 additions & 0 deletions
12
ttnn/cpp/ttnn/operations/experimental/plusone/plusone_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::experimental::plusone::detail { | ||
namespace py = pybind11; | ||
void bind_experimental_plusone_operation(py::module& module); | ||
|
||
} // namespace ttnn::operations::experimental::plusone::detail |