Skip to content

Commit

Permalink
#14915: re-ordering code in test_embedding after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
yugi957 committed Dec 23, 2024
1 parent 9af47e4 commit facba13
Showing 1 changed file with 26 additions and 44 deletions.
70 changes: 26 additions & 44 deletions tests/ttnn/unit_tests/operations/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,32 @@ def reverse_embedding_output(output_tensor, weights_tensor):
return reversed_indices


def create_tile_tensor(height, width, tile_size):
"""
Creates a 2D tensor where each element represents the tile it belongs to.
Parameters:
height (int): The height of the tensor (number of rows).
width (int): The width of the tensor (number of columns).
tile_size (int): The size of each square tile (tile_size x tile_size).
Returns:
torch.Tensor: A 2D tensor with tile indices.
"""
# Calculate the number of tiles in each dimension
tiles_per_row = (width + tile_size - 1) // tile_size
tiles_per_col = (height + tile_size - 1) // tile_size

# Create row and column indices
row_indices = torch.arange(height).unsqueeze(1) // tile_size
col_indices = torch.arange(width).unsqueeze(0) // tile_size

# Calculate tile indices
tile_tensor = row_indices * tiles_per_row + col_indices
print(tile_tensor.shape)
return tile_tensor


@pytest.mark.parametrize("batch_size", [1, 8, 9])
@pytest.mark.parametrize("sentence_size", [32, 256, 512])
@pytest.mark.parametrize("hidden_embedding_dim", [768, 4096]) # Bert_Num_Cols_768, Llama_Num_Cols
Expand Down Expand Up @@ -401,47 +427,3 @@ def test_tg_llama_sharded_embedding(
)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(output_tensor, torch_output_tensor[:, 0, :].unsqueeze(1))

def get_random_factors(N):
if N == 0:
raise ValueError("Input number must be non-zero.")

# Get all divisors of N
divisors = [i for i in range(1, abs(N) + 1) if N % i == 0]

# Randomly select a divisor (excluding N itself to ensure two factors)
A = random.choice(divisors[:-1])

# Compute the second factor
B = N // A

return A, B


import torch


def create_tile_tensor(height, width, tile_size):
"""
Creates a 2D tensor where each element represents the tile it belongs to.
Parameters:
height (int): The height of the tensor (number of rows).
width (int): The width of the tensor (number of columns).
tile_size (int): The size of each square tile (tile_size x tile_size).
Returns:
torch.Tensor: A 2D tensor with tile indices.
"""
# Calculate the number of tiles in each dimension
tiles_per_row = (width + tile_size - 1) // tile_size
tiles_per_col = (height + tile_size - 1) // tile_size

# Create row and column indices
row_indices = torch.arange(height).unsqueeze(1) // tile_size
col_indices = torch.arange(width).unsqueeze(0) // tile_size

# Calculate tile indices
tile_tensor = row_indices * tiles_per_row + col_indices
print(tile_tensor.shape)
return tile_tensor

0 comments on commit facba13

Please sign in to comment.