diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 0d6b8c4406fe..e014d50b8741 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -16,18 +16,16 @@ skip_for_blackhole, is_blackhole, ) -from tests.ttnn.utils_for_testing import ( - assert_with_pcc, - check_with_pcc, - check_with_pcc_without_tensor_printout, - update_process_id, -) +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout import ttnn -import math -import os -import torch.nn as nn +def _nearest_32(x): + return math.ceil(x / 32) * 32 + + +from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores + # def plot_diff(vals, fid, nsticks, stick_len): # import matplotlib.pyplot as plt @@ -42,31 +40,6 @@ # plt.close() -def write_to_file_special(file_name, tensor): - tensor = tensor.cpu().detach().numpy() - with open(file_name, "w") as f: - for i in range(1): - for j in range(tensor.shape[1]): - for k in range(tensor.shape[2] // 16): - # for l in range(tensor.shape[3]): - # f.write(str(round(tensor[i][j][k][l]), 2) + " ") - f.write("{:.2f}".format(tensor[i][j][k][0]) + " ") - if k % 14 == 13: - f.write("\n") - - -def write_to_file(file_name, tensor): - tensor = tensor.cpu().detach().numpy() - with open(file_name, "w") as f: - for i in range(1): - for j in range(tensor.shape[2]): - for k in range(tensor.shape[3]): - for l in range(1): - # f.write(str(round(tensor[i][j][k][l]), 2) + " ") - f.write("{:.2f}".format(tensor[i][l][j][k]) + " ") - f.write("\n") - - def run_conv( device, math_fidelity, @@ -100,37 +73,20 @@ def run_conv( shard_layout=None, use_max_cores=False, ): - # has_bias = False - # update_process_id() - has_bias = True torch.manual_seed(0) conv_input_shape = [batch_size, input_channels, input_height, input_width] conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width] conv_bias_shape = [1, 1, 1, output_channels] torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() - # torch_input_tensor_nchw = torch.zeros(conv_input_shape, dtype=torch.bfloat16).float() # torch_input_tensor_nchw = torch.ones(conv_input_shape, dtype=torch.bfloat16).float() # torch_input_tensor_nchw = torch.tensor(range(input_height * input_width)).reshape([1,1,input_height,input_width]).float() # torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float() torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) - # for i in range(1): - # for j in range(input_channels): - # for k in range(input_height): - # for l in range(input_width): - # torch_input_tensor[i][k][l][j] = 0.01 #if j < 1 else 0 torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float() # torch_weight_tensor = torch.ones(conv_weight_shape, dtype=torch.bfloat16).float() torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None - # for i in range(output_channels): - # for j in range(input_channels // groups): - # for k in range(filter_height): - # for l in range(filter_width): - # torch_weight_tensor[i][j][k][l] = 0.1 - # torch_weight_tensor = 0.0 - # torch_weight_tensor = torch.ones(conv_weight_shape, dtype=torch.bfloat16).float() - torch_bias_tensor = torch.zeros(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None torch_out_golden_tensor = torch.nn.functional.conv2d( torch_input_tensor_nchw, torch_weight_tensor, @@ -160,9 +116,6 @@ def run_conv( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) - # if shard_layout == ttnn.TensorMemoryLayout.HEIGHT_SHARDED: - # pytest.skip("Only testing for height and block sharding. need to remove this test") - if shard_layout is None: shard_layout = ( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED @@ -215,22 +168,9 @@ def run_conv( debug=debug, groups=groups, ) - # breakpoint() + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_output_tensor = ttnn.to_torch(tt_output_tensor) - # write_to_file_special("abc.pt", torch_output_tensor.float()) - # print(tt_output_tensor_on_device) - print(ttnn.get_memory_config(tt_output_tensor_on_device)) - # print(torch_output_tensor[0][0]) - # write_to_file("ref_hw.pt", torch_output_tensor.float()) - - # if enable_auto_formatting: - # torch_output_tensor = torch.split(torch_output_tensor, output_channels, 3)[0] - # torch_output_tensor = torch.reshape(torch_output_tensor, output_shape_nhwc) - # else: - # tt_output_tensor = conv.copy_output_from_device(tt_output_tensor_on_device) - # assert tt_output_tensor.layout == ttnn.ROW_MAJOR_LAYOUT - # torch_output_tensor = ttnn.to_torch(tt_output_tensor) # torch_output_tensor is in row major layout and NHWC shape # NHWC to NCHW @@ -246,12 +186,8 @@ def run_conv( pcc = 0.9969 else: pcc = 0.998 - passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) - # write_to_file("ref_cpu.pt", torch_out_golden_tensor) - # write_to_file("ref_hw.pt", torch_output_tensor.float()) - - print(pcc_msg) + logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") assert passing @@ -364,8 +300,6 @@ def run_conv_with_split( conv_op_cache=reader_patterns_cache, ) tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device) - print(tt_conv_output_tensor) - print(ttnn.get_memory_config(tt_conv_output_tensor)) torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor) print(f"Output shape : {batch_size} {out_height} {out_width} {output_channels}") torch_conv_output_tensor = torch_conv_output_tensor.reshape(batch_size, out_height, out_width, output_channels) @@ -674,6 +608,9 @@ def test_resnet50_conv_gs( use_1d_systolic_array, config_override, ): + if is_blackhole(): + pytest.skip("This test is for Grayskull only") + if batch_size > 8 and (activations_dtype != ttnn.bfloat8_b or weights_dtype != ttnn.bfloat8_b): pytest.skip("Batch > 8 must be run fully bfp8") if batch_size == 20 and input_channels >= 128 and filter_width > 1: @@ -724,22 +661,20 @@ def test_resnet50_conv_gs( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, None), HANGS!! - # (1, 64, 64, 8*7, 8*7, 3, 3, 1, 1, 1, 1, True, None), - # (16, 64, 64, 2 * 7, 2 * 7, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 256}), # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), Out of Memory!! # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passing - (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passed - (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passed - # # # rn50 layer2 + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer2 (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), # passed - # (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), # passed + (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, {"act_block_h": 32}), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed - (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed - (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed - # # rn50 layer3 + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + # rn50 layer3 (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), @@ -986,7 +921,6 @@ def test_resnet50_conv_wh_fp32( # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, None), # doesnt fit at all.. for all data types # sd convs with HxW=64x64 with batch size = 1 - # (1, 32 * 8, 32, 64, 64, 3, 3, 1, 1, 1, 1, True, None), (1, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), @@ -1018,8 +952,8 @@ def test_resnet50_conv_wh_fp32( # (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) # (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) # (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # # 1x1 conv - # (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, True, None), + # 1x1 conv + (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, True, None), # Small conv # (1, 32, 32, 16, 16, 3, 3, 2, 2, 1, 1, True, None), ## batch = 1 is currently not supported ), @@ -1103,7 +1037,6 @@ def test_sd_conv( ) -# @skip_for_wormhole_b0("Issue #7179: non-deterministically fails on N150 regression") @skip_for_grayskull() @skip_for_blackhole() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @@ -1341,6 +1274,9 @@ def test_unet_conv( use_shallow_conv_variant, output_layout, ): + if is_blackhole(): + pytest.skip("This test is for Grayskull only") + if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: pytest.skip("Row major layout not compatible with bfloat8_b") if output_layout == ttnn.ROW_MAJOR_LAYOUT and input_height >= 1056: @@ -2163,51 +2099,18 @@ def test_conv_for_vanilla_unet( ( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width - # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, None), HANGS!! - # (1, 64, 64, 8*7, 8*7, 3, 3, 1, 1, 1, 1, True, None), - (16, 64, 64, 2 * 7, 2 * 7, 3, 3, 1, 1, 1, 1, True, None), - # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), Out of Memory!! + (16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, True, None), # rn50 layer1 (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passing (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passed (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), # passed - # # rn50 layer2 - # (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), + # rn50 layer2 + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), # passed (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), # passed - # (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, {"act_block_h": 32}), (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), # passed - # # rn50 layer3 - # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - # (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - # (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - # # rn50 layer4 - # (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # ## small test - # (1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 2, "grid_size": (2, 2)}), - # (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 4, "grid_size": (2, 4)}), - # # (1, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, False, None), sliding_window_op_infra/sliding_window.cpp:341: indices_length_last_core <= indices_length_per_core - # (8, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # # r50 1x1s2 shapes - # # Fails with packer_l1_acc = True (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 first bottleneck downsample shape - # (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, True, None), # r50 first bottleneck downsample shape - # # Fails with packer_l1_acc = True (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 second bottleneck downsample shape - # # (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - # (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, False, None), # r50 third bottleneck downsample shape - # # (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - # (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, False, None), # r50 fourth bottleneck downsample shape - # # (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - # (20, 128, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), ## L2M1 DS: doesn't fit ), ) @pytest.mark.parametrize( @@ -2245,25 +2148,6 @@ def test_non_height_multiple_tile_conv_wh( ): if device.core_grid.y == 7: pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") - # if batch_size > 8 and (activations_dtype != ttnn.bfloat8_b or weights_dtype != ttnn.bfloat8_b): - # pytest.skip("Batch > 8 must be run fully bfp8") - - # if ( - # ( - # activations_dtype == ttnn.bfloat16 - # and batch_size == 20 - # and ( - # output_channels == 64 - # or ( - # stride_h == 2 - # and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) - # ) - # ) - # ) - # # packer l1 acc has separate buffers when interm != output df, cannot fit into L1 - # or (batch_size == 20 and activations_dtype == ttnn.bfloat8_b and packer_l1_acc and input_height >= 64) - # ): - # pytest.skip("Skipping test because it won't fit in L1!") use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0 run_conv( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 1d26d8b4571f..af8347839054 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -214,14 +214,13 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2]; uint32_t nhw_padded = nhw_shape; - // if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { - // nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); - // } + if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { + nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); + } uint32_t nhw_shard = nhw_padded / num_cores_nhw; TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); uint32_t channel_shard = channels / num_cores_channels; auto shard_spec = ShardSpec{parallel_config.grid, {nhw_shard, channel_shard}, shard_orientation}; - log_debug(tt::LogOp, "shard_spec in create_sharded_memory_config_from_parallel_config {} {} ", nhw_shard, channel_shard); return MemoryConfig{shard_scheme, BufferType::L1, shard_spec}; } @@ -231,10 +230,6 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); const auto& shard_spec = conv_output_mem_config.shard_spec.value(); const auto& shard_shape = shard_spec.shape; - //TT_ASSERT(conv_output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || shard_shape[0] % 32 == 0); - log_debug(tt::LogOp, "shard_shape in determine_conv_op_parallel_config_from_conv_output_mem_config {} {}", shard_shape[0], shard_shape[1]); - //TT_ASSERT(shard_shape[0] % 32 == 0); - //TT_ASSERT(shard_shape[1] % 32 == 0); uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); return { .grid_size = shard_spec.grid.bounding_box().grid_size(), @@ -411,7 +406,7 @@ std::tuple get_conv_padded_input_shape_an if (needs_shard_or_reshard) { uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; - uint32_t input_tensor_height_snapped_to_tile = (conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED)? tensor_height : tt::round_up(tensor_height, input_num_cores_nhw); + uint32_t input_tensor_height_snapped_to_tile; if(conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { input_tensor_height_snapped_to_tile = tensor_height; }else if(conv_config.use_max_cores) { @@ -430,11 +425,12 @@ std::tuple get_conv_padded_input_shape_an 1, input_tensor_height_snapped_to_tile, input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and + uint32_t tile_size = !conv_config.use_max_cores ? 32 : 1; auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{ input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), parallel_config, - 32); + tile_size); return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard}; } else { return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard}; @@ -501,7 +497,6 @@ std::tuple shard_or_reshard_tensor_if_requir input_tensor, device, input_tensor_sharded_memory_config); } } - //input_tensor.print(); return {input_tensor, parallel_config, needs_shard_or_reshard}; } @@ -698,13 +693,17 @@ std::tuple{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), parallel_config, - 32); + tile_size); auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( conv_out_memory_config, get_num_cores_nhw_from_parallel_config(parallel_config), get_num_cores_channels_from_parallel_config(parallel_config)); + if(conv_config.use_max_cores == false){ + TT_ASSERT(conv_out_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || opt_conv_op_parallel_config.per_core_out_matrix_height % 32 == 0); + } auto opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, @@ -830,7 +829,7 @@ std::tuple -inline void reblock_and_untilize_testing( +inline void reblock_and_untilize_rows( uint32_t num_out_subblocks_in_col, uint32_t out_subblock_num_tiles, uint32_t out_subblock_h, uint32_t output_rows_h, uint32_t interm_cb_id, uint32_t out_cb_id) { - UNPACK(DPRINT << num_out_subblocks_in_col << " " << out_subblock_num_tiles << " "<< out_subblock_h << " " << out_subblock_w << " "<< out_block_w << ENDL()); - //print_full_tile(interm_cb_id, 0, false); uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); - //uint32_t output_rows = 49; cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); uint32_t within_block_index = 0; @@ -119,12 +77,11 @@ inline void reblock_and_untilize_testing( } cb_push_back(out_cb_id, test_out_block_w); output_rows_h -= test_out_block_w; - within_block_index += out_subblock_w; } cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); - //print_full_tile(out_cb_id, 0, false); } +#endif template @@ -132,13 +89,12 @@ inline void reblock_and_untilize( uint32_t num_out_subblocks_in_col, uint32_t out_subblock_num_tiles, uint32_t out_subblock_h, - //uint32_t output_rows_h, uint32_t interm_cb_id, uint32_t out_cb_id) { uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); - UNPACK(DPRINT << "num_out_subblocks_in_col " << num_out_subblocks_in_col << " out_subblock_num_tiles " << out_subblock_num_tiles << "out_subblock_w " << out_subblock_w << ENDL()); + uint32_t within_block_index = 0; for (uint32_t h = 0; h < out_subblock_h; h++) { uint32_t block_offset = 0; @@ -157,7 +113,7 @@ inline void reblock_and_untilize( block_offset += out_subblock_num_tiles; } cb_push_back(out_cb_id, out_block_w); - //print_full_tile(out_cb_id); + within_block_index += out_subblock_w; } cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); @@ -204,18 +160,14 @@ void MAIN { constexpr uint32_t out_cb_id = tt::CB::c_out0; constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? matmul_partials_cb : out_cb_id; - //if() - UNPACK(DPRINT << "untilize_out " << (int)(untilize_out) << ENDL()); #ifdef FUSE_BIAS constexpr uint32_t bias_ntiles_w = get_compile_time_arg_val(16); constexpr uint32_t bias_cb_id = tt::CB::c_in2; uint32_t bias_block_offset = 0; constexpr uint32_t mm_out_cb_id = matmul_partials_cb; - UNPACK(DPRINT << "Fuse Bias" << ENDL()); #else constexpr uint32_t mm_out_cb_id = untilize_mode_out_cb_id; - UNPACK(DPRINT << "fuse bias false" << ENDL()); #endif @@ -228,12 +180,12 @@ void MAIN { constexpr uint32_t in0_num_subblocks_read = in0_num_subblocks; #endif + mm_block_init(mm_in0_cb_id, in1_cb_id, out_cb_id, false, out_subblock_w, out_subblock_h, in0_block_w); #ifdef SFPU_OP_INIT_ACTIVATION SFPU_OP_INIT_ACTIVATION #endif // in1 num blocks w is the outer loop. Output blocks are computed in col major order. - DPRINT << "in1_num_blocks_w " << in1_num_blocks_w << " "<< in0_num_blocks_h << ENDL(); for(uint32_t in1_block_w_i = 0; in1_block_w_i < in1_num_blocks_w; ++in1_block_w_i) { for(uint32_t in0_block_h_i = 0; in0_block_h_i < in0_num_blocks_h; ++in0_block_h_i) { @@ -471,13 +423,11 @@ void MAIN { cb_reserve_back(untilize_mode_out_cb_id, out_subblock_num_tiles); tile_regs_wait(); - //UNPACK(DPRINT << "untiled out result " << out_subblock_num_tiles << ENDL()); for (uint32_t i = 0; i < out_subblock_num_tiles; i++) { pack_tile(i, untilize_mode_out_cb_id); } tile_regs_release(); cb_push_back(untilize_mode_out_cb_id, out_subblock_num_tiles); - //UNPACK(DPRINT << "untiled out result" << ENDL()); in1_index_subblock_offset += out_subblock_w; } // for in1_num_subblocks @@ -495,52 +445,29 @@ void MAIN { unpack_reconfig_data_format_srca(in1_cb_id, matmul_partials_cb); #endif - #ifndef USE_MAX_CORES - pack_untilize_dst_init_short(out_cb_id); - copy_tile_to_dst_init_short(); - for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - //uint32_t output_rows_h_testing = output_rows_h_1 < 32*out_subblock_h ? output_rows_h_1 : 32*out_subblock_h; - reblock_and_untilize ( - in1_num_subblocks, - out_subblock_num_tiles, - out_subblock_h, - matmul_partials_cb, - out_cb_id); - - //output_rows_h_1 -= output_rows_h_testing; - } - #else pack_untilize_dst_init_short(out_cb_id); copy_tile_to_dst_init_short(); uint32_t output_rows_h_1 = output_rows_h; - - UNPACK(DPRINT << "TESTING " << in0_num_subblocks << " " << in1_num_subblocks << " " << out_subblock_num_tiles << " " << out_subblock_h << " " << out_subblock_w << " " << out_block_w << " output_rows_h " << output_rows_h << ENDL()); for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { + #ifdef USE_MAX_CORES uint32_t output_rows_h_testing = output_rows_h_1 < 32*out_subblock_h ? output_rows_h_1 : 32*out_subblock_h; - //if(use_max_cores){ - reblock_and_untilize_testing ( + reblock_and_untilize_rows ( in1_num_subblocks, out_subblock_num_tiles, out_subblock_h, output_rows_h_testing, matmul_partials_cb, out_cb_id); - /*}else{ + output_rows_h_1 -= output_rows_h_testing; + #else reblock_and_untilize ( in1_num_subblocks, out_subblock_num_tiles, out_subblock_h, - output_rows_h_testing, matmul_partials_cb, out_cb_id); - }*/ - - output_rows_h_1 -= output_rows_h_testing; + #endif } - #endif - // uint32_t output_rows_h_1 = output_rows_h; - - // DPRINT << "completed" << ENDL(); pack_untilize_uninit(matmul_partials_cb); } if constexpr((in1_num_blocks_w > 1 || in0_num_blocks_h > 1)) { @@ -551,7 +478,6 @@ void MAIN { #endif if constexpr (!tilize_in0) { - UNPACK(DPRINT << "TESTING " <<__LINE__ << ENDL()); mm_block_init_short(mm_in0_cb_id, in1_cb_id, false, out_subblock_w, out_subblock_h, in0_block_w); } } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp index 3e27e00ea5d1..c6b912781532 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_conv_activations_padded_with_halo_3x3_weights_v2.cpp @@ -7,59 +7,6 @@ #include "firmware_common.h" #define DILATION_W get_compile_time_arg_val(4) -#define DEBUG_PRINT 1 - -#if DEBUG_PRINT == 1 - #include "debug/dprint.h" - // #include "debug_macros.h" - - // SliceRange srt = SliceRange{.h0 = 0, .h1 = 32, .hs = 8, .w0 = 0, .w1 = 32, .ws = 4}; - // SliceRange srr = SliceRange{.h0 = 0, .h1 = 1, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1}; - // SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1}; - // SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1}; - - // inline void print_tile_rows(uint32_t cb_id, uint32_t rows = 32, uint32_t tile_id = 0, bool untilize = false) { - // // UNPACK(( DPRINT << "======" << ENDL() )); - // for (uint16_t r = 0; r < rows; ++ r) { - // SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1}; - // // UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() )); - // UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) )); - // } - // // UNPACK(( DPRINT << "++++++" << ENDL() )); - // } - - // inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize = false) { - // UNPACK(( DPRINT << "======" << ENDL() )); - // for (uint16_t r = 0; r < 32; ++ r) { - // SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r+1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1}; - // UNPACK(( DPRINT << (uint)r << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() )); - // } - // UNPACK(( DPRINT << "++++++" << ENDL() )); - // } - - inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++ page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++ j, ++ ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } - } - - // inline void print_cb_details(uint32_t cb_id) { - // DPRINT << "cb_id " << cb_id << ": { " - // << "size: " << cb_interface[cb_id].fifo_size << ", " - // << "limit: " << cb_interface[cb_id].fifo_limit << ", " - // << "page_size: " << cb_interface[cb_id].fifo_page_size << ", " - // << "num_pages: " << cb_interface[cb_id].fifo_num_pages << ", " - // << "rd_ptr: " << cb_interface[cb_id].fifo_rd_ptr << ", " - // << "wr_ptr: " << cb_interface[cb_id].fifo_wr_ptr << ", " - // << "wr_tile_ptr: " << cb_interface[cb_id].fifo_wr_tile_ptr << " }" << ENDL(); - // } -#endif - void kernel_main() { constexpr bool act_in_dram = get_compile_time_arg_val(0)== 1; constexpr uint32_t stride_h = get_compile_time_arg_val(1); @@ -166,10 +113,6 @@ void kernel_main() { act_l1_offset = reader_offset + (reader_idx_1 * conv_act_c_read_bytes); noc_async_read_one_packet_with_state(act_l1_offset, l1_write_addr_act); l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes); - // if(bhd == 0){ - // DPRINT << "page " << bhd << ENDL(); - // print_pages(act_l1_offset, coalesced_read_bytes/2, 1, 0); - // } act_l1_offset = reader_offset + (reader_idx_2 * conv_act_c_read_bytes); noc_async_read_one_packet_with_state(act_l1_offset, l1_write_addr_act); @@ -181,6 +124,7 @@ void kernel_main() { l1_write_addr_act += conv_act_c_read_bytes; act_l1_offset += stride_w_bytes; } + act_l1_offset = reader_offset + (reader_idx_2 * conv_act_c_read_bytes); for(uint32_t inner = 0; inner < weight_size_w; inner++) { noc_async_read_one_packet_with_state(act_l1_offset, l1_write_addr_act); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index 4f9bc97a5d11..dfa846cae51a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -251,6 +251,10 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT + #ifndef USE_MAX_CORES cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + #else + cb_wait_front(cb_id_out0, output_rows_h); + #endif #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index f54b28bea49e..5b4f77456ecf 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/reader_writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -369,6 +369,10 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT + #ifndef USE_MAX_CORES cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + #else + cb_wait_front(cb_id_out0, output_rows_h); + #endif #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index d089a5671fef..d29d233da41b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -4,17 +4,7 @@ #include "dataflow_api.h" -#include "debug/dprint.h" -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++ page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++ j, ++ ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } -} +//#include "debug/dprint.h" void kernel_main() { // This writer is for output tensor in tile format @@ -223,9 +213,6 @@ void kernel_main() { #ifdef SHARDED_OUT #ifndef USE_MAX_CORES - //DPRINT << out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h << ENDL(); - //DPRINT << "SHARDED_OUT_NOT_SUPPORTED" << ENDL(); - //print_pages( get_read_ptr(cb_id_out0), 64, 64, 0); cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #else cb_wait_front(cb_id_out0, output_rows_h); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index 465ea992d24a..a24a41c9196b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_1d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -331,12 +331,10 @@ void kernel_main() { weight_start_tile_id += weight_next_block_stride_w; } // out_num_blocks_w #ifdef SHARDED_OUT - //DPRINT << "1 SHARDED_OUT_NOT_SUPPORTED" << ENDL(); - #ifndef USE_MAX_CORES - //DPRINT << out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h << ENDL(); - cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); - #else - cb_wait_front(cb_id_out0, output_rows_h); - #endif + #ifndef USE_MAX_CORES + cb_wait_front(cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); + #else + cb_wait_front(cb_id_out0, output_rows_h); + #endif #endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp index 924da3529bce..c68e8d0fe463 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp @@ -177,19 +177,15 @@ std::tuple create_CBs_for_sharded_input_v2( // Supposed to be a small CB only responsible for reorganizing // the output blocks to fill the whole "per core output block width" - //std::cout << "num_reblock_cb_tiles: " << num_reblock_cb_tiles * 8 << std::endl; CircularBufferConfig cb_reblock_config = CircularBufferConfig(num_reblock_cb_tiles * out_tile_size, {{untilize_mode_reblock_cb, out_df}}) .set_page_size(untilize_mode_reblock_cb, out_tile_size); auto cb_reblock = tt_metal::CreateCircularBuffer(program, core, cb_reblock_config); log_debug(LogOp, "Reblock CB: {}, npages: {}, pagesize: {}", untilize_mode_reblock_cb, num_reblock_cb_tiles, out_tile_size); - //std::cout << "num_writer_output_tiles: " << num_writer_output_tiles << std::endl; auto shard_shape = output.shard_spec().value().shape; uint32_t aligned_output_stick_nbytes = use_max_cores ? shard_shape[1] * output.element_size() : out_tile_size; uint32_t aligned_output_num_pages = use_max_cores ? shard_shape[0] : num_writer_output_tiles; - //uint32_t aligned_output_num_pages = shard_shape[0]; - log_debug(LogOp , "output CB --> {} {}", aligned_output_num_pages, aligned_output_stick_nbytes); CircularBufferConfig cb_output_config = CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) .set_page_size(out0_cb, aligned_output_stick_nbytes); @@ -1050,7 +1046,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( // TODO: Moving this function call to after kernel logic causes pcc fails // There are additional CBs and semaphores created in 2D conv in kernel logic, // so does order of create_cb calls matter? - //uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); input_output_cbs = create_CBs_for_sharded_input_v2( program, a, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp index 76db9187da48..9204665531a2 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/optimized_conv_op_program_factory.cpp @@ -149,7 +149,6 @@ std::vector OptimizedConvNew::compute_output_shapes(c auto shape_c = output_channels; auto padded_shape_w = this->use_max_cores ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height : parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntiles * TILE_HEIGHT; auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); - //std::cout << "testing values " << " output chaneel " << shape_c << " padded_shape_c " << padded_shape_c << std::endl; auto output_padding = Padding( {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); auto output_tensor_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding)); @@ -170,10 +169,8 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectoruse_max_cores){ num_cores = this->parallelization_config.num_cores_nhw; uint32_t total_height = tt::tt_metal::compute_volume(output_shape) / output_shape[-1]; - //std::cout << "num_cores " << num_cores << " total_height " << total_height << " output_channel " << output_shape[-1] << std::endl; shard_grid = tt::tt_metal::num_cores_to_corerange_set(num_cores, this->parallelization_config.grid_size, true); shard_shape = {(uint32_t)(total_height / num_cores), output_shape[-1]}; - log_debug(tt::LogOp, "REMOVE_THIS shard_shape {} {}", shard_shape[0], shard_shape[1]); } auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; auto mem_config = this->memory_config;