Skip to content

Commit

Permalink
#10110: conv non tile sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Oct 22, 2024
1 parent 219678c commit cb5a429
Show file tree
Hide file tree
Showing 14 changed files with 276 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def __init__(

conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16)
conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
_, self.override_fold_mem_config, _ = ttnn.get_conv_padded_input_shape_and_mem_config(
_, self.override_fold_mem_config, _, _ = ttnn.get_conv_padded_input_shape_and_mem_config(
device=device,
input_tensor=conv_dummy_tensor,
conv_config=self.conv1_config,
Expand Down
114 changes: 114 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def run_conv(
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
output_layout=output_layout,
)
if config_override and "act_block_h" in config_override:
conv_config.act_block_h_override = config_override["act_block_h"]
Expand Down Expand Up @@ -2203,3 +2204,116 @@ def test_conv_for_vanilla_unet(
output_layout=output_layout,
has_bias=False,
)


@skip_for_blackhole()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override",
(
# unique convs in rn50 (complete list)
# first conv post folding and input_channels padding to tile width
(16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, True, None),
# rn50 layer1
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
# rn50 layer2
(8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None),
(16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None),
(20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None),
(8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None),
(16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None),
(20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None),
(1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None),
(1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize("fp32_accum", [False, True], ids=["no_fp32_accum", "fp32_accum"])
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"])
def test_non_tile_multiple_height_conv_wh(
device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
fp32_accum,
packer_l1_acc,
has_bias,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

if (
is_grayskull()
and activations_dtype == ttnn.bfloat16
and batch_size == 20
and (
output_channels == 64
or (
stride_h == 2
and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16))
)
)
):
pytest.skip("Skipping test because it won't fit in L1!")

if (
(weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 128 and input_height == 56)
or (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 64)
or (weights_dtype == ttnn.bfloat8_b and batch_size == 20 and output_channels == 128 and input_height == 56)
):
pytest.skip("Skipping test because it won't fit in L1!")

if has_bias and packer_l1_acc and fp32_accum:
pytest.skip("bug!")

use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0
run_conv(
device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=packer_l1_acc,
fp32_accum=fp32_accum,
has_bias=has_bias,
output_layout=ttnn.ROW_MAJOR_LAYOUT,
)
74 changes: 44 additions & 30 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ MemoryConfig create_sharded_memory_config_from_parallel_config(

uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2];
uint32_t nhw_padded = nhw_shape;
if(shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) {
nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size);
}
nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size);
uint32_t nhw_shard = nhw_padded / num_cores_nhw;
TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels);
uint32_t channel_shard = channels / num_cores_channels;
Expand All @@ -204,14 +202,16 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o
TT_ASSERT(conv_output_mem_config.shard_spec.has_value());
const auto& shard_spec = conv_output_mem_config.shard_spec.value();
const auto& shard_shape = shard_spec.shape;
TT_ASSERT(conv_output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || shard_shape[0] % 32 == 0);
TT_ASSERT(shard_shape[1] % 32 == 0);
uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32);
return {
.grid_size = shard_spec.grid.bounding_box().grid_size(),
.num_cores_nhw = num_cores_nhw,
.num_cores_c = num_cores_c,
.per_core_out_matrix_height_ntiles = tt::round_up(shard_shape[0], 32) / 32,
.per_core_out_matrix_height_ntiles = per_core_out_matrix_height_ntiles,
.per_core_out_matrix_width_ntiles = shard_shape[1] / 32,
.per_core_out_matrix_height = shard_shape[0],
.per_core_out_matrix_width = shard_shape[1],
};
}

Expand Down Expand Up @@ -382,7 +382,7 @@ static TensorMemoryLayout select_shard_spec(
}

template <typename T>
std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_and_mem_config(
std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_shape_and_mem_config(
T* device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand Down Expand Up @@ -429,6 +429,10 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_an
dilation,
device);
}
bool use_non_tile_height = shard_layout == TensorMemoryLayout::HEIGHT_SHARDED && out_channels <= 256 && conv_config.act_block_h_override == 0 &&
conv_config.dtype == DataType::BFLOAT16 && conv_config.output_layout == Layout::ROW_MAJOR;
use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; //shalow conv varient

ParallelConfig input_tensor_parallel_config;
if (!input_tensor_on_device) {
needs_shard_or_reshard = true;
Expand Down Expand Up @@ -474,8 +478,16 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_an
if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) {
auto block_shard_orientation =
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
ParallelConfig optimal_parallel_config = determine_parallel_config(
shard_layout, batch_size, in_channels, height, width, out_channels, device, block_shard_orientation);
const ParallelConfig& optimal_parallel_config = determine_parallel_config(
shard_layout,
batch_size,
in_channels,
height,
width,
out_channels,
device,
block_shard_orientation,
!use_non_tile_height);

if (conv_config.override_sharding_config) {
TT_FATAL(conv_config.core_grid.has_value(), "Error");
Expand All @@ -498,7 +510,8 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_an
uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config);
// TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape());
uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2];
uint32_t input_tensor_height_snapped_to_tile = (shard_layout == TensorMemoryLayout::WIDTH_SHARDED)? tensor_height : tt::round_up(tensor_height, input_num_cores_nhw * 32);
uint32_t round_up_size = (use_non_tile_height || conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED) ? 1 : tt::constants::TILE_HEIGHT;
uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size);
TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height);
uint32_t tensor_width = input_tensor.get_shape()[3];
uint32_t input_tensor_width_snapped_to_channels_alignment =
Expand All @@ -510,20 +523,18 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_an
1,
input_tensor_height_snapped_to_tile,
input_tensor_width_snapped_to_channels_alignment}); // TODO: resolve ttnn::types::Shape and
// tt::tt_metal::LegacyShape issue to clean up next line
auto input_tensor_sharded_memory_config = create_sharded_memory_config_from_parallel_config(
ttnn::Shape(std::array<uint32_t, 4>{
input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}),
parallel_config,
32);
return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard};
input_padded_shape[0], input_padded_shape[1], input_padded_shape[2], input_padded_shape[3]}),
parallel_config, round_up_size);
return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height};
} else {
return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard};
return {input_tensor.shape(), input_tensor.memory_config(), needs_shard_or_reshard, use_non_tile_height};
}
}

template <typename T>
std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_required(
std::tuple<ttnn::Tensor, ParallelConfig, bool, bool> shard_or_reshard_tensor_if_required(
T* device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand All @@ -542,7 +553,7 @@ std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_requir
ttnn::Tensor input_tensor = input_tensor_; // tensor to return
bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_);

auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard] =
auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height] =
get_conv_padded_input_shape_and_mem_config(
device,
input_tensor_,
Expand Down Expand Up @@ -620,7 +631,7 @@ std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_requir
}
}
}
return {input_tensor, parallel_config, needs_shard_or_reshard};
return {input_tensor, parallel_config, needs_shard_or_reshard, use_non_tile_height};
}

void validate_weight_and_bias_tensors(
Expand Down Expand Up @@ -817,7 +828,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
}
uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
uint32_t output_width = ((input_width - kernel_size[1] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;
auto [input_tensor_post_tm, parallel_config, tensor_manipulated] = shard_or_reshard_tensor_if_required(
auto [input_tensor_post_tm, parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required(
device,
input_tensor,
conv_config,
Expand All @@ -841,13 +852,14 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
}
conv_config.deallocate_activation = true;
}
uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1;
auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config(
ttnn::Shape(std::array<uint32_t, 4>{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}),
parallel_config,
32);
parallel_config, round_up_size);
auto opt_conv_op_parallel_config = determine_conv_op_parallel_config_from_conv_output_mem_config(
conv_out_memory_config, get_num_cores_nhw_from_parallel_config(parallel_config),
get_num_cores_channels_from_parallel_config(parallel_config));
TT_ASSERT(use_non_tile_height || conv_out_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED || opt_conv_op_parallel_config.per_core_out_matrix_height % 32 == 0);
auto opt_conv_op_block_config = determine_per_core_conv_block_config(
parallel_config,
opt_conv_op_parallel_config,
Expand Down Expand Up @@ -910,7 +922,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
.dilation_hw = {dilation[0], dilation[1]},
.num_cores_nhw = opt_conv_op_parallel_config.num_cores_nhw,
.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid,
.snap_to_tile = true
.snap_to_tile = !use_non_tile_height,
};

bool bypass_halo = (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED &&
Expand Down Expand Up @@ -952,7 +964,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
false,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
0,
input_tensor_post_tm.memory_config());
input_tensor_post_tm.memory_config(),
!use_non_tile_height);
if (conv_config.deallocate_activation) {
ttnn::operations::core::deallocate(input_tensor_post_tm);
}
Expand Down Expand Up @@ -982,7 +995,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
compute_kernel_config,
conv_config.enable_act_double_buffer,
conv_config.enable_split_reader,
conv_config.enable_subblock_padding);
conv_config.enable_subblock_padding,
use_non_tile_height);
ttnn::operations::core::deallocate(halo_output);

if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) {
Expand Down Expand Up @@ -1053,7 +1067,7 @@ template ParallelConfig determine_parallel_config<MeshDevice>(
ShardOrientation block_shard_orientation,
bool is_out_tiled);

template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_and_mem_config<Device>(
template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_shape_and_mem_config<Device>(
Device* device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand All @@ -1070,8 +1084,8 @@ template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input
uint32_t input_width,
uint32_t groups);

template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_and_mem_config<MeshDevice>(
MeshDevice* device,
template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_shape_and_mem_config<MeshDevice>(
MeshDevice * device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
uint32_t batch_size,
Expand All @@ -1087,7 +1101,7 @@ template std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input
uint32_t input_width,
uint32_t groups);

template std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_required<Device>(
template std::tuple<ttnn::Tensor, ParallelConfig, bool, bool> shard_or_reshard_tensor_if_required<Device>(
Device* device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand All @@ -1104,8 +1118,8 @@ template std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_
uint32_t input_width,
uint32_t groups);

template std::tuple<ttnn::Tensor, ParallelConfig, bool> shard_or_reshard_tensor_if_required<MeshDevice>(
MeshDevice* device,
template std::tuple<ttnn::Tensor, ParallelConfig, bool, bool> shard_or_reshard_tensor_if_required<MeshDevice>(
MeshDevice * device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
uint32_t batch_size,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ std::pair<uint32_t, uint32_t> determine_largest_subblock_size(uint32_t block_hei
OptimizedConvBlockConfig determine_per_core_conv_block_config(const sliding_window::ParallelConfig& parallel_config, const OptimizedConvParallelizationConfig& conv_op_parallel_config, uint32_t padded_in_channels, uint32_t act_block_h_override, uint32_t window_w, bool fp32_accum, bool use_shallow_conv_variant);

template<typename T>
std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_and_mem_config(
std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> get_conv_padded_input_shape_and_mem_config(
T * device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand All @@ -144,7 +144,7 @@ std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> get_conv_padded_input_shape_an
uint32_t groups);

template<typename T>
std::tuple<ttnn::Tensor, sliding_window::ParallelConfig, bool> shard_or_reshard_tensor_if_required(
std::tuple<ttnn::Tensor, sliding_window::ParallelConfig, bool, bool> shard_or_reshard_tensor_if_required(
T * device,
const ttnn::Tensor& input_tensor_,
const Conv2dConfig& conv_config,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void py_bind_conv2d(py::module& module) {
std::array<uint32_t, 2> dilation,
uint32_t weights_width,
uint32_t input_width,
uint32_t groups) -> std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> {
uint32_t groups) -> std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> {
return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config<ttnn::Device>(
device,
input_tensor,
Expand Down Expand Up @@ -187,7 +187,7 @@ void py_bind_conv2d(py::module& module) {
std::array<uint32_t, 2> dilation,
uint32_t weights_width,
uint32_t input_width,
uint32_t groups) -> std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool> {
uint32_t groups) -> std::tuple<ttnn::Shape, ttnn::MemoryConfig, bool, bool> {
return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config<MeshDevice>(
device,
input_tensor,
Expand Down
Loading

0 comments on commit cb5a429

Please sign in to comment.