diff --git a/tests/ttnn/unit_tests/operations/test_embedding.py b/tests/ttnn/unit_tests/operations/test_embedding.py index a9b6f106a1f..89dc39a0788 100644 --- a/tests/ttnn/unit_tests/operations/test_embedding.py +++ b/tests/ttnn/unit_tests/operations/test_embedding.py @@ -121,3 +121,56 @@ def test_moe_embedding( output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("batch_size", [1, 8, 9]) +@pytest.mark.parametrize("sentence_size", [32, 256, 512]) +@pytest.mark.parametrize("hidden_embedding_dim", [768, 4096]) # Bert_Num_Cols_768, Llama_Num_Cols +@pytest.mark.parametrize( + "vocabulary_size", [512, 30522, 2048] +) # Bert_Position_Embeddings_512, Bert_Word_Embeddings_30528, Llama_Position_Embeddings, +@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG]) +@pytest.mark.parametrize("output_mem_config", [ttnn.DRAM_MEMORY_CONFIG]) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +def test_embedding_tiled_input( + device, + batch_size, + sentence_size, + hidden_embedding_dim, + vocabulary_size, + input_mem_config, + output_mem_config, + layout, +): + torch.manual_seed(1234) + + torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size)) + torch_weights = torch_random((vocabulary_size, hidden_embedding_dim), -0.1, 0.1, dtype=torch.bfloat16) + # torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights) + torch_embedding = torch.nn.Embedding.from_pretrained(torch_weights) + torch_output_tensor = torch_embedding(torch_input_tensor) + + input_tensor = ttnn.to_device( + ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT), + device, + memory_config=input_mem_config, + ) + weights = ttnn.to_device( + ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT), + device, + memory_config=input_mem_config, + ) + + # output_tensor = ttnn.embedding(input_tensor, weights, memory_config=output_mem_config, layout=ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.embedding( + input_tensor, + weights, + embeddings_type=ttnn.EmbeddingsType.GENERIC, # Default embeddings type + dtype=ttnn.bfloat16, + memory_config=output_mem_config, # Default memory config + queue_id=0, # Default queue id + layout=layout, + ) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index bc2b1773cc2..8ce161317e3 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -158,6 +158,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding/embedding.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/embedding_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp diff --git a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp index 3361b261937..a3cd6e05a01 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp @@ -95,21 +95,21 @@ void kernel_main() { uint64_t src_noc_addr; uint32_t token = input_l1_ptr[k]; #if defined PADDED - if (token == pad_token) { - src_noc_addr = pad_noc_addr; - } else { - src_noc_addr = get_noc_addr(token, weights); - } - #elif defined BINARY - if (token == 0) { - src_noc_addr = zero_noc_addr; - } else { - src_noc_addr = one_noc_addr; - } + if (token == pad_token) { + src_noc_addr = pad_noc_addr; + } else { + src_noc_addr = get_noc_addr(token, weights); + } + #elif defined BINARY + if (token == 0) { + src_noc_addr = zero_noc_addr; + } else { + src_noc_addr = one_noc_addr; + } #else #if defined BFP16 union { float f; uint32_t u; } u; - u.u = (uint32_t)input_l1_ptr[token_idx] << 16; + u.u = (uint32_t)input_l1_ptr[k] << 16; uint32_t token_casted = static_cast(u.f); src_noc_addr = get_noc_addr(token_casted, weights); #else diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.cpp b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp new file mode 100644 index 00000000000..8be8ab3d8c3 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/embedding/embedding.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/common/constants.hpp" +#include "ttnn/operations/embedding/device/embedding_device_operation.hpp" +#include "ttnn/run_operation.hpp" +#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp" + +namespace ttnn::operations::embedding{ + +ttnn::Tensor EmbeddingOperation::invoke( + uint8_t queue_id, + const Tensor& input_tensor_arg, + const Tensor& weight_arg, + const std::optional& pad_token, + const std::optional& layout, + EmbeddingsType embeddings_type, + const std::optional dtype, + const std::optional& memory_config, + std::optional optional_output_tensor) { + if (pad_token.has_value()) { + embeddings_type = EmbeddingsType::PADDED; + } + Tensor mutable_input_tensor = input_tensor_arg; + Tensor mutable_weight = weight_arg; + + // TODO: Add support for indices tensor in tile layout + // Issue #: 14915 + TT_FATAL(input_tensor_arg.get_layout() == ttnn::ROW_MAJOR_LAYOUT, "Indices tensor must be in row major layout."); + + if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) { + mutable_weight = ttnn::to_layout(mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_weight.device()); + } + auto hidden_embedding_dim = mutable_weight.get_shape()[-1]; + auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1]; + auto weight = ttnn::unsqueeze_to_4D(mutable_weight); + + auto batch_size = mutable_input_tensor.get_shape()[0]; + auto sentence_size = mutable_input_tensor.get_shape()[-1]; + auto input_tensor = + ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); + + // If layout is row major, OR if the input tensor is not a multiple of TILE_HEIGHT, then we cannot use tilized + bool fused_tilized = false; + if(input_tensor.get_legacy_shape()[-1] % TILE_HEIGHT == 0 && + weight.get_legacy_shape()[-1] % TILE_WIDTH == 0){ + if(layout.has_value()){ + if(layout.value() == ttnn::TILE_LAYOUT) fused_tilized = true; + } + else if(weight_arg.get_layout() == ttnn::TILE_LAYOUT){ + fused_tilized = true; + } + } + + auto embeddings = operation::run( + Embeddings{ + .output_mem_config = memory_config.value_or(input_tensor.memory_config()), + .tilized = fused_tilized, + .embeddings_type = embeddings_type, + .pad_token = pad_token, + .output_dtype = dtype.value_or(weight.get_dtype())}, + {input_tensor, weight}) + .at(0); + embeddings = ttnn::reshape( + embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); + embeddings = ttnn::to_layout(embeddings, layout.value_or(weight_arg.get_layout()), std::nullopt, std::nullopt, (Device*)nullptr); + return embeddings; +} +ttnn::Tensor EmbeddingOperation::invoke( + const Tensor& input_tensor_arg, + const Tensor& weight_arg, + const std::optional& pad_token, + const std::optional& layout, + EmbeddingsType embeddings_type, + const std::optional dtype, + const std::optional& memory_config, + std::optional optional_output_tensor + ) { + return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor); +} + +} // namespace ttnn::operations::embedding diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp index 52439fd693d..03679ffd40e 100644 --- a/ttnn/cpp/ttnn/operations/embedding/embedding.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.hpp @@ -4,11 +4,8 @@ #pragma once -#include "ttnn/common/constants.hpp" #include "ttnn/operations/embedding/device/embedding_device_operation.hpp" -#include "ttnn/run_operation.hpp" #include "ttnn/decorators.hpp" -#include "ttnn/operations/core/core.hpp" namespace ttnn { @@ -17,56 +14,25 @@ namespace operations { namespace embedding { struct EmbeddingOperation { - static inline Tensor invoke( + static ttnn::Tensor invoke( uint8_t queue_id, const Tensor& input_tensor_arg, const Tensor& weight_arg, const std::optional& pad_token = std::nullopt, - const Layout& layout = ttnn::ROW_MAJOR_LAYOUT, + const std::optional& layout = std::nullopt, EmbeddingsType embeddings_type = EmbeddingsType::GENERIC, const std::optional dtype = std::nullopt, const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - if (pad_token.has_value()) { - embeddings_type = EmbeddingsType::PADDED; - } - - auto hidden_embedding_dim = weight_arg.get_shape()[-1]; - auto padded_hidden_embedding_dim = weight_arg.get_shape().with_tile_padding()[-1]; - auto weight = ttnn::unsqueeze_to_4D(weight_arg); - - auto batch_size = input_tensor_arg.get_shape()[0]; - auto sentence_size = input_tensor_arg.get_shape()[-1]; - auto input_tensor = - ttnn::reshape(input_tensor_arg, ttnn::SimpleShape{std::array{batch_size, 1, 1, sentence_size}}); - - bool tilized = layout == ttnn::TILE_LAYOUT; - auto embeddings = operation::run( - Embeddings{ - .output_mem_config = memory_config.value_or(input_tensor.memory_config()), - .tilized = tilized, - .embeddings_type = embeddings_type, - .pad_token = pad_token, - .output_dtype = dtype.value_or(weight.get_dtype())}, - {input_tensor, weight}) - .at(0); - embeddings = ttnn::reshape( - embeddings, ttnn::SimpleShape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); - return embeddings; - } - - static inline auto invoke( + std::optional optional_output_tensor = std::nullopt); + static ttnn::Tensor invoke( const Tensor& input_tensor_arg, const Tensor& weight_arg, const std::optional& pad_token = std::nullopt, - const Layout& layout = ttnn::ROW_MAJOR_LAYOUT, + const std::optional& layout = std::nullopt, EmbeddingsType embeddings_type = EmbeddingsType::GENERIC, const std::optional dtype = std::nullopt, const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt - ) { - return invoke(DefaultQueueId, input_tensor_arg, weight_arg, pad_token, layout, embeddings_type, dtype, memory_config, optional_output_tensor); - } + std::optional optional_output_tensor = std::nullopt); }; } // namespace embedding diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding_pybind.hpp b/ttnn/cpp/ttnn/operations/embedding/embedding_pybind.hpp index dbd2f167c5b..49fa7769122 100644 --- a/ttnn/cpp/ttnn/operations/embedding/embedding_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/embedding_pybind.hpp @@ -40,7 +40,7 @@ void py_module(py::module& module) { Returns: - ttnn.Tensor: the output tensor. + ttnn.Tensor: the output tensor of layout == layout or layout of the weights tensor. Example: @@ -69,7 +69,7 @@ void py_module(py::module& module) { const ttnn::Tensor& input_tensor, const ttnn::Tensor& weight, const std::optional& padding_idx, - const ttnn::Layout& layout, + const std::optional& layout, EmbeddingsType embeddings_type, const std::optional dtype, std::optional &optional_output_tensor, @@ -81,7 +81,7 @@ void py_module(py::module& module) { py::arg("weight").noconvert(), py::kw_only(), py::arg("padding_idx") = std::nullopt, - py::arg("layout") = ttnn::ROW_MAJOR_LAYOUT, + py::arg("layout") = std::nullopt, py::arg("embeddings_type").noconvert() = EmbeddingsType::GENERIC, py::arg("dtype").noconvert() = std::nullopt, py::arg("output_tensor").noconvert() = std::nullopt,