Skip to content

Commit

Permalink
Add 1D indices support for embedding (#15726)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue

### Problem description
Trying to use embedding with no batches results in an invalid reshape

### What's changed
added conditional reshaping/batch size instantiation for whether the
indices tensor is 1 or 2 dims.

### Checklist
- [ ] Post commit CI passes:
#15726 (comment)
  • Loading branch information
yugi957 authored Dec 6, 2024
1 parent f976cb0 commit aee56ac
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions ttnn/cpp/ttnn/operations/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ ttnn::Tensor EmbeddingOperation::invoke(
// Issue #: 14915
TT_FATAL(input_tensor_arg.get_layout() == ttnn::ROW_MAJOR_LAYOUT, "Indices tensor must be in row major layout.");

if (mutable_input_tensor.get_layout() == ttnn::TILE_LAYOUT) {
mutable_input_tensor = ttnn::to_layout(
mutable_input_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_input_tensor.device());
}
if (mutable_weight.get_layout() == ttnn::TILE_LAYOUT) {
mutable_weight = ttnn::to_layout(
mutable_weight, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, mutable_weight.device());
Expand All @@ -41,7 +45,8 @@ ttnn::Tensor EmbeddingOperation::invoke(
auto padded_hidden_embedding_dim = mutable_weight.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(mutable_weight);

auto batch_size = mutable_input_tensor.get_shape()[0];
// If indices tensor is 1 dimensional, batch size is 1
auto batch_size = (mutable_input_tensor.get_shape().rank() == 1) ? 1 : mutable_input_tensor.get_shape()[0];
auto sentence_size = mutable_input_tensor.get_shape()[-1];
auto input_tensor =
ttnn::reshape(mutable_input_tensor, ttnn::Shape{std::array<uint32_t, 4>{batch_size, 1, 1, sentence_size}});
Expand All @@ -67,8 +72,14 @@ ttnn::Tensor EmbeddingOperation::invoke(
.output_dtype = dtype.value_or(weight.get_dtype())},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(
embeddings, ttnn::Shape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
// Don't include batch_size if there was none
if (input_tensor_arg.get_shape().rank() == 1) {
embeddings =
ttnn::reshape(embeddings, ttnn::Shape{std::array<uint32_t, 2>{sentence_size, hidden_embedding_dim}});
} else {
embeddings = ttnn::reshape(
embeddings, ttnn::Shape{std::array<uint32_t, 3>{batch_size, sentence_size, hidden_embedding_dim}});
}
embeddings = ttnn::to_layout(
embeddings, layout.value_or(weight_arg.get_layout()), std::nullopt, std::nullopt, (Device*)nullptr);
return embeddings;
Expand Down

0 comments on commit aee56ac

Please sign in to comment.