Skip to content

Commit

Permalink
#5560: Add cases with NC dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 21, 2024
1 parent 1336d39 commit ee10ee4
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def run_all_reduce_test(
.to(layout)
.to(mesh_device.get_device(mesh_device.get_device_ids()[i]), mem_config)
)
input_tensor = input_tensor.view(1, -1, input_tensor.shape[2], input_tensor.shape[3])
input_tensors.append(input_tensor)
unchunked_input_tensor = torch.cat(input_tensors)

Expand All @@ -133,8 +134,8 @@ def run_all_reduce_test(

tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh)
logger.info(f"Compare")
golden_canonical_out_tensor = torch.sum(unchunked_input_tensor, 0)

golden_canonical_out_tensor = torch.sum(unchunked_input_tensor, 0, keepdim=True)
golden_canonical_out_tensor = golden_canonical_out_tensor.view(per_chip_output_shape)
# Compare
mismatch = False
for i, t in enumerate(tt_out_tensors):
Expand Down Expand Up @@ -176,6 +177,10 @@ def run_all_reduce_test(
([1, 1, 8192, 32]),
([1, 1, 1024, 32]),
([1, 1, 2048, 32]),
([4, 1, 32, 4096]),
([8, 1, 32, 1024]),
([1, 4, 1024, 32]),
([2, 4, 2048, 32]),
],
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit ee10ee4

Please sign in to comment.