diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 758fb41d7..013f340e1 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -567,8 +567,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - AnyRankedTensor:$weight); + AnyRankedTensor:$weight, + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index e60261bad..2b18fb24f 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -352,7 +352,7 @@ class EmbeddingOpConversionPattern ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight()); + adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput()); return success(); } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 6e3745c91..5fd09d5e4 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -753,9 +753,9 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); - auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, - kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); + auto out = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, out); } template