diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index f094cf3b6f8..6f8c395b421 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -318,6 +318,156 @@ def run_conv_with_split( @skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("stride", [1]) +@pytest.mark.parametrize( + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div", + ( + (64, 32, 130, 130, 3, 3, 0, 0, 1), + (64, 32, 128, 128, 3, 3, 1, 1, 1), + (64, 32, 1024, 1024, 3, 3, 1, 1, 1), + ), +) +@pytest.mark.parametrize( + "has_bias", + [True], +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16], +) +def test_conv_dram( + device, + use_program_cache, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + pad_h, + pad_w, + act_block_w_div, + stride, + has_bias, + weights_dtype, + activations_dtype, +): + if is_grayskull(): + if input_channels >= 2048: + pytest.skip("Skipping on grayskull due to insufficient L1") + if input_channels >= 768 and input_height >= 10: + pytest.skip("Skipping on grayskull due to insufficient L1") + + stride_h = stride + stride_w = stride + batch_size = 2 + fp32_accum = False + packer_l1_acc = False + deallocate_activation = False + debug = False + groups = 1 + + torch.manual_seed(0) + conv_input_shape = [batch_size, input_channels, input_height, input_width] + conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width] + conv_bias_shape = [1, 1, 1, output_channels] + + torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() + torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float() + torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) + + torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float() + + tt_bias_tensor = None + torch_bias_tensor = None + if has_bias: + torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() * 50 + tt_bias_tensor = ttnn.from_torch( + torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + ) + torch_bias_tensor = torch_bias_tensor.reshape(-1) + ref = torch.nn.functional.conv2d( + torch_input_tensor_nchw, + torch_weight_tensor, + bias=torch_bias_tensor, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + groups=groups, + ) + output_shape_nhwc = [ + ref.shape[0], + ref.shape[2], + ref.shape[3], + ref.shape[1], + ] + + reader_patterns_cache = {} + tt_weight_tensor = ttnn.from_torch( + torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + ) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor, device=device, dtype=ttnn.bfloat16) + + conv_config = ttnn.Conv2dConfig( + dtype=activations_dtype, + weights_dtype=weights_dtype, + math_fidelity=ttnn.MathFidelity.HiFi4, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + input_channels_alignment=32, + deallocate_activation=deallocate_activation, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + reshard_if_not_optimal=True, + act_block_w_div=act_block_w_div, + output_height_in_l1=64, + act_block_h_override=64, + ) + [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=tt_input_tensor, + weight_tensor=tt_weight_tensor, + in_channels=input_channels, + out_channels=output_channels, + device=device, + bias_tensor=tt_bias_tensor, + kernel_size=(filter_height, filter_width), + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + batch_size=batch_size, + input_height=input_height, + input_width=input_width, + conv_config=conv_config, + conv_op_cache=reader_patterns_cache, + debug=debug, + groups=groups, + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + out = tt_output_tensor.cpu().to_torch() + + # out is in row major layout and NHWC shape + # NHWC to NCHW + # ref = torch.permute(ref, (0, 2, 3, 1)) + out = out.reshape(batch_size, out_height, out_width, output_channels) + + out = torch.permute(out, (0, 3, 1, 2)) + reader_patterns_cache.clear() + + pcc = 0.94 + passing, pcc_msg = check_with_pcc_without_tensor_printout(out, ref, pcc=pcc) + logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") + if not passing: + logger.error("Fails with PCC ", pcc_msg) + assert passing + + @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 996df42ec4a..9bffd08f23a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -5,12 +5,20 @@ #include "conv2d.hpp" #include #include +#include +#include "ttnn/common/constants.hpp" +#include "ttnn/operations/core/core.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" #include "tt_metal/detail/reports/memory_reporter.hpp" #include "ttnn/operations/core/to_dtype/to_dtype_op.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" +#include "ttnn/operations/data_movement/concat/concat.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/tensor/types.hpp" using namespace tt; namespace ttnn { @@ -671,9 +679,103 @@ std::tuple bias_tensor, std::optional conv_config_) { - Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; + + Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + if(conv_config.output_height_in_l1 > 0) { + TT_FATAL((conv_config.output_height_in_l1 % 32) == 0, "Output height in L1 must be a multiple of 32", conv_config.output_height_in_l1); + ttnn::Tensor input_tensor_on_device; + if( !is_tensor_on_device_or_multidevice(input_tensor)) { + input_tensor_on_device = ttnn::operations::core::to_device(input_tensor, device, std::nullopt); + } + else + { + input_tensor_on_device = input_tensor; + } + ttnn::Tensor weight_tensor_on_device; + std::optional bias_tensor_on_device; + if(input_tensor_on_device.memory_config().is_dram()) { + Tensor dram_output_tensor; + bool first_run = true; + for(int batch_index = 0; batch_index < batch_size; batch_index++) { + for(uint32_t output_slice_height_start = 0; output_slice_height_start < output_height; output_slice_height_start+=conv_config.output_height_in_l1) { + uint32_t output_slice_height_end = std::min(output_height, output_slice_height_start + conv_config.output_height_in_l1); + uint32_t output_slice_height = output_slice_height_end - output_slice_height_start; + + if(output_slice_height == 0) { + continue; + } + + int input_slice_height_start = output_slice_height_start * stride[0] - padding[0]; + int input_slice_height_end = (output_slice_height_end - 1) * stride[0] - padding[0] + (kernel_size[0] - 1) * (dilation[0] - 1) + kernel_size[0]; + int pad_top = std::max(0, -input_slice_height_start); + int pad_bottom = std::max(0, input_slice_height_end - input_height); + input_slice_height_start = std::max(0, input_slice_height_start); + input_slice_height_end = std::min(input_height, input_slice_height_end); + uint32_t input_slice_height = input_slice_height_end - input_slice_height_start; + log_debug(tt::LogOp, "Output Slice : {}->{}", output_slice_height_start, output_slice_height_end); + log_debug(tt::LogOp, "Input Slice : {}->{}", input_slice_height_start, input_slice_height_end); + log_debug(tt::LogOp, "Padding : {}->{}", pad_top, pad_bottom); + + if(input_slice_height_start < input_slice_height_end) { + auto sliced_input_tensor = ttnn::slice(input_tensor, + std::array{batch_index, input_slice_height_start, 0, 0}, //Start + std::array{batch_index, input_slice_height_end - 1, input_width - 1,in_channels - 1}, //End - Inclusive + std::array{1, 1, 1, 1} //Step + ); + log_debug(tt::LogOp, "Sliced input tensor shape: {}", sliced_input_tensor.get_shape()); + if(pad_top>0 || pad_bottom > 0) + { + auto pad_top_tensor = ttnn::pad( + DefaultQueueId, + sliced_input_tensor, + std::vector>{{0, 0}, {pad_top, pad_bottom}, {0, 0}, {0, 0}}, + 0, true, std::nullopt); + sliced_input_tensor = pad_top_tensor; + } + log_debug(tt::LogOp, "Padded sliced input tensor shape: {}", sliced_input_tensor.get_shape()); + auto conv_config_l1 = conv_config; + conv_config_l1.output_height_in_l1 = 0; + conv_config_l1.reshard_if_not_optimal = true; + ttnn::Tensor sliced_output_tensor; + std::tie(sliced_output_tensor, std::ignore, std::ignore, weight_tensor_on_device, bias_tensor_on_device) = conv2d( + sliced_input_tensor, + first_run ? weight_tensor : weight_tensor_on_device, + device, + in_channels, + out_channels, + 1, input_slice_height + pad_top + pad_bottom, input_width, + kernel_size, stride, {0,padding[1]}, dilation, + groups, + first_run ? bias_tensor : (std::optional)(bias_tensor_on_device), + conv_config_l1 + ); + sliced_output_tensor = ttnn::to_layout(sliced_output_tensor, Layout::ROW_MAJOR, std::nullopt, + MemoryConfig { + .memory_layout=TensorMemoryLayout::INTERLEAVED, + .buffer_type=BufferType::L1, + }, device); + + if(first_run) { + dram_output_tensor = sliced_output_tensor; + } + else { + dram_output_tensor = ttnn::concat( + std::vector{dram_output_tensor, sliced_output_tensor}, 2, + MemoryConfig{ + .memory_layout=TensorMemoryLayout::INTERLEAVED, + .buffer_type=BufferType::DRAM, + }); + } + log_debug(tt::LogOp, "Dram output tensor shape: {}", dram_output_tensor.get_shape()); + first_run = false; + } + } + } + return {dram_output_tensor, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device}; + } + } auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required( device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels); if (tensor_manipulated) { diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 97302a83727..bb88e438d25 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -47,6 +47,7 @@ struct Conv2dConfig { bool enable_act_double_buffer = false; bool enable_split_reader = false; bool enable_subblock_padding = false; + uint32_t output_height_in_l1 = 0; static constexpr auto attribute_names = std::make_tuple( "math_fidelity", "dtype", @@ -68,7 +69,8 @@ struct Conv2dConfig { "output_layout", "enable_act_double_buffer", "enable_split_reader", - "enable_subblock_padding"); + "enable_subblock_padding", + "output_height_in_l1"); const auto attribute_values() const { return std::make_tuple( std::cref(this->math_fidelity), @@ -91,7 +93,8 @@ struct Conv2dConfig { std::cref(this->output_layout), std::cref(this->enable_act_double_buffer), std::cref(this->enable_split_reader), - std::cref(this->enable_subblock_padding)); + std::cref(this->enable_subblock_padding), + std::cref(this->output_height_in_l1)); } }; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 0e5a7711f64..5da54c77724 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -186,7 +186,7 @@ void py_bind_conv2d(py::module& module) { auto py_conv_config = py::class_(module, "Conv2dConfig"); py_conv_config.def( - py::init, bool, Layout, bool, bool, bool>(), + py::init, bool, Layout, bool, bool, bool, uint32_t>(), py::kw_only(), py::arg("math_fidelity") = MathFidelity::HiFi4, py::arg("dtype") = DataType::BFLOAT16, @@ -208,7 +208,8 @@ void py_bind_conv2d(py::module& module) { py::arg("output_layout") = Layout::TILE, py::arg("enable_act_double_buffer") = false, py::arg("enable_split_reader") = false, - py::arg("enable_subblock_padding") = false + py::arg("enable_subblock_padding") = false, + py::arg("output_height_in_l1") = 0 ); py_conv_config.def_readwrite("math_fidelity", &Conv2dConfig::math_fidelity); py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype); @@ -231,6 +232,8 @@ void py_bind_conv2d(py::module& module) { py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer); py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); + py_conv_config.def_readwrite("output_height_in_l1", &Conv2dConfig::output_height_in_l1); + py::class_(module, "OptimizedConvParallelizationConfig") .def( diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index 472b107a4e6..c572947f3c4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -123,7 +123,7 @@ Tensor concat_impl(std::vector &input_tensors, const std::int64_t dim, c "Current concat implementation requires aligned last dim when concatting on last dim"); } } - Layout target_layout = Layout::TILE; + Layout target_layout = Layout::ROW_MAJOR; for (const auto &input_tensor : input_tensors) { if (input_tensor.get_layout() == Layout::ROW_MAJOR) { const auto &input_shape = input_tensor.get_legacy_shape();