Skip to content

Commit

Permalink
#11512: Include both gen_sharded_spec_unary functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amalbasaTT committed Dec 3, 2024
1 parent a100c39 commit c2dfe37
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 16 deletions.
115 changes: 111 additions & 4 deletions tests/sweep_framework/sweep_utils/sharding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,117 @@ def roundup(a, b):
return result


def gen_sharded_spec_unary(num_shapes, max_tensor_size=4 * 1024 * 1024, layouts=["TILE_LAYOUT", "ROW_MAJOR_LAYOUT"]):
# device.compute_with_storage_grid_size()
y = 8
x = 8

# ["BLOCK", "WIDTH", "HEIGHT", "tensor_wh"]
sharding_strategy_list = ["BLOCK", "WIDTH", "HEIGHT", "TENSOR_WH"]
shard_orientation_list = ["COL_MAJOR", "ROW_MAJOR"]
spec_list = []

for sharding_strategy, shard_orientation, rank, layout in itertools.product(
sharding_strategy_list, shard_orientation_list, [4, 3, 2], layouts
):
for _ in range(num_shapes):
if sharding_strategy == "TENSOR_WH":
# Gets stuck:
# X 8 Y 8 input_shape [1, 17792, 8] DataType.BFLOAT8_B Layout.TILE ShardStrategy.BLOCK ShardOrientation.COL_MAJOR tensor_hw_as_shard_shape True

if layout == "TILE_LAYOUT":
# In shard mode ShardMode::PHYSICAL, physical shard shape {12, 13312} is not compatible with alignment Alignment([32, 32])!
min_shard_size_x = 32
min_shard_size_y = 32
else: # if layout == "ROW_MAJOR_LAYOUT":
# Shard Size must be multiple of input_tile_size (width * height is multiple of 1024)
min_shard_size_x = random.choice([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024])
min_shard_size_y = 1024 // min_shard_size_x

rest_volume = random.randint(1, max_tensor_size // (min_shard_size_x * min_shard_size_y * x * y))
input_shape = random.choice(_gen_reshape_args_from_volume(rest_volume, step=1, out_dims=rank))
input_shape = list(input_shape["reshape_dims"])
input_shape[-2] = input_shape[-2] * min_shard_size_x
input_shape[-1] = input_shape[-1] * min_shard_size_y

# Shard width should be multiple of 16 to satisfy L1 alignment (width = multiple 8 for bfloat16)
while input_shape[-1] % 16 != 0:
input_shape[-1] *= 2
input_shape[-2] //= 2

if shard_orientation == "COL_MAJOR":
tmp = input_shape[-2]
input_shape[-2] = input_shape[-1]
input_shape[-1] = tmp

elif sharding_strategy == "BLOCK":
min_shard_size_y = 32 * y
min_shard_size_x = 32 * x

rest_volume = random.randint(1, max_tensor_size // (min_shard_size_x * min_shard_size_y))
physical_shape = random.choice(_gen_reshape_args_from_volume(rest_volume, step=1, out_dims=2))
physical_shape = list(physical_shape["reshape_dims"])
physical_shape[1] *= min_shard_size_y
physical_shape[0] *= min_shard_size_x

input_shape = random.choice(_gen_reshape_args_from_volume(physical_shape[0], step=1, out_dims=rank - 1))
input_shape = list(input_shape["reshape_dims"])
input_shape.append(physical_shape[1])

elif sharding_strategy == "WIDTH" or sharding_strategy == "HEIGHT":
# if shard_width % total_cores != 0: raise RuntimeError("Invalid sharding core_grid")
# Shard Size must be multiple of input_tile_size

if layout == "TILE_LAYOUT":
# In shard mode ShardMode::PHYSICAL, physical shard shape {12, 13312} is not compatible with alignment Alignment([32, 32])!
min_shard_size_x = 32
min_shard_size_y = 32 * x * y
else: # if layout == "ROW_MAJOR_LAYOUT":
# Shard Size must be multiple of input_tile_size
# Shard width should be multiple of 16 to satisfy L1 alignment
mul_32_y = random.choice([16, 32, 64, 128, 256, 512, 1024])
mul_32_x = 1024 // mul_32_y

if sharding_strategy == "HEIGHT":
# Shard width should be multiple of 16 to satisfy L1 alignment
while mul_32_x % 16 != 0:
mul_32_x *= 2
mul_32_y //= 2

min_shard_size_x = mul_32_x
min_shard_size_y = mul_32_y * x * y

rest_volume = random.randint(1, max_tensor_size // (min_shard_size_x * min_shard_size_y))
input_shape = random.choice(_gen_reshape_args_from_volume(rest_volume, step=1, out_dims=rank))
input_shape = list(input_shape["reshape_dims"])
input_shape[-2] = input_shape[-2] * min_shard_size_x
input_shape[-1] = input_shape[-1] * min_shard_size_y

if sharding_strategy == "HEIGHT":
tmp = input_shape[-2]
input_shape[-2] = input_shape[-1]
input_shape[-1] = tmp

# print(input_shape)

spec_list.append(
{
"input_shape": input_shape,
"core_grid_size": (y, x),
"sharding_strategy": sharding_strategy,
"shard_orientation": shard_orientation,
"shard_height_mul_of_32": False,
"input_layout": layout,
}
)

return spec_list


def gen_sharded_spec_unary_2(
num_shapes,
layouts=["ROW_MAJOR_LAYOUT", "TILED_LAYOUT"],
max_tensor_size_per_core=256 * 256,
layouts=["ROW_MAJOR_LAYOUT", "TILE_LAYOUT"],
):
sharding_strategy_list = ["HEIGHT", "WIDTH", "BLOCK", "TENSOR_HW"]
shard_orientation_list = ["COL_MAJOR", "ROW_MAJOR"]
Expand All @@ -47,7 +154,7 @@ def gen_sharded_spec_unary_2(
x = random.randint(1, X)
max_tensor_size = y * x * max_tensor_size_per_core
if sharding_strategy == "TENSOR_HW":
if input_layout == "TILED_LAYOUT":
if input_layout == "TILE_LAYOUT":
min_tensor_height = 32
min_tensor_width = 32
max_tensor_height = int(math.sqrt(max_tensor_size_per_core))
Expand Down Expand Up @@ -197,7 +304,7 @@ def parse_sharding_spec(input_spec):
input_layout = input_spec["input_layout"]

assert sharding_strategy in ["HEIGHT", "WIDTH", "BLOCK", "TENSOR_HW"]
assert input_layout in ["TILED_LAYOUT", "ROW_MAJOR_LAYOUT"]
assert input_layout in ["TILE_LAYOUT", "ROW_MAJOR_LAYOUT"]

tensor_hw_as_shard_shape = False

Expand All @@ -216,7 +323,7 @@ def parse_sharding_spec(input_spec):
else:
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if input_layout == "TILED_LAYOUT":
if input_layout == "TILE_LAYOUT":
input_layout = ttnn.TILE_LAYOUT
else:
input_layout = ttnn.ROW_MAJOR_LAYOUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.sweep_framework.sweep_utils.sharding_utils import (
gen_sharded_spec_unary_2,
gen_sharded_spec_unary,
parse_sharding_spec,
invalidate_vector_sharding,
roundup,
Expand All @@ -35,7 +35,7 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_spec": gen_sharded_spec_unary_2(16),
"input_spec": gen_sharded_spec_unary(16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
},
}
Expand Down

0 comments on commit c2dfe37

Please sign in to comment.