diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 710c5768a..0a913e3b4 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -401,6 +401,26 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { let hasVerifier = 1; } +def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> { + let summary = "Squeeze op."; + let description = [{ + Squeeze tensor. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: adding_an_op_matmul_ttir def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let summary = "Matrix multiply operation."; diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 6e3aa3f4a..81c173705 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -186,6 +186,44 @@ class ReshapeOpConversionPattern : public OpConversionPattern { } }; +class SqueezeOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::SqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Extract input tensor types + ::mlir::RankedTensorType inputType = + mlir::cast<::mlir::RankedTensorType>(adaptor.getInput().getType()); + + // Get the squeeze dimension + int32_t dim = adaptor.getDim(); + + // Get the shape of the input tensor + auto inputShape = inputType.getShape(); + llvm::SmallVector newShape; + + // Build the new shape by removing the specified dimension + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (i == dim) { + continue; + } + newShape.push_back(inputShape[i]); + } + + // Create the new shape attribute + auto shapeAttr = rewriter.getI32ArrayAttr(newShape); + + // Replace the SqueezeOp with a ReshapeOp + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getOutput(), shapeAttr); + + return success(); + } +}; + } // namespace // ANCHOR: adding_an_op_matmul_op_rewriter @@ -229,6 +267,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TransposeOpConversionPattern, ConcatOpConversionPattern, ReshapeOpConversionPattern, + SqueezeOpConversionPattern, MatmulOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 5f23e7967..721928aea 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -262,6 +262,49 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +::mlir::LogicalResult mlir::tt::ttir::SqueezeOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + ::mlir::RankedTensorType outputType = getOutput().getType(); + int32_t dim = getDim(); + + // Check that the dimension `dim` is valid. + if (dim < 0 || dim >= inputType.getRank()) { + return emitOpError() << "Invalid dimension " << dim << " for squeezing."; + } + + // Check that the dimension `dim` is 1 in the input tensor. + if (inputType.getDimSize(dim) != 1) { + return emitOpError() << "Dimension " << dim + << " in the input tensor must be 1."; + } + + if (outputType.getRank() == 0) { + return emitOpError() << "Output tensor must have at least one dimension."; + } + + // Check that the rank of the output tensor is one less than the input tensor. + if (outputType.getRank() != inputType.getRank() - 1) { + return emitOpError() + << "Output tensor rank must be one less than the input tensor rank."; + } + + // Check that the dimensions of the output tensor are the same as the input + // tensor except for dimension `dim`. + for (int64_t i = 0, j = 0; i < inputType.getRank(); ++i) { + if (i == dim) { + continue; + } + if (inputType.getDimSize(i) != outputType.getDimSize(j)) { + return emitOpError() << "Dimensions of the output tensor must be the " + "same as the input tensor except for dimension " + << dim << "."; + } + ++j; + } + + return success(); +} + // ANCHOR: adding_an_op_matmul_ttir_verify ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { ::mlir::RankedTensorType inputAType = getA().getType(); diff --git a/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir b/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir new file mode 100644 index 000000000..1798605d8 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> { + %0 = tensor.empty() : tensor<1x2x32x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] + %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> + return %1 : tensor<1x2x32x32xbf16> + } +}