Skip to content

Commit

Permalink
#13593: embedding fuzed tilized reconditioned, one test made reproduc…
Browse files Browse the repository at this point in the history
…able, fixed index within embedding op
  • Loading branch information
yugi957 committed Oct 15, 2024
1 parent aa4c83c commit e35dfe1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 25 deletions.
24 changes: 13 additions & 11 deletions tests/sweep_framework/one_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, Tuple

import torch

torch.manual_seed(42)
import random
import ttnn

Expand Down Expand Up @@ -66,14 +68,14 @@ def run(
print(torch_output_tensor.shape)

# Loop through rows of tensor and print first false element, print PASS if row is correct
for i in range(check_tensor.shape[1]):
if False in check_tensor[0, i]:
print(f"Row {i}: FAIL")
print(torch_output_tensor[0, i])
print(ttnn_output_tensor[0, i])
break
else:
print(f"Row {i}: PASS")
# for i in range(check_tensor.shape[1]):
# if False in check_tensor[0, i]:
# print(f"Row {i}: FAIL")
# print(torch_output_tensor[0, i])
# print(ttnn_output_tensor[0, i])
# break
# else:
# print(f"Row {i}: PASS")

# Compare the results and return performance and accuracy check
result = check_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)
Expand All @@ -93,11 +95,11 @@ def get_devices(test_module):

# pcc errors
# embedding_specs = {'weight_shape': [2, 768], 'indices_shape': [1, 8]}
# embedding_specs = {'weight_shape': [30528, 768], 'indices_shape': [1, 8], 'padding_idx': 0}
# embedding_specs = {'weight_shape': [400, 10], 'indices_shape': [1, 24]}
embedding_specs = {"weight_shape": [30528, 768], "indices_shape": [1, 32], "padding_idx": 0}
# embedding_specs = {'weight_shape': [400, 10], 'indices_shape': [1, 24]} # Should output 1, 24, 10

# page/buffer errors
embedding_specs = {"weight_shape": [77, 512], "indices_shape": [1, 7]}
# embedding_specs = {"weight_shape": [77, 512], "indices_shape": [1, 8]} #SHould output 1, 7, 512

dtype = ttnn.bfloat16
layout = ttnn.TILE_LAYOUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,21 @@ void kernel_main() {
uint64_t src_noc_addr;
uint32_t token = input_l1_ptr[k];
#if defined PADDED
if (token == pad_token) {
src_noc_addr = pad_noc_addr;
} else {
src_noc_addr = get_noc_addr(token, weights);
}
#elif defined BINARY
if (token == 0) {
src_noc_addr = zero_noc_addr;
} else {
src_noc_addr = one_noc_addr;
}
if (token == pad_token) {
src_noc_addr = pad_noc_addr;
} else {
src_noc_addr = get_noc_addr(token, weights);
}
#elif defined BINARY
if (token == 0) {
src_noc_addr = zero_noc_addr;
} else {
src_noc_addr = one_noc_addr;
}
#else
#if defined BFP16
union { float f; uint32_t u; } u;
u.u = (uint32_t)input_l1_ptr[token_idx] << 16;
u.u = (uint32_t)input_l1_ptr[k] << 16;
uint32_t token_casted = static_cast<uint32_t>(u.f);
src_noc_addr = get_noc_addr(token_casted, weights);
#else
Expand Down
9 changes: 7 additions & 2 deletions ttnn/cpp/ttnn/operations/embedding/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@ struct EmbeddingOperation {
auto input_tensor =
ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});

bool tilized = false;
bool fuzed_tilized = layout == ttnn::TILE_LAYOUT;

// If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized
if(!fuzed_tilized || input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT) fuzed_tilized = false;
if(!fuzed_tilized || weight.get_legacy_shape()[-1] % TILE_WIDTH) fuzed_tilized = false;

auto embeddings = operation::run(
Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = tilized,
.tilized = fuzed_tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = dtype.value_or(weight.get_dtype())},
Expand Down

0 comments on commit e35dfe1

Please sign in to comment.