Skip to content

Commit

Permalink
Embedding RM convertion and fused tilized recondition PR (#14389)
Browse files Browse the repository at this point in the history
### Ticket
#13593

### Problem description
 - Fix Embedding RM conversion, PCC errors were a sweep/untilize issue

### What's changed
 - uint32 untilize provided by Naif's changes
 - Convert inputs to RM for embedding op, recondition fused tilized

### Checklist
- [ ] Post commit CI passes:
https://github.com/tenstorrent/tt-metal/actions/runs/11693157456
- [ ] T3K passes:
https://github.com/tenstorrent/tt-metal/actions/runs/11693180731

---------
  • Loading branch information
yugi957 authored Nov 11, 2024
1 parent 3f12bb9 commit ef71901
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 55 deletions.
53 changes: 53 additions & 0 deletions tests/ttnn/unit_tests/operations/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,56 @@ def test_moe_embedding(
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("batch_size", [1, 8, 9])
@pytest.mark.parametrize("sentence_size", [32, 256, 512])
@pytest.mark.parametrize("hidden_embedding_dim", [768, 4096]) # Bert_Num_Cols_768, Llama_Num_Cols
@pytest.mark.parametrize(
"vocabulary_size", [512, 30522, 2048]
) # Bert_Position_Embeddings_512, Bert_Word_Embeddings_30528, Llama_Position_Embeddings,
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_embedding_tiled_input(
device,
batch_size,
sentence_size,
hidden_embedding_dim,
vocabulary_size,
input_mem_config,
output_mem_config,
layout,
):
torch.manual_seed(1234)

torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size))
torch_weights = torch_random((vocabulary_size, hidden_embedding_dim), -0.1, 0.1, dtype=torch.bfloat16)
# torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights)
torch_embedding = torch.nn.Embedding.from_pretrained(torch_weights)
torch_output_tensor = torch_embedding(torch_input_tensor)

input_tensor = ttnn.to_device(
ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT),
device,
memory_config=input_mem_config,
)
weights = ttnn.to_device(
ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT),
device,
memory_config=input_mem_config,
)

# output_tensor = ttnn.embedding(input_tensor, weights, memory_config=output_mem_config, layout=ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.embedding(
input_tensor,
weights,
embeddings_type=ttnn.EmbeddingsType.GENERIC, # Default embeddings type
dtype=ttnn.bfloat16,
memory_config=output_mem_config, # Default memory config
queue_id=0, # Default queue id
layout=layout,
)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,21 @@ void kernel_main() {
uint64_t src_noc_addr;
uint32_t token = input_l1_ptr[k];
#if defined PADDED
if (token == pad_token) {
src_noc_addr = pad_noc_addr;
} else {
src_noc_addr = get_noc_addr(token, weights);
}
#elif defined BINARY
if (token == 0) {
src_noc_addr = zero_noc_addr;
} else {
src_noc_addr = one_noc_addr;
}
if (token == pad_token) {
src_noc_addr = pad_noc_addr;
} else {
src_noc_addr = get_noc_addr(token, weights);
}
#elif defined BINARY
if (token == 0) {
src_noc_addr = zero_noc_addr;
} else {
src_noc_addr = one_noc_addr;
}
#else
#if defined BFP16
union { float f; uint32_t u; } u;
u.u = (uint32_t)input_l1_ptr[token_idx] << 16;
u.u = (uint32_t)input_l1_ptr[k] << 16;
uint32_t token_casted = static_cast<uint32_t>(u.f);
src_noc_addr = get_noc_addr(token_casted, weights);
#else
Expand Down
85 changes: 85 additions & 0 deletions ttnn/cpp/ttnn/operations/embedding/embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/operations/embedding/device/embedding_device_operation.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"

namespace ttnn::operations::embedding{

ttnn::Tensor EmbeddingOperation::invoke(
uint8_t queue_id,
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token,
const std::optional<ttnn::Layout>& layout,
EmbeddingsType embeddings_type,
const std::optional<const DataType> dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor) {
if (pad_token.has_value()) {
embeddings_type = EmbeddingsType::PADDED;
}
Tensor mutable_input_tensor = input_tensor_arg;
Tensor mutable_weight = weight_arg;

// TODO: Add support for indices tensor in tile layout
// Issue #: 14915
TT_FATAL(input_tensor_arg.get_layout() == ttnn::ROW_MAJOR_LAYOUT, "Indices tensor must be in row major layout.");

if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) {
mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_weight.device());
}
auto hidden_embedding_dim = mutable_weight.get_shape()[-1];
auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(mutable_weight);

auto batch_size = mutable_input_tensor.get_shape()[0];
auto sentence_size = mutable_input_tensor.get_shape()[-1];
auto input_tensor =
ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});

// If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized
bool fused_tilized = false;
if(input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT == 0 &&
weight.get_legacy_shape()[-1] % TILE_WIDTH == 0){
if(layout.has_value()){
if(layout.value() == ttnn::TILE_LAYOUT) fused_tilized = true;
}
else if(weight_arg.get_layout() == ttnn::TILE_LAYOUT){
fused_tilized = true;
}
}

auto embeddings = operation::run(
Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = fused_tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = dtype.value_or(weight.get_dtype())},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(
embeddings, ttnn::Shape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
embeddings = ttnn::to_layout(embeddings, layout.value_or(weight_arg.get_layout()), std::nullopt, std::nullopt, (Device*)nullptr);
return embeddings;
}
ttnn::Tensor EmbeddingOperation::invoke(
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token,
const std::optional<ttnn::Layout>& layout,
EmbeddingsType embeddings_type,
const std::optional<const DataType> dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor
) {
return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor);
}

} // namespace ttnn::operations::embedding
46 changes: 6 additions & 40 deletions ttnn/cpp/ttnn/operations/embedding/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@

#pragma once

#include "ttnn/common/constants.hpp"
#include "ttnn/operations/embedding/device/embedding_device_operation.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/core/core.hpp"

namespace ttnn {

Expand All @@ -17,56 +14,25 @@ namespace operations {
namespace embedding {

struct EmbeddingOperation {
static inline Tensor invoke(
static ttnn::Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token = std::nullopt,
const Layout& layout = ttnn::ROW_MAJOR_LAYOUT,
const std::optional<Layout>& layout = std::nullopt,
EmbeddingsType embeddings_type = EmbeddingsType::GENERIC,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
if (pad_token.has_value()) {
embeddings_type = EmbeddingsType::PADDED;
}

auto hidden_embedding_dim = weight_arg.get_shape()[-1];
auto padded_hidden_embedding_dim = weight_arg.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(weight_arg);

auto batch_size = input_tensor_arg.get_shape()[0];
auto sentence_size = input_tensor_arg.get_shape()[-1];
auto input_tensor =
ttnn::reshape(input_tensor_arg, ttnn::SimpleShape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});

bool tilized = layout == ttnn::TILE_LAYOUT;
auto embeddings = operation::run(
Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = dtype.value_or(weight.get_dtype())},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(
embeddings, ttnn::SimpleShape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
return embeddings;
}

static inline auto invoke(
std::optional<Tensor> optional_output_tensor = std::nullopt);
static ttnn::Tensor invoke(
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token = std::nullopt,
const Layout& layout = ttnn::ROW_MAJOR_LAYOUT,
const std::optional<Layout>& layout = std::nullopt,
EmbeddingsType embeddings_type = EmbeddingsType::GENERIC,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt
) {
return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor);
}
std::optional<Tensor> optional_output_tensor = std::nullopt);
};

} // namespace embedding
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/embedding/embedding_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void py_module(py::module& module) {
Returns:
ttnn.Tensor: the output tensor.
ttnn.Tensor: the output tensor of layout == layout or layout of the weights tensor.
Example:
Expand Down Expand Up @@ -69,7 +69,7 @@ void py_module(py::module& module) {
const ttnn::Tensor& input_tensor,
const ttnn::Tensor& weight,
const std::optional<int>& padding_idx,
const ttnn::Layout& layout,
const std::optional<ttnn::Layout>& layout,
EmbeddingsType embeddings_type,
const std::optional<const DataType> dtype,
std::optional<ttnn::Tensor> &optional_output_tensor,
Expand All @@ -81,7 +81,7 @@ void py_module(py::module& module) {
py::arg("weight").noconvert(),
py::kw_only(),
py::arg("padding_idx") = std::nullopt,
py::arg("layout") = ttnn::ROW_MAJOR_LAYOUT,
py::arg("layout") = std::nullopt,
py::arg("embeddings_type").noconvert() = EmbeddingsType::GENERIC,
py::arg("dtype").noconvert() = std::nullopt,
py::arg("output_tensor").noconvert() = std::nullopt,
Expand Down

0 comments on commit ef71901

Please sign in to comment.