Skip to content

Commit

Permalink
broadcasting + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT committed Dec 20, 2024
1 parent 681cf47 commit 6e8b91d
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 50 deletions.
4 changes: 2 additions & 2 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def ConvertTTKernelToEmitC : Pass<"convert-ttkernel-to-emitc", "::func::FuncOp">
}

def ConvertTTIRToLinalg: Pass<"convert-ttir-to-linalg", "::mlir::ModuleOp"> {
let summary = "Convert TTIR dialect to LinAlg dialect.";
let constructor = "createConvertTTIRToLinAlgPass()";
let summary = "Convert TTIR dialect to Linalg dialect.";
let constructor = "createConvertTTIRToLinalgPass()";
let dependentDialects = ["mlir::tt::ttir::TTIRDialect", "mlir::linalg::LinalgDialect"];
}

Expand Down
198 changes: 154 additions & 44 deletions lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,69 +15,179 @@
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using namespace mlir::tt;

namespace {
template <typename TTIROpTy, typename LinalgOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ElementwiseOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
resultTypes))) {
using TensorRanks = SmallVector<int64_t, 2>;

static LogicalResult computeBroadcastedShape(SmallVector<Value, 3> inputs,
TensorRanks &broadcastedShape) {
for (Value input : inputs) {
auto type = dyn_cast<RankedTensorType>(input.getType());
if (!type) {
return failure();
}

rewriter.replaceOpWithNewOp<LinalgOpTy>(
op, resultTypes, adaptor.getInputs(), adaptor.getOutputs());
return success();
const ArrayRef<int64_t> shape = type.getShape();
if (broadcastedShape.empty()) {
broadcastedShape.assign(shape.begin(), shape.end());
continue;
}
if (broadcastedShape.size() < shape.size()) {
broadcastedShape.resize(shape.size());
}

for (size_t i = 0; i < std::max(broadcastedShape.size(), shape.size());
++i) {
const int64_t dimA =
i < broadcastedShape.size() ? broadcastedShape[i] : 1;
const int64_t dimB = i < shape.size() ? shape[i] : 1;

if (dimA != dimB && dimA != 1 && dimB != 1) {
return failure();
}
broadcastedShape[i] = std::max(dimA, dimB);
}
}
return success();
}

// Helper func to check which dims need to be broadcast and which need to be
// collapsed. Assumes that inputShape is broadcast-able to targetShape.
static void getDimsToBroadcastAndCollapse(
ArrayRef<int64_t> inputShape, ArrayRef<int64_t> targetShape,
TensorRanks &broadcastDims, SmallVector<TensorRanks, 2> &reassocIndices) {

broadcastDims.clear();
reassocIndices.clear();

// Identify what needs broadcasting, aligning from right
int targetIdx = targetShape.size() - 1;
int inputIdx = inputShape.size() - 1;

while (targetIdx >= 0) {
if (inputIdx >= 0) {
llvm::outs() << inputShape[inputIdx] << " vs " << targetShape[targetIdx]
<< "\n";
// This should be impossible since we verify input while computing
// targetShape.
assert(
(inputShape[inputIdx] == targetShape[targetIdx] ||
inputShape[inputIdx] == 1) &&
"attempting to broadcast shape which does not broadcast to target!");
if (inputShape[inputIdx] == 1 && targetShape[targetIdx] != 1) {
broadcastDims.push_back(inputIdx);
}
inputIdx--;
} else {
// Input exhausted, we need to broadcast remaining dimensions.
broadcastDims.push_back(targetIdx);
}
targetIdx--;
}

llvm::outs() << "Found dims to broadcast: ";
for (const auto dim : broadcastDims) {
llvm::outs() << dim << " ";
}
llvm::outs() << "\n";

// Group non-broadcast dimensions together for collapse.
TensorRanks currentGroup;
size_t nextBroadcastDimIdx = 0;
bool fullDimInGroup = false;
for (size_t i = 0; i < inputShape.size(); ++i) {
if (nextBroadcastDimIdx < broadcastDims.size() &&
static_cast<int64_t>(i) == broadcastDims[nextBroadcastDimIdx]) {
nextBroadcastDimIdx++;
} else {
if (fullDimInGroup) {
// Non-broadcast dimensions end the current group.
reassocIndices.push_back(currentGroup);
currentGroup.clear();
}
fullDimInGroup = true;
}
currentGroup.push_back(i);
}
};

class SubtractOpConversionPattern
: public OpConversionPattern<ttir::SubtractOp> {
using OpConversionPattern<ttir::SubtractOp>::OpConversionPattern;
// Add any remaining dimensions in the current group.
if (!currentGroup.empty()) {
reassocIndices.push_back(currentGroup);
}
}

template <typename TTIROpTy, typename LinalgOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ElementwiseOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::SubtractOp srcOp, ttir::SubtractOp::Adaptor adaptor,
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType lhsType =
mlir::cast<RankedTensorType>(adaptor.getInputs().front().getType());
RankedTensorType rhsType =
mlir::cast<RankedTensorType>(adaptor.getInputs().back().getType());

if (lhsType.getShape() == rhsType.getShape()) {
rewriter.replaceOpWithNewOp<linalg::SubOp>(
srcOp, adaptor.getInputs(), adaptor.getOutputs(), srcOp->getAttrs());
Location loc = op.getLoc();

// Broadcast for rhs operand require the operation to be commutative to
// allow switching the order of operands. To allow this conversion, the
// following conversion is applied to SubtractOp: subtractOp(lhs,rhs) ->
// addOp(lhs, negOp(rhs))
// First, compute broadcasted shape from operands.
SmallVector<Value, 3> inputs = adaptor.getInputs();
TensorRanks broadcastedShape;
if (failed(computeBroadcastedShape(inputs, broadcastedShape))) {
return rewriter.notifyMatchFailure(op, "Operands are not broadcastable");
}

} else {
auto negEmptyOp = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), rhsType.getShape(), rhsType.getElementType());
auto negOp = rewriter.create<linalg::NegFOp>(
srcOp.getLoc(), ValueRange{adaptor.getInputs().back()},
ValueRange{negEmptyOp}, srcOp->getAttrs());

rewriter.replaceOpWithNewOp<linalg::AddOp>(
srcOp,
ValueRange{adaptor.getInputs().front(), negOp.getResults().front()},
adaptor.getOutputs(), srcOp->getAttrs());
// Replace any inputs which aren't in target shape with broadcast results
// which are.
SmallVector<Value, 4> broadcastedInputs;
for (Value input : inputs) {
auto inputRankedTensorType = dyn_cast<RankedTensorType>(input.getType());
if (!inputRankedTensorType) {
continue;
}
Type elementType = inputRankedTensorType.getElementType();

// Insert and use a broadcast op if input does not perfectly match target
// shape.
TensorRanks broadCastDims;
SmallVector<TensorRanks, 2> reassocIndexes;
getDimsToBroadcastAndCollapse(inputRankedTensorType.getShape(),
broadcastedShape, broadCastDims,
reassocIndexes);
if (!broadCastDims.empty()) {
Value broadcastInput = input;
// The broadcast op requires we actually collapse any dimensions with
// size 1 we want to broadcast along.
if (reassocIndexes.size() != inputRankedTensorType.getShape().size()) {
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
loc, input, reassocIndexes);
broadcastInput = collapseOp.getResult();
}
auto initTensor = rewriter.create<tensor::EmptyOp>(
loc, broadcastedShape, elementType);
auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
loc, broadcastInput, initTensor.getResult(), broadCastDims);
for (auto result : broadcastOp.getResults()) {
broadcastedInputs.push_back(result);
}
} else {
broadcastedInputs.push_back(input);
}
}

// Perform the actual op substitution, using broadcasted operands when
// needed.
SmallVector<Type> resultTypes;
if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
resultTypes))) {
return failure();
}
rewriter.replaceOpWithNewOp<LinalgOpTy>(op, resultTypes, broadcastedInputs,
adaptor.getOutputs());
return success();
}
};
Expand All @@ -90,8 +200,8 @@ void populateTTIRToLinalgPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ElementwiseOpConversionPattern<ttir::AddOp, linalg::AddOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, linalg::MulOp>,

SubtractOpConversionPattern>(typeConverter, ctx);
ElementwiseOpConversionPattern<ttir::SubtractOp, linalg::SubOp>>(
typeConverter, ctx);
}

} // namespace mlir::tt
59 changes: 55 additions & 4 deletions test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
// RUN: ttmlir-opt --convert-ttir-to-linalg %s | FileCheck %s
module attributes{} {
func.func @add(
%arg0: tensor<32x32xf32>, // First input tensor
%arg1: tensor<32x32xf32>, // Second input tensor
%arg2: tensor<32x32xf32> // Output tensor (result stored here)
%arg0: tensor<32x32xf32>,
%arg1: tensor<32x32xf32>,
%arg2: tensor<32x32xf32>
) -> tensor<32x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}

func.func @add_with_broadcast(
%arg0: tensor<32x32xf32>,
%arg1: tensor<32x1xf32>,
%arg2: tensor<32x32xf32>
) -> tensor<32x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xf32>, tensor<32x1xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: {{%.+}} = tensor.collapse_shape
// CHECK: {{%[0-9]+}} = tensor.empty()
// CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>)
// CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{{.+}}}xf32>) -> tensor<{{.+}}xf32>
return %1 : tensor<32x32xf32>
}

func.func @add_with_broadcast_1(
%arg0: tensor<32x1xf32>,
%arg1: tensor<32x32x32xf32>,
%arg2: tensor<32x32x32xf32>
) -> tensor<32x32x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x1xf32>, tensor<32x32x32xf32>, tensor<32x32x32xf32>) -> tensor<32x32x32xf32>
// CHECK: {{%.+}} = tensor.collapse_shape
// CHECK: {{%[0-9]+}} = tensor.empty()
// CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>)
// CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32>
return %1 : tensor<32x32x32xf32>
}

func.func @add_with_broadcast_2(
%arg0: tensor<32x1x32xf32>,
%arg1: tensor<32x1x1xf32>,
%arg2: tensor<32x1x32xf32>
) -> tensor<32x1x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x1x32xf32>, tensor<32x1x1xf32>, tensor<32x1x32xf32>) -> tensor<32x1x32xf32>
// CHECK: {{%.+}} = tensor.collapse_shape
// CHECK: {{%[0-9]+}} = tensor.empty()
// CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>)
// CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32>
return %1 : tensor<32x1x32xf32>
}

func.func @add_with_broadcast_3(
%arg0: tensor<32xf32>,
%arg1: tensor<32x32xf32>,
%arg2: tensor<32x32xf32>
) -> tensor<32x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: {{%[0-9]+}} = tensor.empty()
// CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>)
// CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<{{.+}}xf32>, tensor<{{.+}}xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32>
return %1 : tensor<32x32xf32>
}
}

0 comments on commit 6e8b91d

Please sign in to comment.