Skip to content

Commit

Permalink
Fix ttnn.concat golden function when groups > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Jan 9, 2025
1 parent 71d43c2 commit 9bb508a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 37 deletions.
36 changes: 1 addition & 35 deletions tests/ttnn/unit_tests/operations/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,40 +207,6 @@ def _gen_inputs(input_specs):
assert_with_pcc(torch_output_tensor_2, output_2)


def grouped_concat(activations, residuals, groups):
"""
Concatenate activations and residuals with flexible interleaving based on groups.
Args:
activations (torch.Tensor): Activation tensor with shape [N, H, W, C].
residuals (torch.Tensor): Residual tensor with shape [N, H, W, C].
groups (int): Number of groups to split channels into.
Returns:
torch.Tensor: Concatenated tensor with interleaved groups.
"""
assert (
activations.shape[:-1] == residuals.shape[:-1]
), "Activations and residuals must have the same shape in all dims but -1"

N, H, W, activation_channels = activations.shape
assert activation_channels % groups == 0, "Channel count must be divisible by the number of groups"

N, H, W, residual_channels = residuals.shape
assert residual_channels % groups == 0, "Channel count must be divisible by the number of groups"

act_groups = activations.view(N, H, W, groups, activation_channels // groups)
res_groups = residuals.view(N, H, W, groups, residual_channels // groups)

# Interleave activations and residuals along the channel axis
interleaved = torch.cat([act_groups, res_groups], dim=-1) # Shape: [N, H, W, groups, 2 * group_size]

# Reshape to combine groups and channels correctly
interleaved = interleaved.permute(0, 1, 2, 3, 4).reshape(N, H, W, residual_channels + activation_channels)

return interleaved


@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize("groups", [1, 2, 4])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -276,7 +242,7 @@ def grouped_concat(activations, residuals, groups):
def test_sharded_concat_with_groups(device, input_shapes, output_shape, dim, groups, core_grid):
torch_input_tensors = [torch.full(shapes, idx + 1, dtype=torch.bfloat16) for idx, shapes in enumerate(input_shapes)]

expected = grouped_concat(torch_input_tensors[0], torch_input_tensors[1], groups)
expected = ttnn.concat.golden_function(torch_input_tensors, dim, groups)

sharded_memory_configs = [
ttnn.create_sharded_memory_config(
Expand Down
38 changes: 36 additions & 2 deletions ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,44 @@ def _golden_function(input_tensor, dims, **_):
ttnn.attach_golden_function(ttnn.permute, golden_function=_golden_function)


def _golden_function(tensors, dim=0, **_):
def _golden_function(tensors, dim=0, groups=1, **_):
import torch

return torch.concat(tensors, dim)
def grouped_concat(activations, residuals, groups):
"""
Concatenate activations and residuals with flexible interleaving based on groups.
Args:
activations (torch.Tensor): Activation tensor with shape [N, H, W, C].
residuals (torch.Tensor): Residual tensor with shape [N, H, W, C].
groups (int): Number of groups to split channels into.
Returns:
torch.Tensor: Concatenated tensor with interleaved groups.
"""

assert (
activations.shape[:-1] == residuals.shape[:-1]
), "Activations and residuals must have the same shape in all dims but -1"

N, H, W, activation_channels = activations.shape
assert activation_channels % groups == 0, "Channel count must be divisible by the number of groups"

N, H, W, residual_channels = residuals.shape
assert residual_channels % groups == 0, "Channel count must be divisible by the number of groups"

act_groups = activations.view(N, H, W, groups, activation_channels // groups)
res_groups = residuals.view(N, H, W, groups, residual_channels // groups)

# Interleave activations and residuals along the channel axis
interleaved = torch.cat([act_groups, res_groups], dim=-1) # Shape: [N, H, W, groups, 2 * group_size]

# Reshape to combine groups and channels correctly
interleaved = interleaved.permute(0, 1, 2, 3, 4).reshape(N, H, W, residual_channels + activation_channels)

return interleaved

return grouped_concat(tensors[0], tensors[1], groups=groups) if groups > 1 else torch.concat(tensors, dim)


ttnn.attach_golden_function(
Expand Down

0 comments on commit 9bb508a

Please sign in to comment.