Skip to content

Commit

Permalink
#4438: Add single-core fold op for Resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed Feb 26, 2024
1 parent aa65e6d commit 912a964
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 7 deletions.
1 change: 1 addition & 0 deletions tests/tt_eager/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ TT_EAGER_TESTS += \
tests/tt_eager/ops/test_tilize_op_channels_last \
tests/tt_eager/ops/test_tilize_zero_padding_channels_last \
tests/tt_eager/ops/test_sfpu \
tests/tt_eager/ops/test_fold_op \
tests/tt_eager/tensors/test_copy_and_move \
tests/tt_eager/tensors/test_host_device_loopback \
tests/tt_eager/tensors/test_raw_host_memory_pointer \
Expand Down
43 changes: 43 additions & 0 deletions tests/tt_eager/ops/test_fold_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <functional>
#include <random>
#include <tt_numpy/functions.hpp>

#include "tt_eager/tensor/tensor.hpp"
#include "tt_eager/tt_dnn/op_library/fold/fold_op.hpp"
#include "tt_eager/tt_dnn/op_library/program_cache.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt;
using namespace tt::tt_metal;
using namespace constants;

void run_fold(Device *device, Shape shape) {
Tensor input_tensor = tt::numpy::random::random(shape).to(Layout::ROW_MAJOR).to(device);
uint32_t stride_h = 2;
uint32_t stride_w = 2;
Tensor device_output_tensor = fold(input_tensor, stride_h, stride_w);
Tensor output_tensor = device_output_tensor.cpu();
}

int main(int argc, char **argv) {
int device_id = 0;
tt_metal::Device *device = tt_metal::CreateDevice(device_id);

run_fold(device, {1, 2, 2, 2});
bool pass = CloseDevice(device);

if (pass) {
log_info(LogTest, "Test Passed");
} else {
TT_THROW("Test Failed");
}

TT_FATAL(pass);

return 0;
}
52 changes: 52 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/test_fold_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import tt_lib as ttl

from models.utility_functions import skip_for_wormhole_b0, torch2tt_tensor


def fold_torch(input_tensor, stride_h, stride_w):
N, H, W, C = input_tensor.shape

reshaped = input_tensor.reshape(N, H // stride_h, stride_h, W // stride_w, stride_w, C)
transposed = reshaped.permute(0, 1, 3, 2, 4, 5)
return transposed.reshape(N, H // stride_h, W // stride_w, C * stride_h * stride_w)


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"act_shape,stride_h,stride_w",
[
((1, 2, 2, 2), 2, 2),
((1, 2, 2, 16), 2, 2),
((10, 2, 2, 32), 2, 2),
((10, 4, 4, 32), 2, 2),
((10, 6, 8, 32), 3, 2),
((10, 6, 8, 32), 3, 1),
((10, 6, 8, 32), 1, 2),
((10, 6, 8, 32), 1, 1),
],
)
def test_fold(act_shape, stride_h, stride_w, device):
torch.manual_seed(0)

torch_input = torch.randn(act_shape, dtype=torch.bfloat16)
expected = fold_torch(torch_input, stride_h, stride_w)

tt_input = torch2tt_tensor(
torch_input,
device,
ttl.tensor.Layout.ROW_MAJOR,
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED),
)

tt_out = ttl.tensor.fold(tt_input, stride_h, stride_w)
tt_out = tt_out.cpu()
actual = tt_out.to_torch()

torch.testing.assert_allclose(actual, expected)
3 changes: 2 additions & 1 deletion tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/upsample/multi_core/upsample_op_multi_core.cpp \
tt_eager/tt_dnn/op_library/upsample/single_core/upsample_op_single_core.cpp \
tt_eager/tt_dnn/op_library/upsample/upsample_op.cpp \

tt_eager/tt_dnn/op_library/fold/fold_op.cpp \
tt_eager/tt_dnn/op_library/fold/single_core/fold_op_single_core.cpp \

TT_DNN_LIB = $(LIBDIR)/libtt_dnn.a
TT_DNN_DEFINES =
Expand Down
56 changes: 56 additions & 0 deletions tt_eager/tt_dnn/op_library/fold/fold_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_dnn/op_library/fold/fold_op.hpp"

#include "tt_dnn/op_library/run_operation.hpp"

namespace tt::tt_metal {
void Fold::validate(const std::vector<Tensor> &input_tensors) const {
const Tensor &input_tensor = input_tensors.at(0);

TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Expect input tensor to be stored on device.");
TT_FATAL(input_tensor.buffer() != nullptr, "Expect input tensor to be allocated on a device buffer.");
TT_FATAL(input_tensor.layout() == Layout::ROW_MAJOR, "Expect input tensor in row-major layout.");
TT_FATAL(
input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"Folding of sharded tensors is not supported.");

TT_FATAL(input_tensor.shape()[1] % stride_h == 0);
TT_FATAL(input_tensor.shape()[2] % stride_w == 0);
}

std::vector<Shape> Fold::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
const Shape &input_shape = input_tensors.at(0).shape();

// we concatenate (stride_h sticks in H-dim) * (stride_w in W-dim) into 1 stick along C-dim
return {{
input_shape[0],
input_shape[1] / stride_h,
input_shape[2] / stride_w,
input_shape[3] * stride_h * stride_w,
}};
}

std::vector<Tensor> Fold::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const Tensor &input_tensor = input_tensors.at(0);
DataType output_dtype = input_tensor.dtype();

return operation::generic_create_output_tensors(
*this, input_tensors, output_dtype, Layout::ROW_MAJOR, input_tensor.memory_config());
}

operation::ProgramWithCallbacks Fold::create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const {
const Tensor &input_tensor = input_tensors.at(0);
Tensor &output_tensor = output_tensors.at(0);

return fold_single_core(input_tensor, output_tensor, stride_h, stride_w);
}

Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w) {
return operation::run(Fold{.stride_h = stride_h, .stride_w = stride_w}, {input_tensor_a}).at(0);
}

} // namespace tt::tt_metal
35 changes: 35 additions & 0 deletions tt_eager/tt_dnn/op_library/fold/fold_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <vector>

#include "tensor/tensor.hpp"
#include "tt_dnn/op_library/operation.hpp"

namespace tt::tt_metal {
struct Fold {
uint8_t stride_h;
uint8_t stride_w;

void validate(const std::vector<Tensor> &input_tensors) const;

std::vector<Shape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;

std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;

operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;

static constexpr auto attribute_names = std::make_tuple("stride_h", "stride_w");

const auto attribute_values() const { return std::make_tuple(stride_h, stride_w); }
};

operation::ProgramWithCallbacks fold_single_core(
const Tensor &input, const Tensor &output, uint8_t stride_h, uint8_t stride_w);

Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w);
} // namespace tt::tt_metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"
#include "debug/dprint.h" // required in all kernels using DPRINT
#include "debug/dprint_tile.h"

void kernel_main() {
const uint32_t dst_addr = get_arg_val<uint32_t>(0);
const uint32_t dst_page_size = get_arg_val<uint32_t>(1);

const uint32_t scratch_addr = get_arg_val<uint32_t>(2);

const uint32_t pixel_size = get_arg_val<uint32_t>(3);
const uint32_t aligned_pixel_size = get_arg_val<uint32_t>(4);
const uint32_t aligned_chunk_size = get_arg_val<uint32_t>(5);
const uint32_t aligned_row_size = get_arg_val<uint32_t>(6);

const uint32_t stride_h = get_arg_val<uint32_t>(7);
const uint32_t stride_w = get_arg_val<uint32_t>(8);

const uint32_t num_dst_rows = get_arg_val<uint32_t>(9);
const uint32_t num_dst_cols = get_arg_val<uint32_t>(10);
const uint32_t cb_pages_per_dst_row = get_arg_val<uint32_t>(11);

constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;

#define stick_size_is_power_of_two get_compile_time_arg_val(2) == 1

#if (stick_size_is_power_of_two)
constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(3);
const InterleavedPow2AddrGen<dst_is_dram> s = {
.bank_base_address = dst_addr, .log_base_2_of_page_size = log_base_2_of_page_size};
#else
const InterleavedAddrGen<dst_is_dram> s = {
.bank_base_address = dst_addr,
.page_size = dst_page_size,
};
#endif

auto extract_next_dst_page = [&](uint32_t src_address) -> uint32_t {
// The src_address value has been offset so that it points to the start of the next output pixel.
auto src_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t *>(src_address);

// scratch fits exactly stride_h * stride_w * C elements
auto scratch_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t *>(scratch_addr);

for (uint32_t row = 0; row < stride_h; ++row) {
uint32_t src_col_offset = 0;
for (uint32_t col = 0; col < stride_w; ++col) {
for (uint32_t i = 0; i < pixel_size; ++i) {
scratch_ptr[i] = src_ptr[src_col_offset + i];
}
scratch_ptr += pixel_size;
src_col_offset += aligned_pixel_size;
}
src_ptr += aligned_row_size;
}

return src_address + aligned_chunk_size;
};

for (uint32_t i = 0, dst_page_id = 0; i < num_dst_rows; ++i) {
cb_wait_front(cb_id_out0, cb_pages_per_dst_row);
uint32_t src_addr = get_read_ptr(cb_id_out0);
// DPRINT << TSLICE(cb_id_out0, 0, SliceRange{.h0 = 0, .h1 = 4, .hs = 1, .w0 = 0, .w1 = 2, .ws = 1}) << ENDL();

for (uint32_t j = 0; j < num_dst_cols; ++j) {
src_addr = extract_next_dst_page(src_addr);
uint64_t dst_addr = get_noc_addr(dst_page_id, s);
noc_async_write(scratch_addr, dst_addr, dst_page_size);
dst_page_id += 1;
}

noc_async_write_barrier();
cb_pop_front(cb_id_out0, cb_pages_per_dst_row);
}
}
Loading

0 comments on commit 912a964

Please sign in to comment.