Skip to content

Commit

Permalink
#15621: conv1d transpose for new tensor infra.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Jan 3, 2025
1 parent 4c25f2a commit 604954d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
9 changes: 9 additions & 0 deletions tests/ttnn/unit_tests/operations/test_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def run_conv(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
tt_output_tensor = ttnn.reshape(
tt_output_tensor,
[
1,
1,
batch_size * out_length,
output_channels,
],
)
torch_output_tensor = torch.Tensor(ttnn.to_torch(tt_output_tensor))

# torch_output_tensor is in row major layout and NLC shape
Expand Down
7 changes: 5 additions & 2 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def _nearest_32(x):
def write_to_file(file_name, data):
data = data.cpu().numpy()
with open(file_name, "w") as f:
for i in range(1):
for i in range(data.shape[0]):
for j in range(data.shape[2]):
for k in range(data.shape[3]):
for l in range(data.shape[1]):
f.write(str(data[i][l][j][k]) + " ")
f.write("\n")
f.write("\n")
f.write("\n")


def write_to_file_special(file_name, data):
Expand All @@ -59,7 +60,7 @@ def write_to_file_special(file_name, data):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(16):
for l in range(data.shape[3]):
f.write(str(data[i][j][k][l]) + " ")
f.write("\n")

Expand Down Expand Up @@ -248,6 +249,8 @@ def run_conv(
else:
pcc = 0.997

# write_to_file("golden_tensor.txt", torch_out_golden_tensor.float())
# write_to_file("output_tensor_1.txt", torch_output_tensor.float())
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing
Expand Down
3 changes: 3 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ std::vector<TensorSpec> OptimizedConvNew::compute_output_specs(const std::vector
auto output_padding = Padding(
{{0, 0}, {0, 0}, {0, 0}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero);
auto output_shape = tt::tt_metal::LegacyShape({batch_size, conv_output_h, conv_output_w, padded_shape_c}, output_padding);
if(conv_output_w == 1){
output_shape = tt::tt_metal::LegacyShape({batch_size, conv_output_w, conv_output_h, padded_shape_c}, output_padding); //handing conv1d transpose.
}

auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE;
if (this->memory_config.is_sharded()) {
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace tt_metal {

namespace tensor_impl {

TensorPrintProfile TTNN_TENSOR_PRINT_PROFILE = TensorPrintProfile::Full;
TensorPrintProfile TTNN_TENSOR_PRINT_PROFILE = TensorPrintProfile::Short;

std::ostream& operator<<(std::ostream& os, const DataType& dtype) {
switch (dtype) {
Expand Down

0 comments on commit 604954d

Please sign in to comment.