Skip to content

Commit

Permalink
#9745: move unpad to slice ttnn cpp references
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed Jul 5, 2024
1 parent c52e153 commit d7832e9
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 16 deletions.
14 changes: 7 additions & 7 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "tt_dnn/op_library/permute/permute_op.hpp"
#include "tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_dnn/op_library/reshape/reshape_op.hpp"
#include "tt_dnn/op_library/unpad/unpad_op.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "tt_eager/tensor/tensor_utils.hpp"
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "tt_numpy/functions.hpp"
Expand Down Expand Up @@ -1532,9 +1532,9 @@ std::vector<Tensor> _prod_bw(
const Shape start_index = {0, 0, 0, 0};
const Shape end_index = {
grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1};
Tensor new_unpad_tensor = unpad(required, start_index, end_index);
Tensor new_slice_tensor = ttnn::slice(required, start_index, end_index, std::nullopt);
after_permute_dims = {0, 2, 3, 1};
updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config);
updated_grad = permute(new_slice_tensor, after_permute_dims, output_mem_config);
Tensor pad_updated_grad = updated_grad.pad_to_tile(1.0f);
Tensor pad_prod_result = prod_result.pad_to_tile(1.0f);
pad_updated_grad = pad_updated_grad.to(Layout::TILE);
Expand All @@ -1549,8 +1549,8 @@ std::vector<Tensor> _prod_bw(
const Shape start_index = {0, 0, 0, 0};
const Shape end_index = {
grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[3] - 1};
Tensor new_unpad_tensor = unpad(required, start_index, end_index);
updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config);
Tensor new_slice_tensor = ttnn::slice(required, start_index, end_index, std::nullopt);
updated_grad = permute(new_slice_tensor, after_permute_dims, output_mem_config);
if(updated_grad.get_layout()==Layout::ROW_MAJOR){
updated_grad = tt::tt_metal::change_layout_to_tile(updated_grad, output_mem_config);
}
Expand Down Expand Up @@ -1599,7 +1599,7 @@ std::vector<Tensor> _prod_bw(
input.get_legacy_shape()[1] - 1,
input.get_legacy_shape()[2] - 1,
input.get_legacy_shape()[3] - 1};
grad_result = unpad(result, start_index, end_index);
grad_result = ttnn::slice(result, start_index, end_index, std::nullopt);
}
grad_tensor.emplace_back(grad_result);
return grad_tensor;
Expand Down Expand Up @@ -1633,7 +1633,7 @@ std::vector<Tensor> _prod_bw(
input.get_legacy_shape()[1] - 1,
input.get_legacy_shape()[2] - 1,
input.get_legacy_shape()[3] - 1};
grad_result = unpad(result, start_index, end_index);
grad_result = ttnn::slice(result, start_index, end_index, std::nullopt);
}
grad_tensor.emplace_back(grad_result);
return grad_tensor;
Expand Down
6 changes: 3 additions & 3 deletions tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "tt_dnn/op_library/concat/concat_op.hpp"
#include "tt_dnn/op_library/bmm/bmm_op.hpp"
#include "tt_dnn/op_library/reshape/reshape_op.hpp"
#include "tt_dnn/op_library/unpad/unpad_op.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_eager/tensor/tensor_utils.hpp"

Expand All @@ -33,15 +33,15 @@ Tensor get_real(const Tensor& input, const MemoryConfig& output_mem_config) {
Shape t_Shape = input.get_legacy_shape();
Shape start = {0, 0, 0, 0} ;
Shape end = {t_Shape[0] - 1,t_Shape[1] - 1 ,t_Shape[2] - 1, (t_Shape[3] / 2) - 1};
Tensor r_tensor = unpad(input, start, end, output_mem_config);
Tensor r_tensor = ttnn::slice(input, start, end, output_mem_config);
return r_tensor;
}

Tensor get_imag(const Tensor& input, const MemoryConfig& output_mem_config) {
Shape t_Shape = input.get_legacy_shape();
Shape start = {0, 0, 0, (t_Shape[3] / 2)};
Shape end = {t_Shape[0] - 1,t_Shape[1] - 1 ,t_Shape[2] - 1, (t_Shape[3] - 1)};
Tensor i_tensor = unpad(input, start, end, output_mem_config);
Tensor i_tensor = ttnn::slice(input, start, end, output_mem_config);
return i_tensor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ operation::ProgramWithCallbacks unpad_rm_multi_core(

tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/reader_unary_unpad_dims_rm_interleaved_start_id.cpp",
"ttnn/cpp/ttnn/operations/data_movement/slice/kernels/dataflow/reader_unary_unpad_dims_rm_interleaved_start_id.cpp",
total_cores,
tt_metal::ReaderDataMovementConfig(reader_compile_time_args_vec));

tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp",
"ttnn/cpp/ttnn/operations/data_movement/slice/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp",
total_cores,
tt_metal::WriterDataMovementConfig(writer_compile_time_args_vec));

Expand Down Expand Up @@ -435,7 +435,7 @@ operation::ProgramWithCallbacks unpad_tile_multi_core(
// Tilized reader
tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/reader_unary_unpad_dims_interleaved_start_id.cpp",
"ttnn/cpp/ttnn/operations/data_movement/slice/kernels/dataflow/reader_unary_unpad_dims_interleaved_start_id.cpp",
total_cores,
tt_metal::ReaderDataMovementConfig(reader_compile_time_args));

Expand Down
127 changes: 127 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "tt_eager/tensor/types.hpp"
#include "ttnn/cpp/ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp"

#include <ranges>


namespace ttnn {
namespace operations {
namespace data_movement {



struct ExecuteSlice {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
2, // min rank
4, // max rank
{ttnn::bfloat16, ttnn::bfloat8_b, ttnn::uint16, ttnn::int32, ttnn::uint32},
{ttnn::TILE_LAYOUT},
true, // can_be_on_device
false, // can_be_on_cpu
false, // can_be_scalar
false // is_optional}
}};
}

template <typename... Args>
static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
tt::tt_metal::Shape output_tensor_start,
tt::tt_metal::Shape output_tensor_end,
const std::optional<MemoryConfig>& memory_config_arg) {

auto memory_config = memory_config_arg.value_or(input_tensor.memory_config());

auto input_tensor_shape = input_tensor.get_legacy_shape();
std::vector<uint32_t> output_tensor_shape = {
output_tensor_end[0] - output_tensor_start[0] + 1,
output_tensor_end[1] - output_tensor_start[1] + 1,
output_tensor_end[2] - output_tensor_start[2] + 1,
output_tensor_end[3] - output_tensor_start[3] + 1,
};
auto output_tensor = operation::run(
tt::tt_metal::Unpad{
.output_tensor_start=output_tensor_start,
.output_tensor_end=output_tensor_end,
.output_mem_config=memory_config,
.output_shape=output_tensor_shape,
.input_shape=input_tensor_shape
},
{input_tensor}).front();

return output_tensor;
}


static ttnn::Tensor execute_on_worker_thread(
const ttnn::Tensor& input_tensor,
tt::tt_metal::Shape output_tensor_start,
tt::tt_metal::Shape output_tensor_end,
const std::optional<MemoryConfig>& memory_config_arg
) {


return execute_on_worker_thread(
0,
input_tensor,
output_tensor_start,
output_tensor_end,
memory_config_arg
);

}


static ttnn::Tensor execute_on_worker_thread(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
std::vector<uint32_t> output_tensor_start,
std::vector<uint32_t> output_tensor_end,
const std::optional<MemoryConfig>& memory_config_arg) {

return execute_on_worker_thread(queue_id,
input_tensor,
tt::tt_metal::Shape(output_tensor_start),
tt::tt_metal::Shape(output_tensor_end),
memory_config_arg
);

}

static ttnn::Tensor execute_on_worker_thread(
const ttnn::Tensor& input_tensor,
std::vector<uint32_t> output_tensor_start,
std::vector<uint32_t> output_tensor_end,
const std::optional<MemoryConfig>& memory_config_arg) {

return execute_on_worker_thread(
0,
input_tensor,
output_tensor_start,
output_tensor_end,
memory_config_arg
);
}


};

} // namespace data_movement
} // namespace operations

constexpr auto slice = ttnn::register_operation<ttnn::operations::data_movement::ExecuteSlice>("ttnn::slice");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
Expand Down Expand Up @@ -374,7 +374,7 @@ std::vector<Tensor> _concat_bw(
input.get_legacy_shape()[2] - 1,
input.get_legacy_shape()[3] - 1};

Tensor grad_a = unpad(grad, start_index, end_index);
Tensor grad_a = ttnn::slice(grad, start_index, end_index, std::nullopt);
grad_tensor.emplace_back(grad_a);

tt::tt_metal::Shape start_index_2 = {0, 0, 0, 0};
Expand All @@ -393,7 +393,7 @@ std::vector<Tensor> _concat_bw(
grad.get_legacy_shape()[1] - 1,
grad.get_legacy_shape()[2] - 1,
grad.get_legacy_shape()[3] - 1};
Tensor grad_b = unpad(grad, start_index_2, end_index_2);
Tensor grad_b = ttnn::slice(grad, start_index_2, end_index_2, std::nullopt);
grad_tensor.emplace_back(grad_b);

return grad_tensor;
Expand Down

0 comments on commit d7832e9

Please sign in to comment.