Skip to content

Commit

Permalink
#0: replace repeat_interleave implementation with a more generic impl…
Browse files Browse the repository at this point in the history
…ementation that lowers to concat and removes the hacky permute implementation

- delete old repeat_interleave cpp test - uint16 is not used for any model and it only works because of hacks
- move the uint16 test to python to compensate but disable for GS as it's uint16
  • Loading branch information
sjameelTT committed Oct 11, 2024
1 parent f9ec7b1 commit c31bdbb
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import is_wormhole_b0, skip_for_blackhole
from models.utility_functions import is_wormhole_b0, is_grayskull, skip_for_blackhole

shapes = (
[[1, 1, 32, 32]], # Single core
Expand All @@ -30,6 +30,8 @@
@pytest.mark.parametrize("dim", [0, 2, -4, -2, 1, 3])
@pytest.mark.parametrize("repeat", [2, 3, 4])
def test_run_repeat_interleave_test(input_shapes, dim, repeat, device):
if is_grayskull and dim == 3:
pytest.skip("Grayskull does not support dim=3 because we cannot tranpose WH reliably")
datagen_func = [
generation_funcs.gen_func_with_cast(
partial(generation_funcs.gen_rand_along_dim, low=-100, high=100),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,22 +660,25 @@ def test_transpose_hc(dtype, shape, device):
)
@pytest.mark.parametrize(
"shape",
[(1, 32), (1, 12), (1, 35), (16, 32), (34, 8)],
[(9216, 128), (1, 32), (1, 12), (1, 35), (16, 32), (34, 8)],
)
@pytest.mark.parametrize(
"layout",
[ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT],
[ttnn.TILE_LAYOUT],
)
@pytest.mark.parametrize(
"dims",
[(1, 0), (-1, -2)],
)
def test_transpose_2D(dtype, shape, layout, dims, device):
def test_transpose_2D(dtype, shape, layout, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Skipping float32 tests on Grayskull")
if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and shape[-1] % 2:
if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or shape[-2] % 2):
pytest.skip("Skipping RM odd inner dim test cases")
transpose(shape, device, dim0=0, dim1=1, input_dtype=dtype)

torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(0, 1)

tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=layout, device=device)
tt_output = ttnn.transpose(tt_input, 0, 1)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


@pytest.mark.parametrize(
Expand All @@ -689,7 +692,7 @@ def test_transpose_2D(dtype, shape, layout, dims, device):
)
@pytest.mark.parametrize(
"layout",
[ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT],
[ttnn.TILE_LAYOUT],
)
@pytest.mark.parametrize(
"dims",
Expand All @@ -698,10 +701,41 @@ def test_transpose_2D(dtype, shape, layout, dims, device):
def test_transpose_3D(dtype, shape, layout, dims, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Skipping float32 tests on Grayskull")

new_shape = shape
new_shape[dims[0]], new_shape[dims[1]] = shape[dims[1]], shape[dims[0]]
if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or new_shape[-1] % 2):
if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat16 and (shape[-1] % 2 or shape[dims[-1]] % 2):
pytest.skip("Skipping RM odd inner dim test cases")

transpose(shape, device, dim0=dims[0], dim1=dims[1], input_dtype=dtype)
torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(dims[0], dims[1])

tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=layout, device=device)
tt_output = ttnn.transpose(tt_input, dims[0], dims[1])
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


@pytest.mark.parametrize(
"shape",
[[4, 3, 1280, 40], [1, 4096, 4096]],
)
def test_transpose_4d_wh_rm(shape, device):
torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(-1, -2)

tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
tt_output = ttnn.transpose(tt_input, -1, -2)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)


@pytest.mark.parametrize(
"shape",
[[4, 3, 1280, 40], [1, 1200, 1280]],
)
def test_transpose_4d_wh_tile(shape, device):
torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(-1, -2)

tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=ttnn.TILE_LAYOUT, device=device)
tt_output = ttnn.transpose(tt_input, -1, -2)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)
1 change: 0 additions & 1 deletion tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
set(TTNN_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_graph_add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_repeat_interleave.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_async_runtime.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_multiprod_queue.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_multi_cq_multi_dev.cpp
Expand Down
105 changes: 0 additions & 105 deletions tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp

This file was deleted.

11 changes: 11 additions & 0 deletions tests/ttnn/unit_tests/operations/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,14 @@ def test_sharded_concat(

assert_with_pcc(torch_output_tensor, output, 0.9999)
assert_with_pcc(torch_output_tensor, output)


@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_concat_5d(device, dim):
torch_input_tensor = torch.rand(1, 1, 1, 1, 2, dtype=torch.bfloat16)
torch_result = torch.cat([torch_input_tensor, torch_input_tensor], dim=dim)

ttnn_input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
ttnn_result = ttnn.concat([ttnn_input_tensor, ttnn_input_tensor], dim=dim)
ttnn_result = ttnn.to_torch(ttnn_result)
assert_with_pcc(torch_result, ttnn_result, 0.9999)
21 changes: 15 additions & 6 deletions tests/ttnn/unit_tests/operations/test_repeat_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@
from loguru import logger

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import (
is_grayskull,
is_wormhole_b0,
)


@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.uint16])
def test_repeat_interleave(device, repeats, dim, dtype):
torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=dtype)
torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
if dtype == ttnn.uint16:
if is_grayskull:
pytest.skip("Grayskull does not support uint16")
torch_dtype = torch.int16
torch_input_tensor = torch.randint(0, 100, (1, 1, 32, 32), dtype=torch_dtype)
else:
torch_dtype = torch.bfloat16
torch_input_tensor = torch.rand(1, 1, 32, 32, dtype=torch_dtype)

torch_result = torch.repeat_interleave(torch_input_tensor, repeats, dim=dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, dtype=dtype, device=device)
output = ttnn.repeat_interleave(input_tensor, repeats, dim=dim)
output = ttnn.to_torch(output)

assert_with_pcc(torch_result, output, 0.9999)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct ToMemoryConfig {

// TODO: Move to cpp once we merge with tt_eager
static Tensor invoke(
const ttnn::Tensor& tensor, const ttnn::MemoryConfig& memory_config, std::optional<ttnn::DataType> dtype) {
const ttnn::Tensor& tensor, const ttnn::MemoryConfig& memory_config, std::optional<ttnn::DataType> dtype = std::nullopt) {
// Temporary until we see why buffer data not being populated
const auto original_shape = tensor.get_shape();

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace data_movement {
return output;
});
// Convert dim after unsqueeze
dim = dim + 4 - rank;
dim = rank < 4 ? dim + 4 - rank : dim;
auto output_tensor = concat_impl(itensor, dim, mem_config);
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ Tensor concat_impl(std::vector<Tensor> &input_tensors, const std::int64_t dim, c
"Current concat implementation requires aligned last dim when concatting on last dim");
}
}
Layout target_layout = Layout::TILE;
// row major should default to row major and tilized to tilized implementations, but the below loop turned RM to tilized when possible
Layout target_layout = input_tensors[0].get_layout();
// this should be dead code when instantiating layout to match the input
for (const auto &input_tensor : input_tensors) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
const auto &input_shape = input_tensor.get_legacy_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,55 @@
#include "repeat_interleave.hpp"

#include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/cpp/ttnn/operations/copy.hpp"

namespace ttnn {
namespace operations {
namespace data_movement {


// repeat interleave supports repeats as 1 to inf, dim between 0 to 2
ttnn::Tensor ExecuteRepeatInterleave::invoke(const ttnn::Tensor& input_a, uint32_t repeat, int32_t dim, std::optional<MemoryConfig> output_mem_config) {
std::vector<Tensor> combined_tensors;
combined_tensors.reserve(repeat);
auto shape_wh = input_a.get_legacy_shape();
MemoryConfig mem_config = output_mem_config.value_or(input_a.memory_config());
// normalizing the negative dim
uint32_t normalized_dim = input_a.get_legacy_shape().get_normalized_index(dim);
// check if dim is 1 or 3
if (normalized_dim & 1) {
constexpr uint32_t tmp_dim = 2;
std::vector<int64_t> dims = {0, 1, 2, 3};
std::swap(dims[dim], dims[tmp_dim]);
Tensor transpose_input = ttnn::permute(input_a, dims);
Tensor ril_result = ExecuteRepeatInterleave::invoke(transpose_input, repeat, tmp_dim, mem_config);
return ttnn::permute(ril_result, dims);
if (repeat == 1) {
return ttnn::to_memory_config(input_a, mem_config);
}
uint32_t input_rank = input_a.get_shape().rank();
uint32_t normalized_dim = input_a.get_shape().get_normalized_index(dim);
if (normalized_dim == input_rank - 1) {
auto transposed_input = ttnn::transpose(input_a, -1, -2, mem_config);
auto repeated_input = ExecuteRepeatInterleave::invoke(transposed_input, repeat, -2, mem_config);
return ttnn::transpose(repeated_input, -1, -2, mem_config);
}

ttnn::Tensor rm_input = input_a;
bool typecast = input_a.get_dtype() != DataType::BFLOAT16;
if (typecast) {
rm_input = ttnn::typecast(rm_input, DataType::BFLOAT16, mem_config);
}

rm_input = ttnn::to_layout(rm_input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device*)nullptr);
std::vector<uint32_t> final_shape;
final_shape.reserve(input_rank);
for (uint32_t i = 0; i < rm_input.get_shape().rank(); i++) {
final_shape.push_back(rm_input.get_shape()[i]);
}
if (normalized_dim <= 1) {
for (int i = 0; i < repeat; i++) {
combined_tensors.push_back(input_a);
}
// TODO: For dim = 1 facing issue with concat_op
if (normalized_dim) {
Tensor concat_out = ttnn::concat(combined_tensors, 2);
return ttnn::reshape_on_device(concat_out, shape_wh[0], shape_wh[1] * repeat, shape_wh[2], shape_wh[3]);
} else {
Tensor concat_out = ttnn::concat(combined_tensors, 1);
return ttnn::reshape_on_device(concat_out, shape_wh[0] * repeat, shape_wh[1], shape_wh[2], shape_wh[3]);
}
} else {
Tensor reshape_out = ttnn::reshape_on_device(input_a, 1, 1, shape_wh[0] * shape_wh[1] * shape_wh[2], shape_wh[3]);
for (int i = 0; i < repeat; i++) {
combined_tensors.push_back(reshape_out);
}
Tensor concat_out = ttnn::concat(combined_tensors, 1);
std::vector<int64_t> permute_dims = {0, 2, 1, 3};
Tensor permute_out = ttnn::permute(concat_out, permute_dims);
return ttnn::reshape_on_device(permute_out, shape_wh[0], shape_wh[1], shape_wh[2] * repeat, shape_wh[3]);

final_shape[normalized_dim] *= repeat;

auto unsqueezed_tensor = ttnn::unsqueeze(rm_input, normalized_dim + 1);
for (uint32_t i = 0; i < repeat; i++) {
combined_tensors.push_back(unsqueezed_tensor);
}

auto concatenated_tensor = ttnn::concat(combined_tensors, normalized_dim + 1);
auto reshaped_tensor = ttnn::reshape(concatenated_tensor, ttnn::Shape(final_shape));
auto original_layout = ttnn::to_layout(reshaped_tensor, input_a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr);
return typecast ? ttnn::typecast(original_layout, input_a.get_dtype(), mem_config) : original_layout;

}

} // namespace data_movement
Expand Down
Loading

0 comments on commit c31bdbb

Please sign in to comment.