From 66b25f2c8fc9e960a4d23f4c3437959d74cc90cc Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Sun, 24 Nov 2024 09:26:42 +0000 Subject: [PATCH] #14257: Yolo Optimization. --- models/demos/yolov4/demo/demo.py | 29 ++++++- models/demos/yolov4/ttnn/common.py | 3 + models/demos/yolov4/ttnn/downsample1.py | 11 ++- models/demos/yolov4/ttnn/downsample4.py | 4 +- models/demos/yolov4/ttnn/head.py | 11 ++- models/demos/yolov4/ttnn/neck.py | 85 +++++++++++++++++-- .../yolov4/test_ttnn_head.py | 9 +- .../yolov4/test_ttnn_neck.py | 9 +- .../concat/device/concat_device_operation.cpp | 16 +++- 9 files changed, 157 insertions(+), 20 deletions(-) diff --git a/models/demos/yolov4/demo/demo.py b/models/demos/yolov4/demo/demo.py index 60da9eb78107..47ddc2b0abff 100644 --- a/models/demos/yolov4/demo/demo.py +++ b/models/demos/yolov4/demo/demo.py @@ -419,7 +419,34 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class if not is_torch_model: input_shape = img.shape input_tensor = torch.permute(img, (0, 2, 3, 1)) - input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16) + # input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16) + input_tensor = torch.permute(img, (0, 2, 3, 1)) # put channel at the end + input_tensor = torch.nn.functional.pad( + input_tensor, (0, 13, 0, 0, 0, 0, 0, 0) + ) # pad channel dim from 3 to 16 + N, H, W, C = input_tensor.shape + input_tensor = torch.reshape(input_tensor, (N, 1, H * W, C)) + + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 7), + ), + } + ) + n_cores = 64 + shard_spec = ttnn.ShardSpec(shard_grid, [N * H * W // n_cores, C], ttnn.ShardOrientation.ROW_MAJOR, False) + input_mem_config = ttnn.MemoryConfig( + ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + input_tensor = ttnn.from_torch( + input_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=device, + memory_config=input_mem_config, + ) img = input_tensor t1 = time.time() diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py index 7f7a98d75b55..d9e2a8a0bd6c 100644 --- a/models/demos/yolov4/ttnn/common.py +++ b/models/demos/yolov4/ttnn/common.py @@ -42,6 +42,7 @@ def __init__( activation="", fused_op=True, width_sharding=False, + output_layout=ttnn.TILE_LAYOUT, ) -> None: if fused_op: self.weights, self.bias = fold_bn_to_conv_weights_bias(model, path) @@ -57,6 +58,7 @@ def __init__( self.out_channels = self.weights.shape[0] self.act_block_h = act_block_h self.reshard = reshard + self.output_layout = output_layout if width_sharding: self.shard_layout = ttnn.TensorMemoryLayout.WIDTH_SHARDED @@ -86,6 +88,7 @@ def __call__(self, device, input_tensor): reshard_if_not_optimal=self.reshard, deallocate_activation=self.deallocate, reallocate_halo_output=False, + output_layout=self.output_layout, ) if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h diff --git a/models/demos/yolov4/ttnn/downsample1.py b/models/demos/yolov4/ttnn/downsample1.py index cc2f2cff37f7..33eb5d73f279 100644 --- a/models/demos/yolov4/ttnn/downsample1.py +++ b/models/demos/yolov4/ttnn/downsample1.py @@ -5,6 +5,8 @@ import torch import ttnn from models.demos.yolov4.ttnn.common import Conv +from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout class Down1: @@ -15,7 +17,7 @@ def __init__(self, model) -> None: torch_model = model.torch_model self.torch_model = torch_model self.conv1 = Conv(torch_model, "down1.conv1", [1, 320, 320, 3], (1, 1, 1, 1), act_block_h=128) - self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1), reshard=True) + self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1)) self.conv3 = Conv(torch_model, "down1.conv3", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False) self.conv4 = Conv(torch_model, "down1.conv4", [1, 160, 160, 64], (1, 1, 0, 0)) self.conv5 = Conv(torch_model, "down1.conv5", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False) @@ -30,6 +32,13 @@ def __call__(self, device, input_tensor): output_tensor_split = self.conv2(device, output_tensor) output_tensor_split = ttnn.mish(output_tensor_split) + shard_grid = get_shard_grid_from_num_cores(50, device) + shard_spec = ttnn.ShardSpec(shard_grid, (512, 64), ttnn.ShardOrientation.ROW_MAJOR, False) + in_sharded_mem_config_conv_5 = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec + ) + output_tensor_split = ttnn.to_memory_config(output_tensor_split, memory_config=in_sharded_mem_config_conv_5) + output_tensor_left = self.conv3(device, output_tensor_split) output_tensor_left = ttnn.mish(output_tensor_left) diff --git a/models/demos/yolov4/ttnn/downsample4.py b/models/demos/yolov4/ttnn/downsample4.py index 8e7d3ad78768..b47481b030cf 100644 --- a/models/demos/yolov4/ttnn/downsample4.py +++ b/models/demos/yolov4/ttnn/downsample4.py @@ -14,7 +14,9 @@ def __init__(self, model) -> None: else: torch_model = model.torch_model self.torch_model = torch_model - self.conv1 = Conv(torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True) + self.conv1 = Conv( + torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True, height_sharding=False + ) self.conv2 = Conv(torch_model, "down4.conv2", [1, 20, 20, 512], (1, 1, 0, 0), deallocate=False) self.conv3 = Conv(torch_model, "down4.conv3", [1, 20, 20, 512], (1, 1, 0, 0)) diff --git a/models/demos/yolov4/ttnn/head.py b/models/demos/yolov4/ttnn/head.py index 943fc086572c..7bb2036a383c 100644 --- a/models/demos/yolov4/ttnn/head.py +++ b/models/demos/yolov4/ttnn/head.py @@ -16,7 +16,15 @@ def __init__(self, model) -> None: self.torch_model = torch_model self.conv1 = Conv(torch_model, "head.conv1", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, deallocate=False) self.conv2 = Conv(torch_model, "head.conv2", [1, 40, 40, 256], (1, 1, 0, 0), fused_op=False) - self.conv3 = Conv(torch_model, "head.conv3", [1, 40, 40, 128], (2, 2, 1, 1), reshard=True, deallocate=False) + self.conv3 = Conv( + torch_model, + "head.conv3", + [1, 40, 40, 128], + (2, 2, 1, 1), + reshard=True, + deallocate=False, + height_sharding=False, + ) self.conv4 = Conv( torch_model, "head.conv4", @@ -71,6 +79,7 @@ def __init__(self, model) -> None: [1, 20, 20, 256], (2, 2, 1, 1), reshard=True, + height_sharding=False, ) self.conv12 = Conv( torch_model, diff --git a/models/demos/yolov4/ttnn/neck.py b/models/demos/yolov4/ttnn/neck.py index 59d62238709c..08f505d7a24b 100644 --- a/models/demos/yolov4/ttnn/neck.py +++ b/models/demos/yolov4/ttnn/neck.py @@ -35,7 +35,7 @@ def __init__(self, model) -> None: "neek.conv3", [1, 10, 10, 1024], (1, 1, 0, 0), - reshard=True, + reshard=False, ) self.conv4 = Conv( @@ -99,7 +99,6 @@ def __init__(self, model) -> None: "neek.conv12", [1, 20, 20, 256], (1, 1, 1, 1), - reshard=True, ) self.conv7_5 = Conv( torch_model, @@ -115,6 +114,7 @@ def __init__(self, model) -> None: [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False, + height_sharding=False, ) self.conv9_2 = Conv( torch_model, @@ -223,9 +223,38 @@ def __call__(self, device, input_tensor): output_tensor = self.conv7(device, output_tensor_left_1) output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1) - output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) - output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) - output_tensor_upsample_1 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG) + output_shape = output_tensor.shape + output_tensor = ttnn.untilize_with_unpadding( + output_tensor, + output_tensor_end=( + output_shape[0] - 1, + output_shape[1] - 1, + output_shape[2] - 1, + output_shape[3] - 1, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + output_tensor = ttnn.reshape(output_tensor, (1, 10, 10, 256)) + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 4), + ), + } + ) + shard_spec = ttnn.ShardSpec(shard_grid, (20, 32), ttnn.ShardOrientation.ROW_MAJOR, False) + in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec) + output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config) + shard_spec = ttnn.ShardSpec(shard_grid, (80, 32), ttnn.ShardOrientation.ROW_MAJOR, False) + out_sharded_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + + output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config) + output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG) + output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256)) output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT) outDowSample5 = input_tensor[1] @@ -254,12 +283,52 @@ def __call__(self, device, input_tensor): output_tensor = self.conv7_5(device, output_tensor) output_tensor_left_2 = ttnn.leaky_relu(output_tensor, negative_slope=0.1) + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(6, 3), + ), + } + ) + shard_spec = ttnn.ShardSpec(shard_grid, (64, 64), ttnn.ShardOrientation.COL_MAJOR, False) + in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec) + output_tensor_left_2 = ttnn.to_memory_config(output_tensor_left_2, memory_config=in_sharded_mem_config) output_tensor = self.conv9(device, output_tensor_left_2) output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1) - output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) - output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) - output_tensor_upsample_2 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG) + output_shape = output_tensor.shape + output_tensor = ttnn.untilize_with_unpadding( + output_tensor, + output_tensor_end=( + output_shape[0] - 1, + output_shape[1] - 1, + output_shape[2] - 1, + output_shape[3] - 1, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + output_tensor = ttnn.reshape(output_tensor, (1, 20, 20, 128)) + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 4), + ), + } + ) + shard_spec = ttnn.ShardSpec(shard_grid, (80, 16), ttnn.ShardOrientation.ROW_MAJOR, False) + in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec) + output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config) + shard_spec = ttnn.ShardSpec(shard_grid, (80 * 4, 16), ttnn.ShardOrientation.ROW_MAJOR, False) + out_sharded_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + + output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config) + output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG) + output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128)) output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT) outDowSample3 = input_tensor[2] diff --git a/tests/ttnn/integration_tests/yolov4/test_ttnn_head.py b/tests/ttnn/integration_tests/yolov4/test_ttnn_head.py index 2465e8de9821..2dc894641a3a 100644 --- a/tests/ttnn/integration_tests/yolov4/test_ttnn_head.py +++ b/tests/ttnn/integration_tests/yolov4/test_ttnn_head.py @@ -93,6 +93,9 @@ def test_head(device, reset_seeds, model_location_generator): result_2 = result_2[:, :255, :, :] result_3 = result_3[:, :255, :, :] - assert_with_pcc(result_1, ref1, 0.99) - assert_with_pcc(result_2, ref2, 0.99) - assert_with_pcc(result_3, ref3, 0.99) + pcc_passed, pcc_message = assert_with_pcc(result_1, ref1, 0.99) + logger.info(pcc_message) + pcc_passed, pcc_message = assert_with_pcc(result_2, ref2, 0.99) + logger.info(pcc_message) + pcc_passed, pcc_message = assert_with_pcc(result_3, ref3, 0.99) + logger.info(pcc_message) diff --git a/tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py b/tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py index bd4805ca9bb9..1943bc0dcd7b 100644 --- a/tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py +++ b/tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py @@ -79,6 +79,9 @@ def test_neck(device, reset_seeds, model_location_generator): result1 = result_1.reshape(ref1.shape) result2 = result_2.reshape(ref2.shape) result3 = result_3.reshape(ref3.shape) - assert_with_pcc(result1, ref1, 0.94) # PCC = 0.94 - assert_with_pcc(result2, ref2, 0.985) # PCC = 0.985 - assert_with_pcc(result3, ref3, 0.96) # PCC = 0.96 + pcc_passed, pcc_message = assert_with_pcc(result1, ref1, 0.99) # PCC = 0.99 + logger.info(pcc_message) + pcc_passed, pcc_message = assert_with_pcc(result2, ref2, 0.985) # PCC = 0.985 + logger.info(pcc_message) + pcc_passed, pcc_message = assert_with_pcc(result3, ref3, 0.96) # PCC = 0.96 + logger.info(pcc_message) diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index dd7054b7b434..7842b75f1805 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -25,7 +25,7 @@ ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strat void ConcatDeviceOperation::validate(const std::vector &input_tensors) const { const auto &first_input = input_tensors[0]; - tt::tt_metal::LegacyShape shape_first = first_input.get_legacy_shape(); + ttnn::SimpleShape shape_first = first_input.get_logical_shape(); TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank."); shape_first[this->dim] = 0; bool shard_first = input_tensors[0].is_sharded(); @@ -38,12 +38,24 @@ void ConcatDeviceOperation::validate(const std::vector &input_tensors) c TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device."); TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts."); TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes."); - tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape(); + ttnn::SimpleShape curr_shape = in_ref.get_logical_shape(); + TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal"); curr_shape[this->dim] = 0; // last tensor can support without any kernel changes if(in_ref.get_layout() == Layout::TILE and in_ref.get_shape().has_tile_padding(this->dim)) { warn_about_alignment = true; + /* // last tensor can support without any kernel changes + TT_FATAL( + !in_ref.get_shape().has_tile_padding(this->dim), + "Tile padding along concatenated dim ({}) not supported for concat yet (tensor: {}).", + this->dim, + i); + TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions."); + if (in_ref.get_layout() == Layout::ROW_MAJOR && this->dim == shape_first.rank() - 1) { + TT_FATAL( + (in_ref.get_logical_shape()[this->dim] * in_ref.element_size()) % in_ref.buffer()->alignment() == 0, + "Current concat implementation requires aligned last dim when concatting on last dim");*/ } TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions."); TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");