Skip to content

Commit

Permalink
Additional verifications of TTIR dialect ops (#1399)
Browse files Browse the repository at this point in the history
- Refactoring of ElementwiseOpInteface to better reflect intention, with
  fix of broadcast shape calculation, considering that operand that
represetnts destination shouldn't affect output shape.
- Check number of operands for AttrSizedOperandSegments ops with simple
  traits.
- Minor refactoring of TTIR_GenericOp.

Addresses #1289, but I would
leave it as open to track further progress with similar traits and
interfaces needed in the TTNN dialect.
  • Loading branch information
azecevicTT authored Nov 26, 2024
1 parent 4083e98 commit d7798cf
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 61 deletions.
51 changes: 14 additions & 37 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> {
// TTIR top level named ops
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, TTIR_ElementwiseOpInterface])> {
TTIR_DPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, TTIR_Broadcastable])> {

let description = [{
Base class for elementwise operations. Elementwise operations can take inputs with different shape,
Expand All @@ -187,7 +191,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [FourOperands])> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
Expand All @@ -210,7 +214,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [TwoOperands])> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
Expand Down Expand Up @@ -424,7 +428,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [ThreeOperands])> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
Expand Down Expand Up @@ -1175,11 +1179,10 @@ class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) {
assert(getNumOperands() == 2 && "Input and output operand must have the same rank");
assert(sameRank(getOperands()) &&
"Elementwise unary op must have only one input and one output operand.");
assert(sameRank(getOperation()->getOperands()) &&
"Input and output operand must have the same rank");

auto rank = mlir::cast<RankedTensorType>(getOperand(0).getType()).getRank();
auto rank = mlir::cast<RankedTensorType>(getOperation()->getOperand(0).getType()).getRank();

SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(rank));
SmallVector<Attribute> iteratorTypes(
Expand All @@ -1188,19 +1191,6 @@ class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
return {builder.getAffineMapArrayAttr(indexingMaps),
builder.getArrayAttr(iteratorTypes)};
}

static bool sameRank(mlir::OperandRange operands) {
if (operands.empty()) {
return true;
}
auto rank = mlir::cast<RankedTensorType>(operands[0].getType()).getRank();
for (auto operand : operands) {
if (mlir::cast<RankedTensorType>(operand.getType()).getRank() != rank) {
return false;
}
}
return true;
}
}];
}

Expand All @@ -1220,29 +1210,16 @@ class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []>
void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) {
assert(sameRank(getOperands()) &&
assert(sameRank(getOperation()->getOperands()) &&
"For now all operands must have the same rank");
auto rank = mlir::cast<RankedTensorType>(getOperand(0).getType()).getRank();
SmallVector<AffineMap> indexingMaps(getNumOperands(),
auto rank = mlir::cast<RankedTensorType>(getOperation()->getOperand(0).getType()).getRank();
SmallVector<AffineMap> indexingMaps(getOperation()->getNumOperands(),
builder.getMultiDimIdentityMap(rank));
SmallVector<Attribute> iteratorTypes(
rank, builder.getAttr<IteratorTypeAttr>(IteratorType::Parallel));
return {builder.getAffineMapArrayAttr(indexingMaps),
builder.getArrayAttr(iteratorTypes)};
}

static bool sameRank(mlir::OperandRange operands) {
if (operands.empty()) {
return true;
}
auto rank = mlir::cast<RankedTensorType>(operands[0].getType()).getRank();
for (auto operand : operands) {
if (mlir::cast<RankedTensorType>(operand.getType()).getRank() != rank) {
return false;
}
}
return true;
}
}];
}

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace mlir {
namespace tt {
namespace ttir {
namespace detail {
mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op);
mlir::LogicalResult verifyBroadcastable(mlir::Operation *op);
} // namespace detail
} // namespace ttir
} // namespace tt
Expand Down
20 changes: 18 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def TTIROpInterface : OpInterface<"TTIROp"> {
];
}

def TTIR_ElementwiseOpInterface : OpInterface<"ElementwiseOp"> {
def TTIR_Broadcastable : OpInterface<"Broadcastable"> {
let cppNamespace = "::mlir::tt::ttir";

let dependentTraits = [AttrSizedOperandSegments];

let verify = [{
return detail::verifyElementwiseOp($_op);
return detail::verifyBroadcastable($_op);
}];
}

Expand Down Expand Up @@ -105,6 +107,20 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
StaticInterfaceMethod<
/*desc=*/[{
Return if the given operands have the same rank.
}],
/*retTy=*/"bool",
/*methodName=*/"sameRank",
/*args=*/(ins "::mlir::OperandRange":$operands),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::all_equal(llvm::map_range(operands, [](Value operand) {
return mlir::cast<RankedTensorType>(operand.getType()).getRank();
}));
}]
>
];
}

Expand Down
38 changes: 17 additions & 21 deletions lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,33 @@
#include "llvm/ADT/SmallVector.h"

mlir::LogicalResult
mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) {
mlir::tt::ttir::detail::verifyBroadcastable(mlir::Operation *op) {
const auto getShape = [](const Value val) {
return mlir::cast<mlir::RankedTensorType>(val.getType()).getShape();
};

const auto operandSegmentSizes =
op->getAttrOfType<mlir::DenseI32ArrayAttr>("operandSegmentSizes");
// DPS operands shouldn't affect the result shape.
const auto outputSegmentSize =
operandSegmentSizes[operandSegmentSizes.size() - 1];
const auto operandShapes = llvm::map_range(op->getOperands(), getShape);
llvm::SmallVector<int64_t, 4> broadcastedShape;
mlir::OperandRange operands = op->getOperands();
mlir::OperandRange::iterator operand_it = operands.begin();
llvm::SmallVector<int64_t, 4> prevOperandShape(
mlir::cast<mlir::RankedTensorType>((*operand_it).getType()).getShape());

while (++operand_it != operands.end()) {
llvm::SmallVector<int64_t, 4> nextOperandShape(
mlir::cast<mlir::RankedTensorType>((*operand_it).getType()).getShape());

if (!OpTrait::util::getBroadcastedShape(prevOperandShape, nextOperandShape,
for (const auto operandShape :
llvm::drop_end(operandShapes, outputSegmentSize)) {
const auto prevBroadcastedShape = broadcastedShape;
if (!OpTrait::util::getBroadcastedShape(prevBroadcastedShape, operandShape,
broadcastedShape)) {
return op->emitOpError("Operands are not broadcast compatible");
}
prevOperandShape = broadcastedShape;
}

llvm::SmallVector<int64_t, 4> resultShape(
mlir::cast<mlir::RankedTensorType>(op->getResult(0).getType())
.getShape());
// Check that the result shape matches the broadcasted shape of the operands.
llvm::SmallVector<int64_t, 4> resultShape(getShape(op->getResults().front()));
if (broadcastedShape != resultShape) {
return op->emitOpError(
"Result shape must match operand shapes after broadcasting");
}

TypeID expectedBaseTy = op->getResultTypes().front().getTypeID();
if (!llvm::all_of(op->getOperandTypes(),
[&](Type t) { return t.getTypeID() == expectedBaseTy; })) {
return op->emitOpError() << "All operands/results must have the same type";
}

return success();
}
28 changes: 28 additions & 0 deletions test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for Broadcastable interface

// CHECK: 'ttir.abs' op Result shape must match operand shapes after broadcasting
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_unary(%arg0: tensor<1x64xbf16>) -> tensor<2x64xbf16> {
%0 = tensor.empty() : tensor<2x64xbf16>
%1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16>
return %1 : tensor<2x64xbf16>
}

// -----
// CHECK: error: 'ttir.add' op Result shape must match operand shapes after broadcasting
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_binary(%arg0: tensor<2x3x64xf32>, %arg1: tensor<64xf32>) -> tensor<4x2x3x64xf32> {
%0 = tensor.empty() : tensor<4x2x3x64xf32>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32>
return %1 : tensor<4x2x3x64xf32>
}

// -----
// CHECK: error: 'ttir.where' op Result shape must match operand shapes after broadcasting
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_ternary(%arg0: tensor<3x64xf32>, %arg1: tensor<1x3x64xf32>, %arg2: tensor<2x1x64xf32>) -> tensor<1x2x3x64xf32> {
%0 = tensor.empty() : tensor<1x2x3x64xf32>
%1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32>
return %1 : tensor<1x2x3x64xf32>
}
37 changes: 37 additions & 0 deletions test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for NOperands trait

// CHECK: error: 'ttir.abs' op expected 2 operands, but found 3
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_unary(%arg0: tensor<64x64xbf16>) -> tensor<64x64xbf16> {
%0 = tensor.empty() : tensor<64x64xbf16>
%1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
return %1 : tensor<64x64xbf16>
}

// -----
// CHECK: error: 'ttir.add' op expected 3 operands, but found 4
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_binary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
%0 = tensor.empty() : tensor<64x64xf32>
%1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

// -----
// CHECK: error: 'ttir.add' op expected 3 operands, but found 2
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_binary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> {
%0 = tensor.empty() : tensor<64x64xf32>
%1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

// -----
// CHECK: error: 'ttir.where' op expected 4 operands, but found 5
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
func.func @eltwise_ternary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
%0 = tensor.empty() : tensor<64x64xf32>
%1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array<i32: 4, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

0 comments on commit d7798cf

Please sign in to comment.