Skip to content

Commit

Permalink
#12883: Add initial unit tests for N300 (#12922)
Browse files Browse the repository at this point in the history
* #12883: Add initial unit tests for N300

* #12883: Commonize functions

* #12883: Add wrapper functions
  • Loading branch information
Aswinmcw authored Sep 23, 2024
1 parent 5448c47 commit 96fd0df
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 37 deletions.
247 changes: 217 additions & 30 deletions tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Invalid combination"

if num_devices < 2:
return True, "Requires multiple devices to run"
elif num_devices == 2 and num_links <= 2:
return True, "Not enough links to run"

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
return True, "Unsupported test case"

Expand Down Expand Up @@ -59,6 +54,19 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
return False, ""


def is_unsupported_case_t3k(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
if num_devices < 2:
return True, "Requires multiple devices to run"
elif num_devices == 2 and num_links <= 2:
return True, "Not enough links to run"

return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout)


def is_unsupported_case_n300(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout)


def run_with_trace(
t3k_mesh_device,
devices,
Expand Down Expand Up @@ -110,7 +118,7 @@ def run_with_trace(
return tt_out_tensor


def run_all_gather_on_t3000_impl(
def run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand All @@ -122,12 +130,10 @@ def run_all_gather_on_t3000_impl(
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 8:
pytest.skip("Not T3000!")

# Use Async mode based on test input config
for device in all_devices:
device.enable_async(enable_async)
Expand All @@ -136,13 +142,6 @@ def run_all_gather_on_t3000_impl(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

devices = get_devices_for_t3000(all_devices, num_devices)
# for device in devices:
# device.disable_and_clear_program_cache()

Expand Down Expand Up @@ -175,6 +174,92 @@ def run_all_gather_on_t3000_impl(
assert eq, f"{i} FAILED: {output}"


def run_all_gather_on_n300_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 2:
pytest.skip("Not N300!")

(is_known_failure, message) = is_unsupported_case_n300(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

return run_all_gather_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
all_devices,
num_iters,
enable_async,
)


def run_all_gather_on_t3000_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 8:
pytest.skip("Not T3000!")

(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

devices = get_devices_for_t3000(all_devices, num_devices)

return run_all_gather_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
num_iters,
enable_async,
)


def run_all_gather_on_t3000_impl_tight_loop(
all_devices,
num_devices,
Expand Down Expand Up @@ -440,7 +525,7 @@ def run_line_all_gather(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
Expand Down Expand Up @@ -494,7 +579,7 @@ def run_line_all_gather_deprecated(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
Expand Down Expand Up @@ -929,18 +1014,13 @@ def run_all_gather_sharded(
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")

for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)

numel = input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] * num_devices
unchunked_input_shape = list(input_shape)
unchunked_input_shape[dim] *= num_devices
Expand All @@ -962,7 +1042,6 @@ def run_all_gather_sharded(
unchunked_input_tensor = unchunked_input_tensor.bfloat16()

input_tensors = torch.chunk(unchunked_input_tensor, num_devices, dim)
devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

# num_cores =
# compute_grid_size = devices[0].compute_with_storage_grid_size()
Expand Down Expand Up @@ -1079,6 +1158,114 @@ def run_all_gather_sharded(
assert all_eq, f"{i} FAILED: {output}"


def run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")

for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)

devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

return run_all_gather_sharded(
t3k_mesh_device,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
enable_async,
n_worker,
n_buffer,
num_iter,
trace_mode,
)


def run_all_gather_sharded_n300(
all_devices,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(all_devices) != 2:
pytest.skip("Not N300!")

for device in all_devices:
device.enable_async(enable_async)

return run_all_gather_sharded(
all_devices,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
all_devices,
enable_async,
n_worker,
n_buffer,
num_iter,
trace_mode,
)


# @pytest.mark.parametrize("num_devices", [4, 8])
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize("num_devices", [8])
Expand Down Expand Up @@ -1146,7 +1333,7 @@ def test_all_gather_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1236,7 +1423,7 @@ def test_all_gather_height_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1320,7 +1507,7 @@ def test_all_gather_block_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1412,7 +1599,7 @@ def test_line_all_gather_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1576,7 +1763,7 @@ def test_sharded_all_gather_nightly(
all_gather_operation,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down
Loading

0 comments on commit 96fd0df

Please sign in to comment.