Skip to content

Commit

Permalink
#7040: Revert "#8260: Revert "#8260: reshard uneven shard" because it…
Browse files Browse the repository at this point in the history
… breaks perf (#8500)"

This reverts commit 3716890.
  • Loading branch information
ntarafdar committed May 16, 2024
1 parent 44e0090 commit 7bf2680
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 132 deletions.
2 changes: 1 addition & 1 deletion models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ def get_debug_tensor(num_pages_width, num_pages_height, dtype, page_width=32, pa
tile_row = None
for col_idx in range(0, int(num_pages_width)):
tile_idx = col_idx + num_pages_width * row_idx
tile = torch.full((1, 1, page_width, page_height), tile_idx + 1, dtype=dtype)
tile = torch.full((1, 1, page_height, page_width), tile_idx, dtype=dtype)
if tile_row == None:
tile_row = tile
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,25 @@

from enum import Enum

from models.utility_functions import skip_for_wormhole_b0, skip_for_grayskull
from models.utility_functions import skip_for_grayskull, get_debug_tensor


tt_dtype_to_torch_dtype = {
ttl.tensor.DataType.UINT32: torch.int32,
ttl.tensor.DataType.UINT16: torch.int16,
ttl.tensor.DataType.BFLOAT16: torch.bfloat16,
ttl.tensor.DataType.BFLOAT8_B: torch.float,
}
TILE_WIDTH = 32
TILE_HEIGHT = 32


def get_tensor(shape, dtype):
if dtype in {torch.int16, torch.int32}:
torch_tensor = torch.randint(0, 1024, shape, dtype=dtype)
else:
torch_tensor = torch.rand(shape, dtype=dtype)
return torch_tensor


def run_reshard_test(
Expand All @@ -31,10 +49,16 @@ def run_reshard_test(
output_sharding_scheme,
tt_dtype,
):
full_grid = device.compute_with_storage_grid_size()

input_shard_grid_set = set()
for _input_shard_grid in input_shard_grid:
compute_grid_start = ttl.tensor.CoreCoord(_input_shard_grid[0][0], _input_shard_grid[0][1])
compute_grid_end = ttl.tensor.CoreCoord(_input_shard_grid[1][0], _input_shard_grid[1][1])
if compute_grid_start.x >= full_grid.x or compute_grid_start.y >= full_grid.y:
pytest.skip("Illegal input core_grid")
if compute_grid_end.x >= full_grid.x or compute_grid_end.y >= full_grid.y:
pytest.skip("Illegal input core_grid")
input_shard_grid_set.add(ttl.tensor.CoreRange(compute_grid_start, compute_grid_end))

input_shard_grid = ttl.tensor.CoreRangeSet(input_shard_grid_set)
Expand All @@ -43,6 +67,10 @@ def run_reshard_test(
for _output_shard_grid in output_shard_grid:
compute_grid_start = ttl.tensor.CoreCoord(_output_shard_grid[0][0], _output_shard_grid[0][1])
compute_grid_end = ttl.tensor.CoreCoord(_output_shard_grid[1][0], _output_shard_grid[1][1])
if compute_grid_start.x >= full_grid.x or compute_grid_start.y >= full_grid.y:
pytest.skip("Illegal output core_grid")
if compute_grid_end.x >= full_grid.x or compute_grid_end.y >= full_grid.y:
pytest.skip("Illegal output core_grid")
output_shard_grid_set.add(ttl.tensor.CoreRange(compute_grid_start, compute_grid_end))

output_shard_grid = ttl.tensor.CoreRangeSet(output_shard_grid_set)
Expand All @@ -56,7 +84,27 @@ def run_reshard_test(
memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=ttl.tensor.BufferType.DRAM,
)
torch_tensor = torch.randn(input_shape).bfloat16()
debug = True
dtype = tt_dtype_to_torch_dtype[tt_dtype]
if debug:
if input_layout == ttl.tensor.Layout.TILE:
num_pages_height = (input_shape[0] * input_shape[1] * input_shape[2]) / 32
num_pages_width = input_shape[3] / 32
page_height = 32
page_width = 32
else:
page_width_input = input_shard_shape[1]
page_width_output = output_shard_shape[1]
page_height = 1
page_width = int(math.gcd(page_width_input, page_width_output))
num_pages_height = int(input_shape[0] * input_shape[1] * input_shape[2])
num_pages_width = int(input_shape[3] / page_width)
torch_tensor = get_debug_tensor(
num_pages_width, num_pages_height, dtype, page_width=page_width, page_height=page_height
)
else:
torch_tensor = get_tensor(input_shape, dtype)

tt_tensor_sharded = ttl.tensor.Tensor(torch_tensor, tt_dtype).to(input_layout)
tt_tensor_sharded = tt_tensor_sharded.to(device, dram_memory_config)
tt_tensor_sharded = ttl.tensor.interleaved_to_sharded(
Expand All @@ -81,7 +129,6 @@ def run_reshard_test(
return torch_tensor, torch_tensor_after_round_trip


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme",
[
Expand Down Expand Up @@ -157,6 +204,18 @@ def run_reshard_test(
ttl.tensor.ShardOrientation.COL_MAJOR,
ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED,
),
(
[1, 1, 160, 64],
ttl.tensor.Layout.TILE,
[[(0, 0), (0, 4)]],
(32, 64),
ttl.tensor.ShardOrientation.ROW_MAJOR,
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
[[(0, 0), (1, 1)]],
(96, 32),
ttl.tensor.ShardOrientation.COL_MAJOR,
ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED,
),
],
)
@pytest.mark.parametrize("tt_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
Expand Down Expand Up @@ -257,14 +316,14 @@ def test_reshard_rn50(
"input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme",
[
(
[1, 1, 32, 6272],
[1, 1, 160, 64],
ttl.tensor.Layout.TILE,
[[(0, 0), (6, 6)]],
(32, 128),
[[(0, 0), (0, 4)]],
(32, 64),
ttl.tensor.ShardOrientation.ROW_MAJOR,
ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED,
[[(0, 0), (0, 6)]],
(32, 1024),
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
[[(0, 0), (1, 1)]],
(96, 32),
ttl.tensor.ShardOrientation.COL_MAJOR,
ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED,
),
Expand Down
27 changes: 22 additions & 5 deletions tests/ttnn/unit_tests/operations/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@
None,
None,
),
(
160,
64,
ttnn.TILE_LAYOUT,
dict(
core_grid=ttnn.CoreGrid(y=5, x=1),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
),
dict(
core_grid=ttnn.CoreGrid(y=2, x=2),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.COL_MAJOR,
),
(32, 64),
(32, 96),
),
],
)
def test_reshard(
Expand All @@ -193,10 +210,11 @@ def test_reshard(
input_override,
output_override,
):
if device.core_grid.y < input_sharded_memory_config_args["core_grid"].y:
pytest.skip()
if device.core_grid.y < output_sharded_memory_config_args["core_grid"].y:
pytest.skip()
if isinstance(input_sharded_memory_config_args["core_grid"], (ttnn.CoreGrid)):
if device.core_grid.y < input_sharded_memory_config_args["core_grid"].y:
pytest.skip()
if device.core_grid.y < output_sharded_memory_config_args["core_grid"].y:
pytest.skip()
input_shape = [1, 1, input_height, input_width]

torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
Expand All @@ -217,7 +235,6 @@ def test_reshard(
output_shard_memory_config = ttnn.create_sharded_memory_config(
output_override, **output_sharded_memory_config_args, use_height_and_width_as_shard_shape=True
)

# interleaved_to_sharded
sharded_input_tensor = ttnn.to_memory_config(interleaved_input_tensor, input_shard_memory_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,28 @@ void kernel_main() {
const uint32_t start_y = get_arg_val<uint32_t>(y_offset + start_y_index);

const uint32_t stride_data_offset = get_arg_val<uint32_t>(arg_index++);
const uint32_t stride_size_num_strides = get_arg_val<uint32_t>(arg_index++);
const uint32_t num_strides = ((stride_size_num_strides) & mask_short);
const uint32_t stride_size_num_strides_skip = get_arg_val<uint32_t>(arg_index++);
const uint32_t num_strides = ((stride_size_num_strides_skip) & mask_short) >> 8;
const bool skip = (((stride_size_num_strides_skip) & mask_byte) == 1);


const uint32_t stride_data = ((stride_data_offset >> 16)) * page_size;
const uint32_t offset = ((stride_data_offset) & mask_short) * page_size;
const uint32_t stride_size = ((stride_size_num_strides >> 16)) * page_size;
const uint32_t num_pages_per_stride = (stride_size_num_strides_skip >> 16);
const uint32_t stride_size = num_pages_per_stride * page_size;

uint32_t addr_offset = offset;
uint32_t core_id_x_index = start_x_index;
uint32_t core_id_y_index = start_y_index;

for(uint32_t stride_idx = 0; stride_idx < num_strides; stride_idx++) {

uint32_t core_id_x = get_arg_val<uint32_t>(core_id_x_index);
uint32_t core_id_y = get_arg_val<uint32_t>(y_offset + core_id_y_index);
uint64_t noc_address = get_noc_addr(core_id_x, core_id_y,
input_shard_addr + addr_offset);
noc_async_read(noc_address, l1_write_addr, stride_size);
if(!skip) {
uint32_t core_id_x = get_arg_val<uint32_t>(core_id_x_index);
uint32_t core_id_y = get_arg_val<uint32_t>(y_offset + core_id_y_index);
uint64_t noc_address = get_noc_addr(core_id_x, core_id_y,
input_shard_addr + addr_offset);
noc_async_read(noc_address, l1_write_addr, stride_size);
}
l1_write_addr+=stride_size;
if(stride_x == 0 and stride_y == 0) {
addr_offset += (stride_data + stride_size);
Expand Down
Loading

0 comments on commit 7bf2680

Please sign in to comment.