Skip to content

Commit

Permalink
#0: DRAM Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt authored and mywoodstock committed Sep 27, 2024
1 parent c076acb commit 058eb39
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
7 changes: 6 additions & 1 deletion tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ def run_conv_with_split(
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div",
((64, 32, 130, 130, 3, 3, 0, 0, 1),),
(
(64, 32, 130, 130, 3, 3, 0, 0, 1),
(64, 32, 128, 128, 3, 3, 1, 1, 1),
(64, 32, 1024, 1024, 3, 3, 1, 1, 1),
),
)
@pytest.mark.parametrize(
"has_bias",
Expand Down Expand Up @@ -424,6 +428,7 @@ def test_conv_dram(
reshard_if_not_optimal=True,
act_block_w_div=act_block_w_div,
output_height_in_l1=64,
act_block_h_override=64,
)
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=tt_input_tensor,
Expand Down
10 changes: 5 additions & 5 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,13 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
std::array<uint32_t, 4>{1, 1, 1, 1} //Step
);
log_debug(tt::LogOp, "Sliced input tensor shape: {}", sliced_input_tensor.get_shape());
if(pad_top>0)
if(pad_top>0 || pad_bottom > 0)
{
auto pad_top_tensor = ttnn::pad(
DefaultQueueId,
sliced_input_tensor,
tt::tt_metal::Array4D({1, input_slice_height + pad_top, input_width, in_channels}),
tt::tt_metal::Array4D({0, 0, 0, 0}),
0);
std::vector<std::pair<uint32_t, uint32_t>>{{0, 0}, {pad_top, pad_bottom}, {0, 0}, {0, 0}},
0, true, std::nullopt);
sliced_input_tensor = pad_top_tensor;
}
log_debug(tt::LogOp, "Padded sliced input tensor shape: {}", sliced_input_tensor.get_shape());
Expand All @@ -745,7 +745,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
device,
in_channels,
out_channels,
1, input_slice_height + pad_top, input_width,
1, input_slice_height + pad_top + pad_bottom, input_width,
kernel_size, stride, {0,padding[1]}, dilation,
groups,
first_run ? bias_tensor : (std::optional<const ttnn::Tensor>)(bias_tensor_on_device),
Expand Down

0 comments on commit 058eb39

Please sign in to comment.