Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix embedding data missmatch (#1633)
When serializing to FlatBuffer, we did not treat the embedding op as a DPS op. Instead, its output was treated as a separate tensor and assigned a different global ID than the DPS init, which later caused a data mismatch at runtime. So in this example: ``` %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x100>>, <interleaved>>, shape = #ttnn.shape<1x12x3200>}> : (!tt.device<#device>) -> tensor<1x12x3200xbf16, #ttnn_layout5> %4 = "ttnn.embedding"(%1, %2, %3) : (tensor<1x12xi32, #ttnn_layout3>, tensor<1x12x3200xbf16, #ttnn_layout5>, tensor<32000x3200xbf16, #ttnn_layout4>) -> tensor<1x12x3200xbf16, #ttnn_layout5> %5 = "ttnn.from_device"(%4) : (tensor<1x12x3200xbf16, #ttnn_layout5>) -> tensor<1x12x3200xbf16, #ttnn_layout2> ``` Here’s what happens: • The "ttnn.empty" operation produces a tensor with a global ID, say 5. • The "ttnn.embedding" operation, instead of reusing the global ID 5 for its output (as would be expected for a DPS operation), is assigned a new global ID, say 6. As a result, its output is not written to the memory chunk associated with global ID 5. • When the runtime tries to execute the "ttnn.from_device" operation, it expects its input to have global ID 5 (since it follows the DPS convention). However, because nothing was written to global ID 5 due to the mismatch in how "ttnn.embedding" was handled, the runtime will instead read a random or uninitialized tensor from that location. This leads to a data mismatch. The root cause of this issue is the incorrect FlatBuffer serialization logic for the embedding operation: ``` ::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp> createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { auto in0 = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); } ``` To fix this, we should replace the line: `auto output = cache.getOrCreate(op.getResult(), kHostAllocatedAddress, kHostAllocatedSize);` with: `auto out = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getResult()));` This change ensures that "ttnn.embedding" writes to the output of "ttnn.empty" with global ID 5 (as in the example) rather than allocating a new buffer with a different ID. Note: This bug could potentially be present in other operations as well. Will check and address them accordingly. closes #1404
- Loading branch information