Skip to content

Commit

Permalink
Fix embedding data missmatch (#1633)
Browse files Browse the repository at this point in the history
When serializing to FlatBuffer, we did not treat the embedding op as a
DPS op. Instead, its output was treated as a separate tensor and
assigned a different global ID than the DPS init, which later caused a
data mismatch at runtime.

So in this example:
```
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x100>>, <interleaved>>, shape = #ttnn.shape<1x12x3200>}> : (!tt.device<#device>) -> tensor<1x12x3200xbf16, #ttnn_layout5>
%4 = "ttnn.embedding"(%1, %2, %3) : (tensor<1x12xi32, #ttnn_layout3>, tensor<1x12x3200xbf16, #ttnn_layout5>, tensor<32000x3200xbf16, #ttnn_layout4>) -> tensor<1x12x3200xbf16, #ttnn_layout5>
%5 = "ttnn.from_device"(%4) : (tensor<1x12x3200xbf16, #ttnn_layout5>) -> tensor<1x12x3200xbf16, #ttnn_layout2>
```
Here’s what happens:
	•	The "ttnn.empty" operation produces a tensor with a global ID, say 5.
• The "ttnn.embedding" operation, instead of reusing the global ID 5 for
its output (as would be expected for a DPS operation), is assigned a new
global ID, say 6. As a result, its output is not written to the memory
chunk associated with global ID 5.
• When the runtime tries to execute the "ttnn.from_device" operation, it
expects its input to have global ID 5 (since it follows the DPS
convention). However, because nothing was written to global ID 5 due to
the mismatch in how "ttnn.embedding" was handled, the runtime will
instead read a random or uninitialized tensor from that location. This
leads to a data mismatch.

The root cause of this issue is the incorrect FlatBuffer serialization
logic for the embedding operation:
```
::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp>
createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
  auto in0 =
      cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
  auto in1 = cache.at<::tt::target::TensorRef>(
      getOperandThroughDPSOps(op.getWeight()));
  auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
                                  kHostAllocatedAddress, kHostAllocatedSize);
  return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output);
}
```

To fix this, we should replace the line:
`auto output = cache.getOrCreate(op.getResult(), kHostAllocatedAddress,
kHostAllocatedSize);`

with:
`auto out =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getResult()));`

This change ensures that "ttnn.embedding" writes to the output of
"ttnn.empty" with global ID 5 (as in the example) rather than allocating
a new buffer with a different ID.

Note: This bug could potentially be present in other operations as well.
Will check and address them accordingly.

closes #1404
  • Loading branch information
jserbedzijaTT authored Dec 19, 2024
1 parent 293d226 commit 305ea47
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
AnyRankedTensor:$weight);
AnyRankedTensor:$weight,
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class EmbeddingOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight());
adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput());

return success();
}
Expand Down
6 changes: 3 additions & 3 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,9 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) {
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto in1 = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getWeight()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output);
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, out);
}

template <typename EmbeddingBackwardOp>
Expand Down

0 comments on commit 305ea47

Please sign in to comment.