Skip to content

Commit

Permalink
Implement conversions for stablehlo.broadcast_in_dim and ttir.broadca…
Browse files Browse the repository at this point in the history
…st ops (#680)

Implement conversions for stablehlo.broadcast_in_dim and ttir.broadcast ops along with tests.
  • Loading branch information
uazizTT authored Sep 13, 2024
1 parent 5b997c1 commit c4d70db
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,24 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let hasVerifier = 1;
}

def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
let summary = "Broadcast operation.";
let description = [{
Broadcast op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ArrayAttr:$dimension,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let summary = "Conv2d operation.";
let description = [{
Expand Down
64 changes: 64 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"

#include "mlir/Dialect/Traits.h"
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/PatternMatch.h>
Expand Down Expand Up @@ -253,6 +254,60 @@ class StableHLOToTTIRConstantOpConversionPattern

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
srcOp.getValue());
return success();
}
};

class StableHLOToTTIRBroadcastInDimOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::BroadcastInDimOp> {
using OpConversionPattern<
mlir::stablehlo::BroadcastInDimOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::BroadcastInDimOp srcOp,
mlir::stablehlo::BroadcastInDimOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
}

auto outputType = mlir::cast<RankedTensorType>(srcOp.getResult().getType());
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg =
rewriter.getI64ArrayAttr(adaptor.getBroadcastDimensions());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp, outputTensor.getType(), Value(adaptor.getOperand()),
Value(outputTensor), dimArg,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}

private:
LogicalResult
checkBasicLegality(mlir::stablehlo::BroadcastInDimOp &srcOp,
mlir::stablehlo::BroadcastInDimOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {

llvm::SmallVector<int64_t, 4> broadcastedShape;
auto inputShape =
mlir::cast<mlir::RankedTensorType>((srcOp.getOperand()).getType())
.getShape();

if (!OpTrait::util::getBroadcastedShape(
inputShape, adaptor.getBroadcastDimensions(), broadcastedShape)) {
return rewriter.notifyMatchFailure(
srcOp, "Input cannot be broadcasted to provided dimensions.");
}

return success();
}
Expand Down Expand Up @@ -313,6 +368,14 @@ void addTensorCreationOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRConstantOpConversionPattern>(typeConverter, ctx);
}

void addBroadcastOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<StableHLOToTTIRBroadcastInDimOpConversionPattern>(typeConverter,
ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -326,6 +389,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
addTensorCreationOpsConversionPatterns(ctx, patterns, typeConverter);
addBroadcastOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
25 changes: 25 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/ValueRange.h"
Expand Down Expand Up @@ -594,6 +595,29 @@ class MaxPool2dOpConversionPattern
}
};

class BroadcastOpConversionPattern
: public OpConversionPattern<ttir::BroadcastOp> {
using OpConversionPattern<ttir::BroadcastOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Fold this operation into all consumer ops. It will only work with TTNN
// ops that support implicit broadcasting. We expect each Op's verify
// function to assert their arguments to verify that they can broadcast.

mlir::Value input = srcOp.getOperand(0);
mlir::Value result = srcOp.getResult();

rewriter.replaceAllUsesWith(result, input);
rewriter.eraseOp(srcOp);

return success();
}
};

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand All @@ -619,6 +643,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<1xf32>) -> tensor<512x512xf32>
%1 = stablehlo.maximum %0, %arg1 : tensor<512x512xf32>
// CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %1 : tensor<512x512xf32>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
// CHECK-NOT: %[[C:.*]] = "ttnn.broadcast"[[C:.*]]
%0 = tensor.empty() : tensor<512x512xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{dimension = [1], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%2 = tensor.empty() : tensor<512x512xf32>
%3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %3 : tensor<512x512xf32>
}
}

0 comments on commit c4d70db

Please sign in to comment.