Skip to content

Commit

Permalink
Adding squeeze op (#543)
Browse files Browse the repository at this point in the history
* Adding  op

* Adding include
  • Loading branch information
mtopalovicTT authored Aug 29, 2024
1 parent 3a59669 commit bf57ad8
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 0 deletions.
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,26 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
let hasVerifier = 1;
}

def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> {
let summary = "Squeeze op.";
let description = [{
Squeeze tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
39 changes: 39 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,44 @@ class ReshapeOpConversionPattern : public OpConversionPattern<ttir::ReshapeOp> {
}
};

class SqueezeOpConversionPattern : public OpConversionPattern<ttir::SqueezeOp> {
public:
using OpConversionPattern<ttir::SqueezeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::SqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Extract input tensor types
::mlir::RankedTensorType inputType =
mlir::cast<::mlir::RankedTensorType>(adaptor.getInput().getType());

// Get the squeeze dimension
int32_t dim = adaptor.getDim();

// Get the shape of the input tensor
auto inputShape = inputType.getShape();
llvm::SmallVector<int32_t, 4> newShape;

// Build the new shape by removing the specified dimension
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
continue;
}
newShape.push_back(inputShape[i]);
}

// Create the new shape attribute
auto shapeAttr = rewriter.getI32ArrayAttr(newShape);

// Replace the SqueezeOp with a ReshapeOp
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), shapeAttr);

return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -229,6 +267,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TransposeOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SqueezeOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
43 changes: 43 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,49 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttir::SqueezeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
int32_t dim = getDim();

// Check that the dimension `dim` is valid.
if (dim < 0 || dim >= inputType.getRank()) {
return emitOpError() << "Invalid dimension " << dim << " for squeezing.";
}

// Check that the dimension `dim` is 1 in the input tensor.
if (inputType.getDimSize(dim) != 1) {
return emitOpError() << "Dimension " << dim
<< " in the input tensor must be 1.";
}

if (outputType.getRank() == 0) {
return emitOpError() << "Output tensor must have at least one dimension.";
}

// Check that the rank of the output tensor is one less than the input tensor.
if (outputType.getRank() != inputType.getRank() - 1) {
return emitOpError()
<< "Output tensor rank must be one less than the input tensor rank.";
}

// Check that the dimensions of the output tensor are the same as the input
// tensor except for dimension `dim`.
for (int64_t i = 0, j = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
continue;
}
if (inputType.getDimSize(i) != outputType.getDimSize(j)) {
return emitOpError() << "Dimensions of the output tensor must be the "
"same as the input tensor except for dimension "
<< dim << ".";
}
++j;
}

return success();
}

// ANCHOR: adding_an_op_matmul_ttir_verify
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_squeeze.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> {
%0 = tensor.empty() : tensor<1x2x32x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16>
return %1 : tensor<1x2x32x32xbf16>
}
}

0 comments on commit bf57ad8

Please sign in to comment.