diff --git a/ttnn/cpp/ttnn/operations/embedding/embedding.cpp b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp index 1943dd38012..5e61dc6db66 100644 --- a/ttnn/cpp/ttnn/operations/embedding/embedding.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/embedding.cpp @@ -33,6 +33,10 @@ ttnn::Tensor EmbeddingOperation::invoke( // Issue #: 14915 TT_FATAL(input_tensor_arg.get_layout() == ttnn::ROW_MAJOR_LAYOUT, "Indices tensor must be in row major layout."); + if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) { + mutable_input_tensor = ttnn::to_layout( + mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_input_tensor.device()); + } if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) { mutable_weight = ttnn::to_layout( mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_weight.device()); @@ -41,7 +45,8 @@ ttnn::Tensor EmbeddingOperation::invoke( auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1]; auto weight = ttnn::unsqueeze_to_4D(mutable_weight); - auto batch_size = mutable_input_tensor.get_shape()[0]; + // If indices tensor is 1 dimensional, batch size is 1 + auto batch_size = (mutable_input_tensor.get_shape().rank() == 1) ? 1 : mutable_input_tensor.get_shape()[0]; auto sentence_size = mutable_input_tensor.get_shape()[-1]; auto input_tensor = ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array{batch_size, 1, 1, sentence_size}}); @@ -67,8 +72,14 @@ ttnn::Tensor EmbeddingOperation::invoke( .output_dtype = dtype.value_or(weight.get_dtype())}, {input_tensor, weight}) .at(0); - embeddings = ttnn::reshape( - embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); + // Don't include batch_size if there was none + if (input_tensor_arg.get_shape().rank() == 1) { + embeddings = + ttnn::reshape(embeddings, ttnn::Shape{std::array{sentence_size, hidden_embedding_dim}}); + } else { + embeddings = ttnn::reshape( + embeddings, ttnn::Shape{std::array{batch_size, sentence_size, hidden_embedding_dim}}); + } embeddings = ttnn::to_layout( embeddings, layout.value_or(weight_arg.get_layout()), std::nullopt, std::nullopt, (Device*)nullptr); return embeddings;