From c1b3b6a77dfed57f00a1b8172a79f76833b6035d Mon Sep 17 00:00:00 2001 From: Pavle Josipovic Date: Tue, 5 Nov 2024 10:20:56 +0000 Subject: [PATCH] Better parallelization strategy --- .../ttnn_functional_resnet50_new_conv_api.py | 64 +++++- .../ttnn_functional_downsample_2d_new_conv.py | 37 ++- .../ttnn_functional_resnetblock2d_new_conv.py | 55 ++++- models/demos/yolov4/ttnn/downsample1.py | 2 +- .../functional_unet/tt/unet_shallow_ttnn.py | 29 ++- .../sweeps/conv2d/short/conv2d_short_sweep.py | 40 ++-- .../unit_tests/operations/test_maxpool2d.py | 2 + .../unit_tests/operations/test_new_conv2d.py | 12 +- tt_metal/third_party/sfpi | 1 + .../ttnn/operations/conv/conv2d/conv2d.cpp | 211 +++++++++++------- .../ttnn/operations/conv/conv2d/conv2d.hpp | 24 +- .../operations/conv/conv2d/conv2d_pybind.cpp | 64 +----- .../conv/conv2d/device/conv2d_op.cpp | 5 +- .../conv2d_op_sharded_program_factory.cpp | 20 +- ...onv2d_op_width_sharded_program_factory.cpp | 3 - .../conv_transpose2d/conv_transpose2d.cpp | 19 +- .../operations/matmul/device/matmul_op.cpp | 3 +- .../operations/pool/maxpool/max_pool2d.cpp | 27 ++- ttnn/ttnn/__init__.py | 2 +- ttnn/ttnn/operations/conv2d.py | 1 - 20 files changed, 385 insertions(+), 236 deletions(-) create mode 160000 tt_metal/third_party/sfpi diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 3a5c75967e9f..f034f1348d52 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -8,12 +8,37 @@ is_grayskull, is_wormhole_b0, _nearest_y, + nearest_y, pad_and_fold_conv_activation_for_unity_stride, ) from typing import List from loguru import logger from tests.ttnn.utils_for_testing import assert_with_pcc + +def get_core_grid_from_num_cores(num_cores: int, grid_rows: int, grid_cols: int): + columns = num_cores // grid_rows + assert columns <= grid_cols, "Not enough cores for specified core grid" + ranges = [] + if columns != 0: + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_rows - 1, columns - 1), + ) + ) + remainder = num_cores % grid_rows + if remainder != 0: + assert columns + 1 <= grid_cols, "Not enough cores for specified core grid" + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, columns), + ttnn.CoreCoord(remainder - 1, columns), + ) + ) + return ttnn.CoreRangeSet({*ranges}) + + hardcoded_matmul_config_linear = { 8: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( compute_with_storage_grid_size=(8, 4), @@ -632,15 +657,38 @@ def __init__( conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16) conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) - _, self.override_fold_mem_config, _, _ = ttnn.get_conv_padded_input_shape_and_mem_config( - device=device, - input_tensor=conv_dummy_tensor, - conv_config=self.conv1_config, + + parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=self.batch_size, - height=self.conv1_output_height, - width=self.conv1_output_width, - in_channels=self.conv1_input_channels, - out_channels=self.conv1_output_channels, + input_channels=self.conv1_input_channels, + output_height=self.conv1_output_height, + output_width=self.conv1_output_width, + output_channels=self.conv1_output_channels, + compute_grid_size=device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=True, + is_out_tiled=True, + ) + # / Override compute grid size for Grayskull + # First convs would got to 108 cores by default + # but this would add padding into output tensor + # and reshard that follows first conv fails with padding ATM. + if is_grayskull(): + compute_grid = device.compute_with_storage_grid_size() + parallel_config.grid = get_core_grid_from_num_cores(98, compute_grid.x, compute_grid.y) + + self.override_fold_mem_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.conv1_input_width * self.conv1_input_height * self.batch_size, + nearest_y(self.conv1_input_channels, self.conv1_config.input_channels_alignment), + ] + ), + parallel_config=parallel_config, + tile_size=32, ) def __del__(self): diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 3635026d8094..57a757afb009 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -74,6 +74,33 @@ def __init__( self.output_height = ttnn.get_conv_output_dim(input_height, 3, self.stride, 1) self.output_width = ttnn.get_conv_output_dim(input_width, 3, self.stride, 1) + self.shard_layout = ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if self.in_channels < 320 else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ) + + self.input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.input_height * self.input_width, + self.out_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=self.shard_layout, + batch_size=self.batch_size, + input_channels=self.in_channels, + output_height=self.output_height, + output_width=self.output_width, + output_channels=self.out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=False, + is_out_tiled=True, + ), + tile_size=32, + ) def __call__( self, @@ -104,13 +131,15 @@ def __call__( math_approx_mode_enabled=True, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if self.in_channels < 320 - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=self.shard_layout, input_channels_alignment=32, transpose_shards=False, - reshard_if_not_optimal=True, + reshard_if_not_optimal=False, ) + + if hidden_states.memory_config() != self.input_memory_config: + hidden_states = ttnn.to_memory_config(hidden_states, self.input_memory_config) + if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 4e63fc9b13cb..c36db8c77b3f 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -106,6 +106,33 @@ def __init__( self.conv1_input_width = input_width self.conv1_in_channels = split_input_channels self.conv1_out_channels = out_channels + self.conv1_output_height = ttnn.get_conv_output_dim(self.conv1_input_height, 3, 1, 1) + self.conv1_output_width = ttnn.get_conv_output_dim(self.conv1_input_width, 3, 1, 1) + self.conv1_shard_layout = ttnn.TensorMemoryLayout.BLOCK_SHARDED + + self.conv1_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.conv1_input_height * self.conv1_input_width, + self.conv1_in_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=self.conv1_shard_layout, + batch_size=self.batch_size, + input_channels=self.conv1_in_channels, + output_height=self.conv1_output_height, + output_width=self.conv1_output_width, + output_channels=self.conv1_out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=False, + is_out_tiled=True, + ), + tile_size=32, + ) for i in range(conv1_split_chunks): self.conv1s_weights.append(ttnn.from_torch(split_weight_tensors[i], ttnn.float32)) @@ -165,6 +192,29 @@ def __init__( self.conv2_in_channels = parameters.conv2.weight.shape[1] self.conv2_out_channels = parameters.conv2.weight.shape[0] # self.conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)] + self.conv2_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.batch_size * self.conv2_input_height * self.conv2_input_width, + out_channels, + ] + ), + parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + batch_size=self.batch_size, + input_channels=self.conv2_in_channels, + output_height=self.conv2_input_height, + output_width=self.conv2_input_width, + output_channels=self.conv2_out_channels, + compute_grid_size=self.device.compute_with_storage_grid_size(), + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=False, + is_out_tiled=True, + ), + tile_size=32, + ) self.groups = 32 # if use_in_shortcut: @@ -402,12 +452,14 @@ def __call__( # hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states)) # hidden_states = self.conv1s[0](hidden_states) + hidden_states = ttnn.to_memory_config(hidden_states, self.conv1_input_memory_config) + conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, math_fidelity=ttnn.MathFidelity.LoFi, activation="", - shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=self.conv1_shard_layout, math_approx_mode_enabled=True, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, @@ -598,6 +650,7 @@ def __call__( # hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG) hidden_states = ttnn.sharded_to_interleaved(hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype) + hidden_states = ttnn.to_memory_config(hidden_states, self.conv2_input_memory_config) # hidden_states = self.conv2(hidden_states) conv_config = ttnn.Conv2dConfig( diff --git a/models/demos/yolov4/ttnn/downsample1.py b/models/demos/yolov4/ttnn/downsample1.py index cc2f2cff37f7..9937457fa94e 100644 --- a/models/demos/yolov4/ttnn/downsample1.py +++ b/models/demos/yolov4/ttnn/downsample1.py @@ -48,7 +48,7 @@ def __call__(self, device, input_tensor): output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) output_tensor_left = ttnn.to_layout(output_tensor_left, layout=ttnn.ROW_MAJOR_LAYOUT) output_sharded_memory_config = ttnn.create_sharded_memory_config( - [512, 128], + [output_tensor.memory_config().shard_spec.shape[0], 2 * output_tensor.memory_config().shard_spec.shape[1]], core_grid=output_tensor_left.memory_config().shard_spec.grid, strategy=ttnn.ShardStrategy.HEIGHT, use_height_and_width_as_shard_shape=True, diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 3d47538c4e5a..08c4aea31a37 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -246,6 +246,7 @@ def __init__( conv_cache={}, should_reshard=False, mesh_mapper=None, + conv_override_p_config=False, ): self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) self.conv2 = UNetConv2D( @@ -268,8 +269,15 @@ def __init__( output_channels=self.conv1.out_channels, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=True, is_out_tiled=True, ) + + if conv_override_p_config: + pconfig_override = pool["parallel_config_override"] + num_cores_nhw = pconfig_override["num_cores_nhw"] + parallel_config.grid = get_core_grid_from_num_cores(num_cores_nhw) + self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( tensor_shape=ttnn.Shape( [ @@ -304,7 +312,18 @@ def __call__(self, x): class UNetUpblock: def __init__( - self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False, mesh_mapper=None + self, + conv1, + bn1, + conv2, + bn2, + conv3, + bn3, + device, + conv_cache={}, + should_reshard=False, + mesh_mapper=None, + nhw_core_override=-1, ): self.device = device self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper) @@ -322,8 +341,13 @@ def __init__( output_channels=self.conv1.out_channels, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=True, is_out_tiled=True, ) + + if nhw_core_override != -1: + parallel_config.grid = get_core_grid_from_num_cores(nhw_core_override) + self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( tensor_shape=ttnn.Shape( [ @@ -413,6 +437,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: conv_cache=self.conv_cache, should_reshard=True, mesh_mapper=mesh_mapper, + conv_override_p_config=True, ) self.downblock3 = UNetDownblock( parameters.c3, @@ -450,6 +475,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: output_channels=self.bnc.out_channels, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=True, is_out_tiled=True, ) self.bnc_sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( @@ -500,6 +526,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: conv_cache=self.conv_cache, should_reshard=True, mesh_mapper=mesh_mapper, + nhw_core_override=60, ) self.upblock4 = UNetUpblock( parameters.c8, diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 0f3176775cd0..5d6024474d50 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -454,30 +454,22 @@ def test_conv2d_localrun(device, input_spec): failing_parameters = [ # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] # Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM - [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6 - [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14 - [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15 - [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141 - [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170 - [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171 - [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 173 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 182 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 183 - [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 199 - [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 205 - [1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], # 241 - [1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1], # 245 - [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], # 247 - [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292 - [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293 - [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294 - [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347 - [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348 - [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363 - [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 728, False, 1], # 366 - [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 367 - [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 374 - [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 395 + [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 5 + [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 14 + [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 169 + [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 170 + [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 172 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 181 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 182 + [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 198 + [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 203 + [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204 + [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 292 + [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 366 + [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 373 + [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 374 + [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 394 + [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 395 ] diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 6dab62917624..6ecfd31c33d1 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -158,6 +158,7 @@ def run_max_pool( output_channels=in_c, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=False, is_out_tiled=False, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( @@ -744,6 +745,7 @@ def test_pool_core_nondivis( output_channels=in_c, compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, + is_conv2d_op=False, is_out_tiled=True, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index f760b407b855..ff502a3c8196 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -137,7 +137,7 @@ def run_conv( enable_subblock_padding=False, output_layout=output_layout, ) - if config_override and "act_block_h" in config_override: + if config_override and "act_block_h" in config_override and not auto_shard: conv_config.act_block_h_override = config_override["act_block_h"] if config_override and "act_block_w_div" in config_override: @@ -183,10 +183,8 @@ def run_conv( if not fp32_accum: pcc = 0.985 - elif math_fidelity == ttnn.MathFidelity.LoFi and activations_dtype == ttnn.bfloat8_b: - pcc = 0.996 else: - pcc = 0.997 + pcc = 0.995 passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}") @@ -1426,9 +1424,9 @@ def test_sd_conv_wh( False, ), # fails. mismatch. It passes when input_channels=64. Probably an issue with padding when input_channels % 32 != 0. (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), - (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 22 * 32}, False), + (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), + (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), ), ) @pytest.mark.parametrize( diff --git a/tt_metal/third_party/sfpi b/tt_metal/third_party/sfpi new file mode 160000 index 000000000000..1aab81a8dddd --- /dev/null +++ b/tt_metal/third_party/sfpi @@ -0,0 +1 @@ +Subproject commit 1aab81a8ddddf94702fad18efec0c1c163c96bab diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 4dca42fbd886..2c9a8fa53425 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -8,7 +8,9 @@ #include #include "common/constants.hpp" +#include "common/math.hpp" #include "impl/buffers/buffer_constants.hpp" +#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" @@ -18,6 +20,7 @@ #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" using namespace tt; namespace ttnn { @@ -27,25 +30,37 @@ using sliding_window::ParallelConfig; namespace conv2d { -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; while (num % divisor != 0) divisor = divisor - 1; return divisor; } -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { + uint32_t divisor = start_divisor; + while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + return divisor; +} + +static uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor) { uint32_t divisor = start_divisor; uint32_t padded_num = round_up(num, divisor); - while ((padded_num - num) >= (int)(padded_num / divisor)) { + while ((padded_num - num) >= padded_num / divisor) { divisor = divisor - 1; padded_num = round_up(num, divisor); } return divisor; } -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor) { +static uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num1, uint32_t num2, uint32_t start_divisor) { uint32_t divisor = start_divisor; - while (num1 % divisor != 0 or num2 % divisor != 0) divisor = divisor - 1; + uint32_t padded_num1 = round_up(num1, divisor); + uint32_t padded_num2 = round_up(num2, divisor); + while ((padded_num1 - num1) >= (padded_num1 / divisor) || (padded_num2 - num2) >= (padded_num2 / divisor)) { + divisor = divisor - 1; + padded_num1 = round_up(num1, divisor); + padded_num2 = round_up(num2, divisor); + } return divisor; } @@ -83,6 +98,7 @@ ParallelConfig determine_parallel_config( uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool is_conv2d_op, bool is_out_tiled) { uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; @@ -92,26 +108,35 @@ ParallelConfig determine_parallel_config( // calculate num_core_nhw and the grid uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - uint32_t num_cores_nhw = 0; CoreRangeSet grid; if (shard_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - num_cores_nhw = find_closest_largest_divisor(out_nhw_ntiles, max_num_cores); - if (num_cores_nhw < compute_grid_size.x && out_nhw_ntiles > compute_grid_size.x) { - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, compute_grid_size.x); - } + uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, max_num_cores); grid = num_cores_to_corerangeset(num_cores_nhw, compute_grid_size, true); } else if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED) { uint32_t start_divisor = block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.x : compute_grid_size.y; - num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); - uint32_t num_cores_c = find_closest_common_largest_divisor(out_c_ntiles, std::ceil((float)input_channels / effective_tile_width), block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y : compute_grid_size.x); + uint32_t num_cores_nhw = find_closest_largest_divisor_with_num_padding(out_nhw_ntiles, start_divisor); + uint32_t num_cores_c = is_conv2d_op + ? find_closest_largest_divisor_with_num_padding( + out_c_ntiles, + std::ceil((float)input_channels / effective_tile_width), + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y + : compute_grid_size.x) + : find_closest_largest_divisor( + out_c_ntiles, + std::ceil((float)input_channels / effective_tile_width), + block_shard_orientation == ShardOrientation::COL_MAJOR ? compute_grid_size.y + : compute_grid_size.x) + ; uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); grid = CoreRangeSet({core_range}); } else if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED) { - num_cores_nhw = 1; - uint32_t num_cores_c = find_closest_largest_divisor(std::ceil((float)input_channels / effective_tile_width), max_num_cores); + uint32_t input_channles_ntiles = tt::div_up(input_channels, effective_tile_width); + uint32_t num_cores_c = is_conv2d_op + ? find_closest_largest_divisor_with_num_padding(input_channles_ntiles, max_num_cores) + : find_closest_largest_divisor(input_channles_ntiles, max_num_cores); grid = num_cores_to_corerangeset(num_cores_c, compute_grid_size, true); } else { TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); @@ -126,6 +151,26 @@ ParallelConfig determine_parallel_config( return pconfig; } +static ParallelConfig determine_output_parallel_config( + const ParallelConfig& input_parallel_config, + const CoreCoord& compute_grid_size, + uint32_t out_channels, + bool is_mm_conv) { + ParallelConfig output_parallel_config = input_parallel_config; + if (input_parallel_config.shard_scheme == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !is_mm_conv) { + uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; + output_parallel_config = { + .grid = num_cores_to_corerangeset( + find_closest_largest_divisor_with_num_padding( + tt::div_up(out_channels, tt::constants::TILE_WIDTH), max_num_cores), + compute_grid_size, + true), + .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, + .shard_orientation = input_parallel_config.shard_orientation}; + } + return output_parallel_config; +} + uint32_t get_num_cores_nhw_from_parallel_config(const ParallelConfig& pconfig) { TT_ASSERT(!pconfig.grid.ranges().empty()); TT_ASSERT( @@ -192,6 +237,7 @@ MemoryConfig create_sharded_memory_config_from_parallel_config( if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) { nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size); } + uint32_t nhw_shard = nhw_padded / num_cores_nhw; TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels); uint32_t channel_shard = channels / num_cores_channels; @@ -254,6 +300,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( const ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, + uint32_t padded_output_height_ntiles, uint32_t act_block_h_override, uint32_t act_block_w_div, uint32_t window_h, @@ -268,11 +315,17 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } auto grid_size = parallel_config.grid.bounding_box().grid_size(); uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntiles; - if(act_block_h_override > 0) { - if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) { - log_info(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); + if (parallel_config.shard_scheme != TensorMemoryLayout::WIDTH_SHARDED && act_block_h_override > 0) { + log_debug(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used"); + uint32_t act_block_h_override_ntiles = act_block_h_override / constants::TILE_HEIGHT; + if (padded_output_height_ntiles % act_block_h_override_ntiles == 0) { + act_block_h_ntiles = act_block_h_override_ntiles; } else { - act_block_h_ntiles = act_block_h_override / constants::TILE_HEIGHT; + log_info( + LogOp, + "act_block_h_override_ntiles {} is not a valid override for padded_output_height_ntiles {}", + act_block_h_override_ntiles, + padded_output_height_ntiles); } } @@ -324,15 +377,15 @@ static TensorMemoryLayout select_shard_spec( const CoreCoord& compute_grid_size) { auto get_core_count_for_sharding = [&](TensorMemoryLayout shard_layout) { return determine_parallel_config( - shard_layout, - batch_size, - in_channels, - output_height, - output_width, - out_channels, - compute_grid_size, - shard_orientation) - .grid.num_cores(); + shard_layout, + batch_size, + in_channels, + output_height, + output_width, + out_channels, + compute_grid_size, + shard_orientation, + !is_mm_conv).grid.num_cores(); }; // 1d convs support only height sharding @@ -349,7 +402,7 @@ static TensorMemoryLayout select_shard_spec( // Prefer block sharding over height sharding but make sure that we got at least // some blocking on width dimension as well. - if (cc_height > max_cc || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { + if ((cc_height > max_cc && max_cc < 48) || (cc_height == max_cc && cc_height <= compute_grid_size.x)) { shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; max_cc = cc_height; } @@ -384,7 +437,8 @@ std::tuple get_conv_padded_input_sh uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels) { + uint32_t out_channels, + bool is_mm_conv) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -456,7 +510,16 @@ std::tuple get_conv_padded_input_sh auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; ParallelConfig optimal_parallel_config = determine_parallel_config( - shard_layout, batch_size, in_channels, height, width, out_channels, device->compute_with_storage_grid_size(), block_shard_orientation, !use_non_tile_height); + shard_layout, + batch_size, + in_channels, + height, + width, + out_channels, + device->compute_with_storage_grid_size(), + block_shard_orientation, + !is_mm_conv, + !use_non_tile_height); if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); @@ -477,15 +540,15 @@ std::tuple get_conv_padded_input_sh } if (needs_shard_or_reshard) { uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); + uint32_t input_num_cores_c = get_num_cores_channels_from_parallel_config(parallel_config); + // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; uint32_t round_up_size = (use_non_tile_height || conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) ? 1 : tt::constants::TILE_HEIGHT; uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); - uint32_t tensor_width = input_tensor.get_shape()[3]; uint32_t input_tensor_width_snapped_to_channels_alignment = - tt::round_up(tensor_width, conv_config.input_channels_alignment); - TT_ASSERT(input_tensor_width_snapped_to_channels_alignment >= tensor_width); + tt::round_up(input_tensor.get_shape()[3], input_num_cores_c * conv_config.input_channels_alignment); auto input_padded_shape = ttnn::Shape(std::array{ 1, @@ -493,10 +556,12 @@ std::tuple get_conv_padded_input_sh input_tensor_height_snapped_to_tile, input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and // tt::tt_metal::LegacyShape issue to clean up next line - auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( + MemoryConfig input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{ - input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), - parallel_config, round_up_size); + input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}), + parallel_config, + round_up_size); + return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; } else { return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height}; @@ -527,23 +592,17 @@ std::tuple shard_or_re height, width, in_channels, - out_channels); + out_channels, + is_mm_conv); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, .shard_orientation = input_tensor_sharded_memory_config.shard_spec.value().orientation }; - auto shard_layout = input_tensor_sharded_memory_config.memory_layout; - auto output_parallel_config = parallel_config; - if(shard_layout == ttnn::TensorMemoryLayout::WIDTH_SHARDED && !is_mm_conv) { - uint32_t max_num_cores = compute_grid_size.x * compute_grid_size.y; - output_parallel_config = { - .grid = num_cores_to_corerangeset( find_closest_largest_divisor(tt::div_up(out_channels, tt::constants::TILE_WIDTH),max_num_cores), compute_grid_size, true), - .shard_scheme = ttnn::TensorMemoryLayout::WIDTH_SHARDED, - .shard_orientation = parallel_config.shard_orientation - }; - log_debug(tt::LogOp, "Changing width sharded output grid to {}",output_parallel_config.grid); - } + + ParallelConfig output_parallel_config = + determine_output_parallel_config(parallel_config, compute_grid_size, out_channels, is_mm_conv); + if (needs_shard_or_reshard) { if (input_tensor.get_shape()[0] != 1 or input_tensor.get_shape()[1] != 1) { // reshape to [1, 1, N*H*W, C] @@ -666,9 +725,14 @@ std::pair> prepare_conv_weights_biases uint32_t in_channels = weights_shape[1]; uint32_t window_h = weights_shape[2]; uint32_t window_w = weights_shape[3]; - uint32_t out_channel_padding = tt::round_up(out_channels, 32) - out_channels; + + uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); + uint32_t out_channels_padded = tt::round_up(out_channels, num_cores_channels * tt::constants::TILE_WIDTH); + uint32_t in_channels_padded = tt::round_up(in_channels, num_cores_channels * input_channels_alignment); + uint32_t out_channel_padding = out_channels_padded - out_channels; + tt::tt_metal::LegacyShape weights_channels_padded_shape = tt::tt_metal::LegacyShape(std::array( - {tt::round_up(out_channels, 32), tt::round_up(in_channels, input_channels_alignment), window_h, window_w})); + {out_channels_padded, in_channels_padded, window_h, window_w})); if (weights_bias_dtype == DataType::BFLOAT8_B) { TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); if (bias_tensor.has_value()) { @@ -862,18 +926,31 @@ Result conv2d( } uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; - auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config( + ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH)}), // todo check not to round for shallow convs output_parallel_config, round_up_size); - auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; + ParallelConfig largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; - auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( - conv_out_memory_config, get_num_cores_nhw_from_parallel_config(largest_parallel_config), + OptimizedConvParallelizationConfig opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config( + conv_out_memory_config, + get_num_cores_nhw_from_parallel_config(largest_parallel_config), get_num_cores_channels_from_parallel_config(largest_parallel_config)); - auto opt_conv_op_block_config = determine_per_core_conv_block_config( + + uint32_t in_channels_padded = tt::round_up( + in_channels, + get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment); + + uint32_t num_output_rows = batch_size * output_height * output_width; + uint32_t act_block_h_datum = + opt_conv_op_parallel_config.per_core_out_matrix_height_ntiles * tt::constants::TILE_HEIGHT; + uint32_t num_rows_padded = tt::round_up(num_output_rows, act_block_h_datum); + uint32_t num_rows_padded_ntile = num_rows_padded / tt::constants::TILE_HEIGHT; + + OptimizedConvBlockConfig opt_conv_op_block_config = determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, - tt::round_up(in_channels, conv_config.input_channels_alignment), + in_channels_padded, + num_rows_padded_ntile, conv_config.act_block_h_override, conv_config.act_block_w_div, kernel_size[0], @@ -1054,26 +1131,6 @@ Result conv2d( } } -template std::tuple get_conv_padded_input_shape_and_mem_config( - Device* device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - -template std::tuple get_conv_padded_input_shape_and_mem_config( - MeshDevice * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - Result Conv2dOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index f3334c23be84..5f1655b70172 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -103,12 +103,6 @@ struct Conv2dConfig { } }; -uint32_t find_closest_largest_divisor(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num, uint32_t start_divisor); - -uint32_t find_closest_common_largest_divisor(uint32_t num1, uint32_t num2, uint32_t start_divisor); - bool use_matmul_for_1x1_conv( const std::array& kernel_size, const std::array& stride, @@ -125,13 +119,17 @@ sliding_window::ParallelConfig determine_parallel_config( uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool is_conv2d_op, bool is_out_tiled=true); uint32_t get_num_cores_nhw_from_parallel_config(const sliding_window::ParallelConfig& pconfig); uint32_t get_num_cores_channels_from_parallel_config(const sliding_window::ParallelConfig& pconfig); -MemoryConfig create_sharded_memory_config_from_parallel_config(const ttnn::Shape& tensor_shape, sliding_window::ParallelConfig& parallel_config, uint32_t tile_size); +MemoryConfig create_sharded_memory_config_from_parallel_config( + const ttnn::Shape& tensor_shape, + sliding_window::ParallelConfig& parallel_config, + uint32_t tile_size); OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_output_mem_config( const MemoryConfig& conv_output_mem_config, uint32_t num_cores_nhw, uint32_t num_cores_c); @@ -142,6 +140,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( const sliding_window::ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, + uint32_t padded_input_height_ntiles, uint32_t act_block_h_override, uint32_t act_block_w_div, uint32_t window_h, @@ -149,17 +148,6 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( bool fp32_accum, bool split_reader_enabled); -template -std::tuple get_conv_padded_input_shape_and_mem_config( - T * device, - const ttnn::Tensor& input_tensor_, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels); - template std::tuple shard_or_reshard_tensor_if_required( T* device, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index d00c50b46d05..da709c907627 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -4,6 +4,7 @@ +#include "common/constants.hpp" #include "ttnn/cpp/pybind11/decorators.hpp" #include "conv2d_pybind.hpp" @@ -120,65 +121,6 @@ void py_bind_conv2d(py::module& module) { py::arg("queue_id") = 0} ); - module.def( - "get_conv_padded_input_shape_and_mem_config", - [](ttnn::Device* device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, - input_tensor, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels); - }, - py::kw_only(), - py::arg("device"), - py::arg("input_tensor"), - py::arg("conv_config"), - py::arg("batch_size"), - py::arg("height"), - py::arg("width"), - py::arg("in_channels"), - py::arg("out_channels")); - - module.def( - "get_conv_padded_input_shape_and_mem_config", - [](MeshDevice* device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { - return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, - input_tensor, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels); - }, - py::kw_only(), - py::arg("device"), - py::arg("input_tensor"), - py::arg("conv_config"), - py::arg("batch_size"), - py::arg("height"), - py::arg("width"), - py::arg("in_channels"), - py::arg("out_channels")); module.def( "convert_conv_weight_tensor_to_tiled_layout", @@ -213,9 +155,10 @@ void py_bind_conv2d(py::module& module) { uint32_t output_channels, const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, + bool is_conv2d_op, bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { return ttnn::operations::conv::conv2d::determine_parallel_config( - shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_out_tiled); + shard_layout, batch_size, input_channels, output_height, output_width, output_channels, compute_grid_size, block_shard_orientation, is_conv2d_op, is_out_tiled); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -225,6 +168,7 @@ void py_bind_conv2d(py::module& module) { py::arg("output_channels"), py::arg("compute_grid_size"), py::arg("block_shard_orientation"), + py::arg("is_conv2d_op"), py::arg("is_out_tiled") = true); module.def( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index fc265b666f6e..f18790011252 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -6,6 +6,7 @@ #include #include "conv2d_op.hpp" +#include "common/math.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/detail/util.hpp" @@ -115,9 +116,9 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const // For block sharded, out_width per core is shard width, and this is split along row // TODO: We should clean this up and relax constraints on out_subblock h and w if (this->memory_config.shard_spec.value().orientation == ShardOrientation::COL_MAJOR) { - out_width_ntiles /= this->parallelization_config.grid_size.y; + out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.y); } else { - out_width_ntiles /= this->parallelization_config.grid_size.x; + out_width_ntiles = tt::div_up(out_width_ntiles, this->parallelization_config.grid_size.x); } TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index cffa1308549d..2171e966244c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -845,10 +845,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; } TT_FATAL(total_active_num_cores_per_weight_slice <= total_num_cores_per_weight_slice, "Error"); - uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; + //uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; if (weight_width_sliced) { - TT_FATAL(total_noop_cores == 0, "Error"); + //TT_FATAL(total_noop_cores == 0, "Error"); TT_FATAL(total_active_num_cores == total_num_cores, "Error"); } @@ -874,14 +874,14 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( CoreCoord(num_active_cores_x_last_y - 1, num_active_cores_y_with_full_x))); } CoreRangeSet all_active_cores(all_active_cores_set); - std::set noop_cores_set; - if (total_noop_cores > 0) { - TT_FATAL(total_noop_cores == num_cores_x - num_active_cores_x_last_y, "Expected total_noop_cores {} to be equal to num_cores_x {} - num_active_cores_x_last_y {}", total_noop_cores, num_cores_x, num_active_cores_x_last_y); - noop_cores_set.insert(CoreRange( - CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), - CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); - } - CoreRangeSet noop_cores(noop_cores_set); + // std::set noop_cores_set; + // if (total_noop_cores > 0) { + // //TT_FATAL(total_noop_cores == num_cores_x - num_active_cores_x_last_y, "Expected total_noop_cores {} to be equal to num_cores_x {} - num_active_cores_x_last_y {}", total_noop_cores, num_cores_x, num_active_cores_x_last_y); + // noop_cores_set.insert(CoreRange( + // CoreCoord(num_active_cores_x_last_y, num_active_cores_y_with_full_x), + // CoreCoord(num_cores_x - 1, num_active_cores_y_with_full_x))); + // } + // CoreRangeSet noop_cores(noop_cores_set); // Mcast cores // If total_num_cores, there is no mcasting diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index d74a1957d06f..f66c44935b01 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -339,9 +339,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( : (output_channels_padded_to_tile_width % weight_block_w_datums); TT_FATAL(last_block_width_datums % TILE_WIDTH == 0, "last_block_width_datums {} should be divisible by TILE_WIDTH {}", last_block_width_datums, TILE_WIDTH); - // sanity check - TT_FATAL(num_blocks_output_w == num_blocks_weight_w, "num_blocks_output_w {} should be equal to num_blocks_weight_w {}", num_blocks_output_w, num_blocks_weight_w); - uint32_t out_block_h_datums = out_block_h_ntiles * TILE_HEIGHT; tt_metal::Buffer* src0_dram_buffer = a.buffer(); diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 8543a01d1317..17b3523d447d 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -212,21 +212,28 @@ Result conv_transpose2d( //Call Conv2d u_op with Stride = 1, Padding = 0. auto conv_out_memory_config = conv2d::create_sharded_memory_config_from_parallel_config( - ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), - output_parallel_config, - round_up_size); + ttnn::Shape(std::array{ + 1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + output_parallel_config, + round_up_size); auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; auto opt_conv_op_parallel_config = conv2d::determine_conv_op_parallel_config_from_conv_output_mem_config( conv_out_memory_config, conv2d::get_num_cores_nhw_from_parallel_config(largest_parallel_config), - conv2d::get_num_cores_channels_from_parallel_config(largest_parallel_config) - ); + conv2d::get_num_cores_channels_from_parallel_config(largest_parallel_config)); + + uint32_t in_channels_padded = tt::round_up( + in_channels, + conv2d::get_num_cores_channels_from_parallel_config(parallel_config) * + conv_config.input_channels_alignment); + auto opt_conv_op_block_config = conv2d::determine_per_core_conv_block_config( parallel_config, opt_conv_op_parallel_config, - tt::round_up(in_channels, conv_config.input_channels_alignment), + in_channels_padded, + (input_tensor_post_tm.shard_spec().value().shape[0] * conv2d::get_num_cores_nhw_from_parallel_config(parallel_config)) / tt::constants::TILE_HEIGHT, conv_config.act_block_h_override, conv_config.act_block_w_div, kernel_size[0], diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 622ebad74c00..374e8b81cbbb 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1144,8 +1144,7 @@ void Matmul::validate( uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; auto shard_shape = input_tensor_a.shard_spec().value().shape; - - TT_FATAL(div_up(M, per_core_M) == input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); + TT_FATAL(div_up(M, per_core_M) <= input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); TT_FATAL(K == (shard_shape[1] / in0_tile_shape[1]), "Error"); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp index bf2cf007f152..d51da2a63550 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp @@ -55,15 +55,16 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, shard_layout = applied_shard_scheme.value(); } parallel_config = conv::conv2d::determine_parallel_config( - shard_layout, - batch_size, - channels, - output_shape[1], - output_shape[2], - channels, - input_tensor.device()->compute_with_storage_grid_size(), - ShardOrientation::ROW_MAJOR, - false); + shard_layout, + batch_size, + channels, + output_shape[1], + output_shape[2], + channels, + input_tensor.device()->compute_with_storage_grid_size(), + ShardOrientation::ROW_MAJOR, + false, + false); num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config); num_cores_c = conv::conv2d::get_num_cores_channels_from_parallel_config(parallel_config); auto sharded_mem_config = conv::conv2d::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1); @@ -85,7 +86,13 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, // update the shard spec to match the output shape auto shard_spec = out_memory_config.shard_spec.value(); - uint32_t output_shard_width_padded = input_tensor.dtype() == DataType::BFLOAT8_B ? tt::round_up(channels / num_cores_c, tt::constants::TILE_WIDTH) : tt::round_up(channels / num_cores_c * tt::datum_size(tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype())), tt::constants::TILE_WIDTH); + uint32_t output_shard_width_padded = + input_tensor.dtype() == DataType::BFLOAT8_B + ? tt::round_up(channels / num_cores_c, tt::constants::TILE_WIDTH) + : tt::round_up( + channels / num_cores_c * + tt::datum_size(tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype())), + tt::constants::TILE_WIDTH); uint32_t output_nhw = output_shape[0] * output_shape[1] * output_shape[2]; uint32_t output_nhw_padded = tt::round_up(output_nhw, num_cores_nhw * (is_out_tiled ? tt::constants::TILE_HEIGHT : 1)); uint32_t output_shard_height_padded = output_nhw_padded / num_cores_nhw; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba363c39718..231f8b326af3 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -295,7 +295,7 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL Topology, ) -from ttnn.operations.conv2d import Conv2dConfig, get_conv_padded_input_shape_and_mem_config, get_conv_output_dim +from ttnn.operations.conv2d import Conv2dConfig, get_conv_output_dim from ttnn.operations.pool import avg_pool2d from ttnn.operations.conv1d import Conv1d, Conv1dConfig diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index ca1f329dd696..b46a7e1fbf79 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -21,7 +21,6 @@ def _nearest_32(x): Conv2dConfig = ttnn._ttnn.operations.conv.Conv2dConfig -get_conv_padded_input_shape_and_mem_config = ttnn._ttnn.operations.conv.get_conv_padded_input_shape_and_mem_config OptimizedConvParallelizationConfig = ttnn._ttnn.operations.conv.OptimizedConvParallelizationConfig OptimizedConvBlockConfig = ttnn._ttnn.operations.conv.OptimizedConvBlockConfig