Skip to content

Commit

Permalink
Better parallelization strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavle Josipovic committed Nov 16, 2024
1 parent 8dee2c2 commit c1b3b6a
Show file tree
Hide file tree
Showing 20 changed files with 385 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,37 @@
is_grayskull,
is_wormhole_b0,
_nearest_y,
nearest_y,
pad_and_fold_conv_activation_for_unity_stride,
)
from typing import List
from loguru import logger
from tests.ttnn.utils_for_testing import assert_with_pcc


def get_core_grid_from_num_cores(num_cores: int, grid_rows: int, grid_cols: int):
columns = num_cores // grid_rows
assert columns <= grid_cols, "Not enough cores for specified core grid"
ranges = []
if columns != 0:
ranges.append(
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(grid_rows - 1, columns - 1),
)
)
remainder = num_cores % grid_rows
if remainder != 0:
assert columns + 1 <= grid_cols, "Not enough cores for specified core grid"
ranges.append(
ttnn.CoreRange(
ttnn.CoreCoord(0, columns),
ttnn.CoreCoord(remainder - 1, columns),
)
)
return ttnn.CoreRangeSet({*ranges})


hardcoded_matmul_config_linear = {
8: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(8, 4),
Expand Down Expand Up @@ -632,15 +657,38 @@ def __init__(

conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16)
conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
_, self.override_fold_mem_config, _, _ = ttnn.get_conv_padded_input_shape_and_mem_config(
device=device,
input_tensor=conv_dummy_tensor,
conv_config=self.conv1_config,

parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config(
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
batch_size=self.batch_size,
height=self.conv1_output_height,
width=self.conv1_output_width,
in_channels=self.conv1_input_channels,
out_channels=self.conv1_output_channels,
input_channels=self.conv1_input_channels,
output_height=self.conv1_output_height,
output_width=self.conv1_output_width,
output_channels=self.conv1_output_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)
# / Override compute grid size for Grayskull
# First convs would got to 108 cores by default
# but this would add padding into output tensor
# and reshard that follows first conv fails with padding ATM.
if is_grayskull():
compute_grid = device.compute_with_storage_grid_size()
parallel_config.grid = get_core_grid_from_num_cores(98, compute_grid.x, compute_grid.y)

self.override_fold_mem_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
1,
1,
self.conv1_input_width * self.conv1_input_height * self.batch_size,
nearest_y(self.conv1_input_channels, self.conv1_config.input_channels_alignment),
]
),
parallel_config=parallel_config,
tile_size=32,
)

def __del__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ def __init__(

self.output_height = ttnn.get_conv_output_dim(input_height, 3, self.stride, 1)
self.output_width = ttnn.get_conv_output_dim(input_width, 3, self.stride, 1)
self.shard_layout = (
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if self.in_channels < 320 else ttnn.TensorMemoryLayout.BLOCK_SHARDED
)

self.input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
1,
1,
self.batch_size * self.input_height * self.input_width,
self.out_channels,
]
),
parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config(
shard_layout=self.shard_layout,
batch_size=self.batch_size,
input_channels=self.in_channels,
output_height=self.output_height,
output_width=self.output_width,
output_channels=self.out_channels,
compute_grid_size=self.device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=False,
is_out_tiled=True,
),
tile_size=32,
)

def __call__(
self,
Expand Down Expand Up @@ -104,13 +131,15 @@ def __call__(
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=True,
packer_l1_accum_enabled=False,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if self.in_channels < 320
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=self.shard_layout,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=True,
reshard_if_not_optimal=False,
)

if hidden_states.memory_config() != self.input_memory_config:
hidden_states = ttnn.to_memory_config(hidden_states, self.input_memory_config)

if self.conv_config_override and "act_block_h" in self.conv_config_override:
conv_config.act_block_h_override = self.conv_config_override["act_block_h"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,33 @@ def __init__(
self.conv1_input_width = input_width
self.conv1_in_channels = split_input_channels
self.conv1_out_channels = out_channels
self.conv1_output_height = ttnn.get_conv_output_dim(self.conv1_input_height, 3, 1, 1)
self.conv1_output_width = ttnn.get_conv_output_dim(self.conv1_input_width, 3, 1, 1)
self.conv1_shard_layout = ttnn.TensorMemoryLayout.BLOCK_SHARDED

self.conv1_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
1,
1,
self.batch_size * self.conv1_input_height * self.conv1_input_width,
self.conv1_in_channels,
]
),
parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config(
shard_layout=self.conv1_shard_layout,
batch_size=self.batch_size,
input_channels=self.conv1_in_channels,
output_height=self.conv1_output_height,
output_width=self.conv1_output_width,
output_channels=self.conv1_out_channels,
compute_grid_size=self.device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=False,
is_out_tiled=True,
),
tile_size=32,
)

for i in range(conv1_split_chunks):
self.conv1s_weights.append(ttnn.from_torch(split_weight_tensors[i], ttnn.float32))
Expand Down Expand Up @@ -165,6 +192,29 @@ def __init__(
self.conv2_in_channels = parameters.conv2.weight.shape[1]
self.conv2_out_channels = parameters.conv2.weight.shape[0]
# self.conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)]
self.conv2_input_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
1,
1,
self.batch_size * self.conv2_input_height * self.conv2_input_width,
out_channels,
]
),
parallel_config=ttnn._ttnn.operations.conv.determine_parallel_config(
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
batch_size=self.batch_size,
input_channels=self.conv2_in_channels,
output_height=self.conv2_input_height,
output_width=self.conv2_input_width,
output_channels=self.conv2_out_channels,
compute_grid_size=self.device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=False,
is_out_tiled=True,
),
tile_size=32,
)

self.groups = 32
# if use_in_shortcut:
Expand Down Expand Up @@ -402,12 +452,14 @@ def __call__(
# hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states))
# hidden_states = self.conv1s[0](hidden_states)

hidden_states = ttnn.to_memory_config(hidden_states, self.conv1_input_memory_config)

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=self.conv1_shard_layout,
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=True,
packer_l1_accum_enabled=False,
Expand Down Expand Up @@ -598,6 +650,7 @@ def __call__(
# hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.sharded_to_interleaved(hidden_states, ttnn.L1_MEMORY_CONFIG, hidden_states.dtype)

hidden_states = ttnn.to_memory_config(hidden_states, self.conv2_input_memory_config)
# hidden_states = self.conv2(hidden_states)

conv_config = ttnn.Conv2dConfig(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/yolov4/ttnn/downsample1.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __call__(self, device, input_tensor):
output_tensor = ttnn.to_layout(output_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
output_tensor_left = ttnn.to_layout(output_tensor_left, layout=ttnn.ROW_MAJOR_LAYOUT)
output_sharded_memory_config = ttnn.create_sharded_memory_config(
[512, 128],
[output_tensor.memory_config().shard_spec.shape[0], 2 * output_tensor.memory_config().shard_spec.shape[1]],
core_grid=output_tensor_left.memory_config().shard_spec.grid,
strategy=ttnn.ShardStrategy.HEIGHT,
use_height_and_width_as_shard_shape=True,
Expand Down
29 changes: 28 additions & 1 deletion models/experimental/functional_unet/tt/unet_shallow_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
conv_cache={},
should_reshard=False,
mesh_mapper=None,
conv_override_p_config=False,
):
self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper)
self.conv2 = UNetConv2D(
Expand All @@ -268,8 +269,15 @@ def __init__(
output_channels=self.conv1.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)

if conv_override_p_config:
pconfig_override = pool["parallel_config_override"]
num_cores_nhw = pconfig_override["num_cores_nhw"]
parallel_config.grid = get_core_grid_from_num_cores(num_cores_nhw)

self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
Expand Down Expand Up @@ -304,7 +312,18 @@ def __call__(self, x):

class UNetUpblock:
def __init__(
self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False, mesh_mapper=None
self,
conv1,
bn1,
conv2,
bn2,
conv3,
bn3,
device,
conv_cache={},
should_reshard=False,
mesh_mapper=None,
nhw_core_override=-1,
):
self.device = device
self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper)
Expand All @@ -322,8 +341,13 @@ def __init__(
output_channels=self.conv1.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)

if nhw_core_override != -1:
parallel_config.grid = get_core_grid_from_num_cores(nhw_core_override)

self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
Expand Down Expand Up @@ -413,6 +437,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
conv_override_p_config=True,
)
self.downblock3 = UNetDownblock(
parameters.c3,
Expand Down Expand Up @@ -450,6 +475,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
output_channels=self.bnc.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)
self.bnc_sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
Expand Down Expand Up @@ -500,6 +526,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
nhw_core_override=60,
)
self.upblock4 = UNetUpblock(
parameters.c8,
Expand Down
40 changes: 16 additions & 24 deletions tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,30 +454,22 @@ def test_conv2d_localrun(device, input_spec):
failing_parameters = [
# [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation]
# Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM
[1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6
[1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15
[1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141
[1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171
[1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 173
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 182
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 183
[1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 199
[1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 205
[1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], # 241
[1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1], # 245
[1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], # 247
[1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293
[1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294
[1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347
[1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348
[1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363
[1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 728, False, 1], # 366
[1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 367
[1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 374
[1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 395
[1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 5
[1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 14
[1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 169
[1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 170
[1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 172
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 181
[1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 182
[1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 198
[1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 203
[1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204
[1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 292
[1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], # 366
[1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 373
[1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 374
[1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 394
[1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 395
]


Expand Down
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def run_max_pool(
output_channels=in_c,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=False,
is_out_tiled=False,
)
sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
Expand Down Expand Up @@ -744,6 +745,7 @@ def test_pool_core_nondivis(
output_channels=in_c,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=False,
is_out_tiled=True,
)
sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
Expand Down
Loading

0 comments on commit c1b3b6a

Please sign in to comment.