-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4438: Add single-core fold op for Resnet
- Loading branch information
1 parent
aa65e6d
commit 912a964
Showing
10 changed files
with
419 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <algorithm> | ||
#include <functional> | ||
#include <random> | ||
#include <tt_numpy/functions.hpp> | ||
|
||
#include "tt_eager/tensor/tensor.hpp" | ||
#include "tt_eager/tt_dnn/op_library/fold/fold_op.hpp" | ||
#include "tt_eager/tt_dnn/op_library/program_cache.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
|
||
using namespace tt; | ||
using namespace tt::tt_metal; | ||
using namespace constants; | ||
|
||
void run_fold(Device *device, Shape shape) { | ||
Tensor input_tensor = tt::numpy::random::random(shape).to(Layout::ROW_MAJOR).to(device); | ||
uint32_t stride_h = 2; | ||
uint32_t stride_w = 2; | ||
Tensor device_output_tensor = fold(input_tensor, stride_h, stride_w); | ||
Tensor output_tensor = device_output_tensor.cpu(); | ||
} | ||
|
||
int main(int argc, char **argv) { | ||
int device_id = 0; | ||
tt_metal::Device *device = tt_metal::CreateDevice(device_id); | ||
|
||
run_fold(device, {1, 2, 2, 2}); | ||
bool pass = CloseDevice(device); | ||
|
||
if (pass) { | ||
log_info(LogTest, "Test Passed"); | ||
} else { | ||
TT_THROW("Test Failed"); | ||
} | ||
|
||
TT_FATAL(pass); | ||
|
||
return 0; | ||
} |
52 changes: 52 additions & 0 deletions
52
tests/tt_eager/python_api_testing/unit_testing/test_fold_op.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
|
||
import tt_lib as ttl | ||
|
||
from models.utility_functions import skip_for_wormhole_b0, torch2tt_tensor | ||
|
||
|
||
def fold_torch(input_tensor, stride_h, stride_w): | ||
N, H, W, C = input_tensor.shape | ||
|
||
reshaped = input_tensor.reshape(N, H // stride_h, stride_h, W // stride_w, stride_w, C) | ||
transposed = reshaped.permute(0, 1, 3, 2, 4, 5) | ||
return transposed.reshape(N, H // stride_h, W // stride_w, C * stride_h * stride_w) | ||
|
||
|
||
@skip_for_wormhole_b0() | ||
@pytest.mark.parametrize( | ||
"act_shape,stride_h,stride_w", | ||
[ | ||
((1, 2, 2, 2), 2, 2), | ||
((1, 2, 2, 16), 2, 2), | ||
((10, 2, 2, 32), 2, 2), | ||
((10, 4, 4, 32), 2, 2), | ||
((10, 6, 8, 32), 3, 2), | ||
((10, 6, 8, 32), 3, 1), | ||
((10, 6, 8, 32), 1, 2), | ||
((10, 6, 8, 32), 1, 1), | ||
], | ||
) | ||
def test_fold(act_shape, stride_h, stride_w, device): | ||
torch.manual_seed(0) | ||
|
||
torch_input = torch.randn(act_shape, dtype=torch.bfloat16) | ||
expected = fold_torch(torch_input, stride_h, stride_w) | ||
|
||
tt_input = torch2tt_tensor( | ||
torch_input, | ||
device, | ||
ttl.tensor.Layout.ROW_MAJOR, | ||
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED), | ||
) | ||
|
||
tt_out = ttl.tensor.fold(tt_input, stride_h, stride_w) | ||
tt_out = tt_out.cpu() | ||
actual = tt_out.to_torch() | ||
|
||
torch.testing.assert_allclose(actual, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "tt_dnn/op_library/fold/fold_op.hpp" | ||
|
||
#include "tt_dnn/op_library/run_operation.hpp" | ||
|
||
namespace tt::tt_metal { | ||
void Fold::validate(const std::vector<Tensor> &input_tensors) const { | ||
const Tensor &input_tensor = input_tensors.at(0); | ||
|
||
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Expect input tensor to be stored on device."); | ||
TT_FATAL(input_tensor.buffer() != nullptr, "Expect input tensor to be allocated on a device buffer."); | ||
TT_FATAL(input_tensor.layout() == Layout::ROW_MAJOR, "Expect input tensor in row-major layout."); | ||
TT_FATAL( | ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, | ||
"Folding of sharded tensors is not supported."); | ||
|
||
TT_FATAL(input_tensor.shape()[1] % stride_h == 0); | ||
TT_FATAL(input_tensor.shape()[2] % stride_w == 0); | ||
} | ||
|
||
std::vector<Shape> Fold::compute_output_shapes(const std::vector<Tensor> &input_tensors) const { | ||
const Shape &input_shape = input_tensors.at(0).shape(); | ||
|
||
// we concatenate (stride_h sticks in H-dim) * (stride_w in W-dim) into 1 stick along C-dim | ||
return {{ | ||
input_shape[0], | ||
input_shape[1] / stride_h, | ||
input_shape[2] / stride_w, | ||
input_shape[3] * stride_h * stride_w, | ||
}}; | ||
} | ||
|
||
std::vector<Tensor> Fold::create_output_tensors(const std::vector<Tensor> &input_tensors) const { | ||
const Tensor &input_tensor = input_tensors.at(0); | ||
DataType output_dtype = input_tensor.dtype(); | ||
|
||
return operation::generic_create_output_tensors( | ||
*this, input_tensors, output_dtype, Layout::ROW_MAJOR, input_tensor.memory_config()); | ||
} | ||
|
||
operation::ProgramWithCallbacks Fold::create_program( | ||
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const { | ||
const Tensor &input_tensor = input_tensors.at(0); | ||
Tensor &output_tensor = output_tensors.at(0); | ||
|
||
return fold_single_core(input_tensor, output_tensor, stride_h, stride_w); | ||
} | ||
|
||
Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w) { | ||
return operation::run(Fold{.stride_h = stride_h, .stride_w = stride_w}, {input_tensor_a}).at(0); | ||
} | ||
|
||
} // namespace tt::tt_metal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
|
||
#include "tensor/tensor.hpp" | ||
#include "tt_dnn/op_library/operation.hpp" | ||
|
||
namespace tt::tt_metal { | ||
struct Fold { | ||
uint8_t stride_h; | ||
uint8_t stride_w; | ||
|
||
void validate(const std::vector<Tensor> &input_tensors) const; | ||
|
||
std::vector<Shape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const; | ||
|
||
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const; | ||
|
||
operation::ProgramWithCallbacks create_program( | ||
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const; | ||
|
||
static constexpr auto attribute_names = std::make_tuple("stride_h", "stride_w"); | ||
|
||
const auto attribute_values() const { return std::make_tuple(stride_h, stride_w); } | ||
}; | ||
|
||
operation::ProgramWithCallbacks fold_single_core( | ||
const Tensor &input, const Tensor &output, uint8_t stride_h, uint8_t stride_w); | ||
|
||
Tensor fold(const Tensor &input_tensor_a, uint8_t stride_h, uint8_t stride_w); | ||
} // namespace tt::tt_metal |
82 changes: 82 additions & 0 deletions
82
..._library/fold/kernels/dataflow/writer_unary_stick_layout_concatenate_rows_interleaved.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
|
||
#include "dataflow_api.h" | ||
#include "debug/dprint.h" // required in all kernels using DPRINT | ||
#include "debug/dprint_tile.h" | ||
|
||
void kernel_main() { | ||
const uint32_t dst_addr = get_arg_val<uint32_t>(0); | ||
const uint32_t dst_page_size = get_arg_val<uint32_t>(1); | ||
|
||
const uint32_t scratch_addr = get_arg_val<uint32_t>(2); | ||
|
||
const uint32_t pixel_size = get_arg_val<uint32_t>(3); | ||
const uint32_t aligned_pixel_size = get_arg_val<uint32_t>(4); | ||
const uint32_t aligned_chunk_size = get_arg_val<uint32_t>(5); | ||
const uint32_t aligned_row_size = get_arg_val<uint32_t>(6); | ||
|
||
const uint32_t stride_h = get_arg_val<uint32_t>(7); | ||
const uint32_t stride_w = get_arg_val<uint32_t>(8); | ||
|
||
const uint32_t num_dst_rows = get_arg_val<uint32_t>(9); | ||
const uint32_t num_dst_cols = get_arg_val<uint32_t>(10); | ||
const uint32_t cb_pages_per_dst_row = get_arg_val<uint32_t>(11); | ||
|
||
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); | ||
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; | ||
|
||
#define stick_size_is_power_of_two get_compile_time_arg_val(2) == 1 | ||
|
||
#if (stick_size_is_power_of_two) | ||
constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(3); | ||
const InterleavedPow2AddrGen<dst_is_dram> s = { | ||
.bank_base_address = dst_addr, .log_base_2_of_page_size = log_base_2_of_page_size}; | ||
#else | ||
const InterleavedAddrGen<dst_is_dram> s = { | ||
.bank_base_address = dst_addr, | ||
.page_size = dst_page_size, | ||
}; | ||
#endif | ||
|
||
auto extract_next_dst_page = [&](uint32_t src_address) -> uint32_t { | ||
// The src_address value has been offset so that it points to the start of the next output pixel. | ||
auto src_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t *>(src_address); | ||
|
||
// scratch fits exactly stride_h * stride_w * C elements | ||
auto scratch_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t *>(scratch_addr); | ||
|
||
for (uint32_t row = 0; row < stride_h; ++row) { | ||
uint32_t src_col_offset = 0; | ||
for (uint32_t col = 0; col < stride_w; ++col) { | ||
for (uint32_t i = 0; i < pixel_size; ++i) { | ||
scratch_ptr[i] = src_ptr[src_col_offset + i]; | ||
} | ||
scratch_ptr += pixel_size; | ||
src_col_offset += aligned_pixel_size; | ||
} | ||
src_ptr += aligned_row_size; | ||
} | ||
|
||
return src_address + aligned_chunk_size; | ||
}; | ||
|
||
for (uint32_t i = 0, dst_page_id = 0; i < num_dst_rows; ++i) { | ||
cb_wait_front(cb_id_out0, cb_pages_per_dst_row); | ||
uint32_t src_addr = get_read_ptr(cb_id_out0); | ||
// DPRINT << TSLICE(cb_id_out0, 0, SliceRange{.h0 = 0, .h1 = 4, .hs = 1, .w0 = 0, .w1 = 2, .ws = 1}) << ENDL(); | ||
|
||
for (uint32_t j = 0; j < num_dst_cols; ++j) { | ||
src_addr = extract_next_dst_page(src_addr); | ||
uint64_t dst_addr = get_noc_addr(dst_page_id, s); | ||
noc_async_write(scratch_addr, dst_addr, dst_page_size); | ||
dst_page_id += 1; | ||
} | ||
|
||
noc_async_write_barrier(); | ||
cb_pop_front(cb_id_out0, cb_pages_per_dst_row); | ||
} | ||
} |
Oops, something went wrong.