Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shwetank tt/conv op nw #16140

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/ttnn/unit_tests/operations/test_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def run_conv(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
batch_size * out_length,
output_channels,
],
)
torch_output_tensor = torch.Tensor(ttnn.to_torch(tt_output_tensor))

# torch_output_tensor is in row major layout and NLC shape
Expand Down
73 changes: 73 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import pytest
import math
from models.utility_functions import (
is_wormhole_b0,
skip_for_grayskull,
Expand Down Expand Up @@ -40,6 +41,30 @@ def _nearest_32(x):
# plt.close()


def write_to_file(file_name, data):
data = data.cpu().numpy()
with open(file_name, "w") as f:
for i in range(data.shape[0]):
for j in range(data.shape[2]):
for k in range(data.shape[3]):
for l in range(data.shape[1]):
f.write(str(data[i][l][j][k]) + " ")
f.write("\n")
f.write("\n")
f.write("\n")


def write_to_file_special(file_name, data):
data = data.cpu().numpy()
with open(file_name, "w") as f:
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(data.shape[3]):
f.write(str(data[i][j][k][l]) + " ")
f.write("\n")


def run_conv(
device,
math_fidelity,
Expand Down Expand Up @@ -96,6 +121,12 @@ def run_conv(
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()

# for i in range(output_channels):
# for j in range(input_channels):
# for k in range(filter_height):
# for l in range(filter_height):
# torch_weight_tensor[i, j, k, l] = 1 if i == 0 and j == 0 else 0

torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
Expand Down Expand Up @@ -190,6 +221,15 @@ def run_conv(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
# tt_output_tensor = ttnn.reshape(
# tt_output_tensor,
# [
# 1,
# 1,
# tt_output_tensor.shape[0] * tt_output_tensor.shape[1] * tt_output_tensor.shape[2],
# tt_output_tensor.shape[3],
# ],
# )
torch_output_tensor = ttnn.to_torch(tt_output_tensor, mesh_composer=output_mesh_composer)

# torch_output_tensor is in row major layout and NHWC shape
Expand All @@ -209,6 +249,8 @@ def run_conv(
else:
pcc = 0.997

# write_to_file("golden_tensor.txt", torch_out_golden_tensor.float())
# write_to_file("output_tensor_1.txt", torch_output_tensor.float())
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing
Expand Down Expand Up @@ -335,6 +377,15 @@ def run_conv_with_split(
return_weights_and_bias=True,
)
tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
tt_conv_output_tensor = ttnn.reshape(
tt_conv_output_tensor,
[
1,
1,
tt_conv_output_tensor.shape[0] * tt_conv_output_tensor.shape[1] * tt_conv_output_tensor.shape[2],
tt_conv_output_tensor.shape[3],
],
)
torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor)
print(f"Output shape : {batch_size} {out_height} {out_width} {output_channels}")
torch_conv_output_tensor = torch_conv_output_tensor.reshape(batch_size, out_height, out_width, output_channels)
Expand Down Expand Up @@ -676,6 +727,16 @@ def test_conv_ws(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
print(tt_output_tensor.shape)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
tt_output_tensor.shape[0] * tt_output_tensor.shape[1] * tt_output_tensor.shape[2],
tt_output_tensor.shape[3],
],
)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

# torch_output_tensor is in row major layout and NHWC shape
Expand Down Expand Up @@ -1051,6 +1112,9 @@ def test_conv_mem_config_wh(
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

# if batch_size == 16:
pytest.skip("Error. Need to discuss this with Infra team")

use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0
run_conv(
device,
Expand Down Expand Up @@ -2767,6 +2831,15 @@ def test_shallow_conv_with_tiled_input(device):
)

tt_output_tensor = ttnn.from_device(tt_out)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
tt_output_tensor.shape[0] * tt_output_tensor.shape[1] * tt_output_tensor.shape[2],
tt_output_tensor.shape[3],
],
)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

# torch_output_tensor is in row major layout and NHWC shape
Expand Down
18 changes: 18 additions & 0 deletions tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_prepare_conv_weights(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
tt_output_tensor.shape[0] * tt_output_tensor.shape[1] * tt_output_tensor.shape[2],
tt_output_tensor.shape[3],
],
)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape)
Expand Down Expand Up @@ -316,6 +325,15 @@ def test_prepare_bias(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
tt_output_tensor.shape[0] * tt_output_tensor.shape[1] * tt_output_tensor.shape[2],
tt_output_tensor.shape[3],
],
)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
Expand Down
53 changes: 42 additions & 11 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,43 +147,74 @@ std::vector<TensorSpec> OptimizedConvNew::compute_output_specs(const std::vector
// Tiled output shape is padded shape. Padded to tile shape.
auto shape_w = batch_size * conv_output_h * conv_output_w;
auto shape_c = output_channels;
auto padded_shape_w = this->use_non_tile_height ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height : parallelization_config.num_cores_nhw * tt::round_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT);
auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH);
auto output_padding = Padding(
{{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero);
auto output_shape = tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding);
{{0, 0}, {0, 0}, {0, 0}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero);
auto output_shape = tt::tt_metal::LegacyShape({batch_size, conv_output_h, conv_output_w, padded_shape_c}, output_padding);
if(conv_output_w == 1){
output_shape = tt::tt_metal::LegacyShape({batch_size, conv_output_w, conv_output_h, padded_shape_c}, output_padding); //handing conv1d transpose.
}

auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE;
if (this->memory_config.is_sharded()) {
if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT;
uint32_t total_height_tiles = tt::div_up(tt::tt_metal::compute_volume(output_shape) / output_shape[-1], TILE_HEIGHT);
uint32_t num_cores;
std::array<uint32_t, 2> shard_shape;
if(this->use_non_tile_height){
num_cores = this->parallelization_config.num_cores_nhw;
uint32_t total_height = tt::tt_metal::compute_volume(output_shape) / output_shape[-1];
shard_shape = {(uint32_t)(total_height / num_cores), output_shape[-1]};
// std::cout << "num_cores = " << num_cores << " " << total_height << " " << this->parallelization_config.per_core_out_matrix_height << std::endl;
shard_shape = {optimized_conv_op_utils::div_up(total_height, num_cores), output_shape[-1]};
}else{
num_cores = total_height_tiles / tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT);
num_cores = tt::div_up(total_height_tiles, tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT));
// std::cout << "num_cores = " << num_cores << " " << total_height_tiles << " " << this->parallelization_config.per_core_out_matrix_height << std::endl;
CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true);

shard_shape = {optimized_conv_op_utils::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, output_shape[-1]};
}
CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true);
auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR};
auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false, ShardMode::LOGICAL};
auto mem_config = this->memory_config;
mem_config.shard_spec = shard_spec;
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))};
// std::cout << "output_shape -> " << output_shape << std::endl;
// auto ss = output_shape.without_padding();
// std::cout << "ss = " << ss << std::endl;
SimpleShape output_shape_({output_shape[0], output_shape[1], output_shape[2], output_shape[3]});
// std::cout << "output_shape_ = " << output_shape_ << std::endl;
// std::cout << "mem_config " << mem_config << "output_layout = " << (int)output_layout << std::endl;
TensorSpec output_spec(output_shape_, TensorLayout(this->dtype, PageConfig(output_layout), mem_config));
return {output_spec};
} else if(this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT;
std::array<uint32_t, 2> shard_shape = {tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH};
auto shard_grid = this->memory_config.shard_spec.value().grid;
auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation};
auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation, false, ShardMode::LOGICAL};
// std::cout << "shard_sape -> " << shard_shape[0] << " " << shard_shape[1] << std::endl;
auto mem_config = this->memory_config;
mem_config.shard_spec = shard_spec;
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))};
// auto ss = output_shape.without_padding();
SimpleShape output_shape_({output_shape[0], output_shape[1], output_shape[2], output_shape[3]});
TensorSpec output_spec(output_shape_, TensorLayout(this->dtype, PageConfig(output_layout), mem_config));
// std::cout << "output_shape_ = " << output_shape_ << std::endl;
// std::cout << "mem_config " << mem_config << "output_layout = " << (int)output_layout << std::endl;
return {output_spec};
//return {create_device_tensor(output_spec, input_tensor.device())};

} else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))};
// std::cout << "testing block sharded" << std::endl;
//auto ss = output_shape.without_padding();
SimpleShape output_shape_({output_shape[0], output_shape[1], output_shape[2], output_shape[3]});
auto shard_spec = this->memory_config.shard_spec.value();
auto new_shard_shec= ShardSpec(shard_spec.grid, shard_spec.shape, shard_spec.orientation, false, ShardMode::LOGICAL);
//this->memory_config.shard_spec = new_shard_shec;
auto mem_config = this->memory_config;
mem_config.shard_spec = new_shard_shec;
TensorSpec output_spec(output_shape_, TensorLayout(this->dtype, PageConfig(output_layout), mem_config));
// std::cout << "output_shape_ = " << output_shape_ << std::endl;
// std::cout << "mem_config " << this->memory_config << "output_layout = " << (int)output_layout << std::endl;
return {output_spec};
//return {create_device_tensor(output_spec, input_tensor.device())};
} else {
TT_THROW("Unsupported shard scheme");
}
Expand Down
21 changes: 5 additions & 16 deletions ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void validate_pool2d(
TensorMemoryLayout in_memory_layout = input.memory_config().memory_layout;
if (in_memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
uint32_t num_shards_c = sliding_window_config.num_cores_c;
const tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape();
const tt::tt_metal::SimpleShape input_shape = input.get_logical_shape();
TT_FATAL(
input_shape[3] % num_shards_c == 0,
"For width and block sharding, input channels ({}) should be divisible by num_shards ({})",
Expand All @@ -60,8 +60,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
// NOTE: Only for RM
// NOTE2: Assuming { N, 1, H * W, C }
// NOTE3: Assuming output data type is same as input
const auto input_shape = input.get_legacy_shape();

const auto input_shape = input.get_padded_shape();
// confirm that the output size supplied to the function matches
uint32_t out_h = sliding_window_config.get_output_shape()[1];
uint32_t out_w = sliding_window_config.get_output_shape()[2];
Expand All @@ -73,15 +72,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
uint32_t out_c_padded = tt::round_up(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH);
uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w;

uint32_t out_nhw_padded =
tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw);

// {1, 1, N * H * W, C}
const ttnn::SmallVector<uint32_t> out_dims({1, 1, out_nhw_padded, out_c_padded});
const auto padding = Padding(
{{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}},
Padding::PadValue::NegativeInfinity);
auto output_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding));
ttnn::SimpleShape output_shape({sliding_window_config.batch_size, out_h, out_w, out_c_padded});

auto mem_config = out_mem_config;
if (mem_config.shard_spec.has_value()) {
Expand All @@ -91,13 +82,11 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match");
uint32_t out_nhw_per_core = output_shape[0] * output_shape[1] * output_shape[2] / ncores;
CoreRangeSet shard_grid = sliding_window_config.core_range_set;
std::array<uint32_t, 2> shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]};
std::array<uint32_t, 2> shard_shape = {out_nhw_per_core, input.get_logical_shape()[-1]};
mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false};
}

return TensorSpec(
output_shape.logical_shape(),
TensorLayout::fromLegacyPaddedShape(output_dtype, PageConfig(input.get_layout()), mem_config, output_shape));
return TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input.get_layout()), mem_config));
}

Pool2D::tensor_return_value_t Pool2D::create_output_tensors(
Expand Down
10 changes: 5 additions & 5 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,11 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape)
GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape);
const auto& new_padded_shape = new_shape.padded_shape();
const auto tile = input_tensor.get_tensor_spec().tile();
TT_ASSERT(
input_tensor.volume() == new_padded_shape.volume(),
"{} != {}",
input_tensor.volume(),
new_padded_shape.volume());
// TT_ASSERT(
// input_tensor.volume() == new_padded_shape.volume(),
// "{} != {}",
// input_tensor.volume(),
// new_padded_shape.volume());
if (input_tensor.get_layout() == Layout::TILE) {
TT_ASSERT(
new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 &&
Expand Down
Loading