From 305ea4768685f8cf1852dad9df4b35d784d2bb2a Mon Sep 17 00:00:00 2001 From: Jovan Serbedzija Date: Thu, 19 Dec 2024 11:36:05 +0100 Subject: [PATCH] Fix embedding data missmatch (#1633) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When serializing to FlatBuffer, we did not treat the embedding op as a DPS op. Instead, its output was treated as a separate tensor and assigned a different global ID than the DPS init, which later caused a data mismatch at runtime. So in this example: ``` %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x100>>, >, shape = #ttnn.shape<1x12x3200>}> : (!tt.device<#device>) -> tensor<1x12x3200xbf16, #ttnn_layout5> %4 = "ttnn.embedding"(%1, %2, %3) : (tensor<1x12xi32, #ttnn_layout3>, tensor<1x12x3200xbf16, #ttnn_layout5>, tensor<32000x3200xbf16, #ttnn_layout4>) -> tensor<1x12x3200xbf16, #ttnn_layout5> %5 = "ttnn.from_device"(%4) : (tensor<1x12x3200xbf16, #ttnn_layout5>) -> tensor<1x12x3200xbf16, #ttnn_layout2> ``` Here’s what happens: • The "ttnn.empty" operation produces a tensor with a global ID, say 5. • The "ttnn.embedding" operation, instead of reusing the global ID 5 for its output (as would be expected for a DPS operation), is assigned a new global ID, say 6. As a result, its output is not written to the memory chunk associated with global ID 5. • When the runtime tries to execute the "ttnn.from_device" operation, it expects its input to have global ID 5 (since it follows the DPS convention). However, because nothing was written to global ID 5 due to the mismatch in how "ttnn.embedding" was handled, the runtime will instead read a random or uninitialized tensor from that location. This leads to a data mismatch. The root cause of this issue is the incorrect FlatBuffer serialization logic for the embedding operation: ``` ::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp> createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { auto in0 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); } ``` To fix this, we should replace the line: `auto output = cache.getOrCreate(op.getResult(), kHostAllocatedAddress, kHostAllocatedSize);` with: `auto out = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getResult()));` This change ensures that "ttnn.embedding" writes to the output of "ttnn.empty" with global ID 5 (as in the example) rather than allocating a new buffer with a different ID. Note: This bug could potentially be present in other operations as well. Will check and address them accordingly. closes https://github.com/tenstorrent/tt-mlir/issues/1404 --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 4 ++-- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 2 +- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 758fb41d7..013f340e1 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -567,8 +567,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - AnyRankedTensor:$weight); + AnyRankedTensor:$weight, + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index e60261bad..2b18fb24f 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -352,7 +352,7 @@ class EmbeddingOpConversionPattern ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight()); + adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput()); return success(); } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 6e3745c91..5fd09d5e4 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -753,9 +753,9 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); - auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, - kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); + auto out = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, out); } template