Skip to content

Commit

Permalink
#14257: Yolo Optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 24, 2024
1 parent 4df8579 commit 66b25f2
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 20 deletions.
29 changes: 28 additions & 1 deletion models/demos/yolov4/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,34 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
if not is_torch_model:
input_shape = img.shape
input_tensor = torch.permute(img, (0, 2, 3, 1))
input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16)
# input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16)
input_tensor = torch.permute(img, (0, 2, 3, 1)) # put channel at the end
input_tensor = torch.nn.functional.pad(
input_tensor, (0, 13, 0, 0, 0, 0, 0, 0)
) # pad channel dim from 3 to 16
N, H, W, C = input_tensor.shape
input_tensor = torch.reshape(input_tensor, (N, 1, H * W, C))

shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 7),
),
}
)
n_cores = 64
shard_spec = ttnn.ShardSpec(shard_grid, [N * H * W // n_cores, C], ttnn.ShardOrientation.ROW_MAJOR, False)
input_mem_config = ttnn.MemoryConfig(
ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec
)
input_tensor = ttnn.from_torch(
input_tensor,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=device,
memory_config=input_mem_config,
)
img = input_tensor
t1 = time.time()

Expand Down
3 changes: 3 additions & 0 deletions models/demos/yolov4/ttnn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
activation="",
fused_op=True,
width_sharding=False,
output_layout=ttnn.TILE_LAYOUT,
) -> None:
if fused_op:
self.weights, self.bias = fold_bn_to_conv_weights_bias(model, path)
Expand All @@ -57,6 +58,7 @@ def __init__(
self.out_channels = self.weights.shape[0]
self.act_block_h = act_block_h
self.reshard = reshard
self.output_layout = output_layout

if width_sharding:
self.shard_layout = ttnn.TensorMemoryLayout.WIDTH_SHARDED
Expand Down Expand Up @@ -86,6 +88,7 @@ def __call__(self, device, input_tensor):
reshard_if_not_optimal=self.reshard,
deallocate_activation=self.deallocate,
reallocate_halo_output=False,
output_layout=self.output_layout,
)
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h
Expand Down
11 changes: 10 additions & 1 deletion models/demos/yolov4/ttnn/downsample1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import ttnn
from models.demos.yolov4.ttnn.common import Conv
from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


class Down1:
Expand All @@ -15,7 +17,7 @@ def __init__(self, model) -> None:
torch_model = model.torch_model
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "down1.conv1", [1, 320, 320, 3], (1, 1, 1, 1), act_block_h=128)
self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1), reshard=True)
self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1))
self.conv3 = Conv(torch_model, "down1.conv3", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False)
self.conv4 = Conv(torch_model, "down1.conv4", [1, 160, 160, 64], (1, 1, 0, 0))
self.conv5 = Conv(torch_model, "down1.conv5", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False)
Expand All @@ -30,6 +32,13 @@ def __call__(self, device, input_tensor):
output_tensor_split = self.conv2(device, output_tensor)
output_tensor_split = ttnn.mish(output_tensor_split)

shard_grid = get_shard_grid_from_num_cores(50, device)
shard_spec = ttnn.ShardSpec(shard_grid, (512, 64), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config_conv_5 = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec
)
output_tensor_split = ttnn.to_memory_config(output_tensor_split, memory_config=in_sharded_mem_config_conv_5)

output_tensor_left = self.conv3(device, output_tensor_split)
output_tensor_left = ttnn.mish(output_tensor_left)

Expand Down
4 changes: 3 additions & 1 deletion models/demos/yolov4/ttnn/downsample4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def __init__(self, model) -> None:
else:
torch_model = model.torch_model
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True)
self.conv1 = Conv(
torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True, height_sharding=False
)
self.conv2 = Conv(torch_model, "down4.conv2", [1, 20, 20, 512], (1, 1, 0, 0), deallocate=False)
self.conv3 = Conv(torch_model, "down4.conv3", [1, 20, 20, 512], (1, 1, 0, 0))

Expand Down
11 changes: 10 additions & 1 deletion models/demos/yolov4/ttnn/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ def __init__(self, model) -> None:
self.torch_model = torch_model
self.conv1 = Conv(torch_model, "head.conv1", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, deallocate=False)
self.conv2 = Conv(torch_model, "head.conv2", [1, 40, 40, 256], (1, 1, 0, 0), fused_op=False)
self.conv3 = Conv(torch_model, "head.conv3", [1, 40, 40, 128], (2, 2, 1, 1), reshard=True, deallocate=False)
self.conv3 = Conv(
torch_model,
"head.conv3",
[1, 40, 40, 128],
(2, 2, 1, 1),
reshard=True,
deallocate=False,
height_sharding=False,
)
self.conv4 = Conv(
torch_model,
"head.conv4",
Expand Down Expand Up @@ -71,6 +79,7 @@ def __init__(self, model) -> None:
[1, 20, 20, 256],
(2, 2, 1, 1),
reshard=True,
height_sharding=False,
)
self.conv12 = Conv(
torch_model,
Expand Down
85 changes: 77 additions & 8 deletions models/demos/yolov4/ttnn/neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, model) -> None:
"neek.conv3",
[1, 10, 10, 1024],
(1, 1, 0, 0),
reshard=True,
reshard=False,
)

self.conv4 = Conv(
Expand Down Expand Up @@ -99,7 +99,6 @@ def __init__(self, model) -> None:
"neek.conv12",
[1, 20, 20, 256],
(1, 1, 1, 1),
reshard=True,
)
self.conv7_5 = Conv(
torch_model,
Expand All @@ -115,6 +114,7 @@ def __init__(self, model) -> None:
[1, 20, 20, 256],
(1, 1, 0, 0),
deallocate=False,
height_sharding=False,
)
self.conv9_2 = Conv(
torch_model,
Expand Down Expand Up @@ -223,9 +223,38 @@ def __call__(self, device, input_tensor):
output_tensor = self.conv7(device, output_tensor_left_1)
output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor_upsample_1 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG)
output_shape = output_tensor.shape
output_tensor = ttnn.untilize_with_unpadding(
output_tensor,
output_tensor_end=(
output_shape[0] - 1,
output_shape[1] - 1,
output_shape[2] - 1,
output_shape[3] - 1,
),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.reshape(output_tensor, (1, 10, 10, 256))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (20, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256))
output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT)

outDowSample5 = input_tensor[1]
Expand Down Expand Up @@ -254,12 +283,52 @@ def __call__(self, device, input_tensor):
output_tensor = self.conv7_5(device, output_tensor)
output_tensor_left_2 = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(6, 3),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (64, 64), ttnn.ShardOrientation.COL_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor_left_2 = ttnn.to_memory_config(output_tensor_left_2, memory_config=in_sharded_mem_config)
output_tensor = self.conv9(device, output_tensor_left_2)
output_tensor = ttnn.leaky_relu(output_tensor, negative_slope=0.1)

output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor_upsample_2 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG)
output_shape = output_tensor.shape
output_tensor = ttnn.untilize_with_unpadding(
output_tensor,
output_tensor_end=(
output_shape[0] - 1,
output_shape[1] - 1,
output_shape[2] - 1,
output_shape[3] - 1,
),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.reshape(output_tensor, (1, 20, 20, 128))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80 * 4, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128))
output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT)

outDowSample3 = input_tensor[2]
Expand Down
9 changes: 6 additions & 3 deletions tests/ttnn/integration_tests/yolov4/test_ttnn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def test_head(device, reset_seeds, model_location_generator):
result_2 = result_2[:, :255, :, :]
result_3 = result_3[:, :255, :, :]

assert_with_pcc(result_1, ref1, 0.99)
assert_with_pcc(result_2, ref2, 0.99)
assert_with_pcc(result_3, ref3, 0.99)
pcc_passed, pcc_message = assert_with_pcc(result_1, ref1, 0.99)
logger.info(pcc_message)
pcc_passed, pcc_message = assert_with_pcc(result_2, ref2, 0.99)
logger.info(pcc_message)
pcc_passed, pcc_message = assert_with_pcc(result_3, ref3, 0.99)
logger.info(pcc_message)
9 changes: 6 additions & 3 deletions tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def test_neck(device, reset_seeds, model_location_generator):
result1 = result_1.reshape(ref1.shape)
result2 = result_2.reshape(ref2.shape)
result3 = result_3.reshape(ref3.shape)
assert_with_pcc(result1, ref1, 0.94) # PCC = 0.94
assert_with_pcc(result2, ref2, 0.985) # PCC = 0.985
assert_with_pcc(result3, ref3, 0.96) # PCC = 0.96
pcc_passed, pcc_message = assert_with_pcc(result1, ref1, 0.99) # PCC = 0.99
logger.info(pcc_message)
pcc_passed, pcc_message = assert_with_pcc(result2, ref2, 0.985) # PCC = 0.985
logger.info(pcc_message)
pcc_passed, pcc_message = assert_with_pcc(result3, ref3, 0.96) # PCC = 0.96
logger.info(pcc_message)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strat

void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) const {
const auto &first_input = input_tensors[0];
tt::tt_metal::LegacyShape shape_first = first_input.get_legacy_shape();
ttnn::SimpleShape shape_first = first_input.get_logical_shape();
TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank.");
shape_first[this->dim] = 0;
bool shard_first = input_tensors[0].is_sharded();
Expand All @@ -38,12 +38,24 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) c
TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device.");
TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts.");
TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes.");
tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape();
ttnn::SimpleShape curr_shape = in_ref.get_logical_shape();

TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal");
curr_shape[this->dim] = 0;
// last tensor can support without any kernel changes
if(in_ref.get_layout() == Layout::TILE and in_ref.get_shape().has_tile_padding(this->dim)) {
warn_about_alignment = true;
/* // last tensor can support without any kernel changes
TT_FATAL(
!in_ref.get_shape().has_tile_padding(this->dim),
"Tile padding along concatenated dim ({}) not supported for concat yet (tensor: {}).",
this->dim,
i);
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
if (in_ref.get_layout() == Layout::ROW_MAJOR && this->dim == shape_first.rank() - 1) {
TT_FATAL(
(in_ref.get_logical_shape()[this->dim] * in_ref.element_size()) % in_ref.buffer()->alignment() == 0,
"Current concat implementation requires aligned last dim when concatting on last dim");*/
}
TT_FATAL(curr_shape == shape_first, "concat tensors differ in shape across non-concat dimensions.");
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
Expand Down

0 comments on commit 66b25f2

Please sign in to comment.