Skip to content

Commit

Permalink
#8361: add reshard sweep test
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed May 10, 2024
1 parent 48afa1c commit ee6f9d6
Showing 1 changed file with 138 additions and 0 deletions.
138 changes: 138 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/reshard_height_width.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc
from models.utility_functions import torch_random
import math

parameters = {
"dtype": [ttnn.int32, ttnn.bfloat16, ttnn.bfloat8_b],
"height": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
"width": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"input_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
"input_num_cores_x": [1, 2, 4, 8],
"input_num_cores_y": [1, 2, 4, 8],
"input_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
"output_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
"output_num_cores_x": [1, 2, 4, 8],
"output_num_cores_y": [1, 2, 4, 8],
"output_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
}


def invalid_shard_spec(
layout,
height,
width,
device,
num_cores_x,
num_cores_y,
shard_strategy,
) -> bool:
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
dim_being_distributed = float(height)
elif shard_strategy == ttnn.ShardStrategy.WIDTH:
dim_being_distributed = float(width)

num_cores = num_cores_x * num_cores_y
size_per_core = math.ceil(dim_being_distributed, num_cores)
if (size_per_core == 0 and layout == ttnn.ROW_MAJOR_LAYOUT) or (size_per_core < 32 and layout == ttnn.TILE_LAYOUT):
return True

full_grid = device.compute_with_storage_grid_size()
if num_cores_x >= full_grid.x or num_cores_y >= full_grid.y:
return True

return False


def skip(
*,
layout,
height,
width,
device,
input_num_cores_x,
input_num_cores_y,
input_shard_strategy,
output_num_cores_x,
output_num_cores_y,
output_shard_strategy,
**_,
) -> Tuple[bool, Optional[str]]:
if invalid_shard_spec(layout, height, width, device, input_num_cores_x, input_num_cores_y, input_shard_strategy):
return True, "Invalid Input Shard Spec"

if invalid_shard_spec(layout, height, width, device, output_num_cores_x, output_num_cores_y, output_shard_strategy):
return True, "Invalid Output Shard Spec"

return False, None


def skip(**_) -> Tuple[bool, Optional[str]]:
return False, None


def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]:
return False, None


def run(
dtype,
height,
width,
layout,
input_shard_orientation,
input_num_cores_x,
input_num_cores_y,
input_shard_strategy,
output_shard_orientation,
output_num_cores_x,
output_num_cores_y,
output_shard_strategy,
*,
device,
) -> Tuple[bool, Optional[str]]:
tensor_shape = [1, 1, height, width]
input_core_grid = ttnn.CoreGrid(y=input_num_cores_y, x=input_num_cores_x)
output_core_grid = ttnn.CoreGrid(y=output_num_cores_y, x=output_num_cores_x)
input_args = dict(
shape=tensor_shape,
core_grid=input_core_grid,
strategy=input_shard_strategy,
orientation=input_shard_orientation,
)
output_args = dict(
shape=tensor_shape,
core_grid=output_core_grid,
strategy=output_shard_strategy,
orientation=output_shard_orientation,
)

torch_input_tensor = torch.randn(tensor_shape, dtype=torch.float32)
interleaved_input_tensor = ttnn.from_torch(
torch_input_tensor, layout=layout, dtype=dtype, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_shard_memory_config = ttnn.create_sharded_memory_config(tensor_shape, **input_args)
output_shard_memory_config = ttnn.create_sharded_memory_config(tensor_shape, **output_args)

# interleaved_to_sharded
sharded_input_tensor = ttnn.to_memory_config(interleaved_input_tensor, input_shard_memory_config)

# reshard
sharded_output_tensor = ttnn.to_memory_config(sharded_input_tensor, output_shard_memory_config)

# sharded_to_interleaved
interleaved_output_tensor = ttnn.to_memory_config(sharded_output_tensor, ttnn.DRAM_MEMORY_CONFIG)

output = ttnn.to_torch(interleaved_output_tensor)

return check_with_pcc(torch_input_tensor, output, 0.999)

0 comments on commit ee6f9d6

Please sign in to comment.