Skip to content

Commit

Permalink
Added support for bitwise AND, OR, XOR and NOT ops. (#1424)
Browse files Browse the repository at this point in the history
- Defined ops in TTIR and TTNN dialects.
- Implemented StableHLO to TTIR converison (unified with logical ops)
- Implemented TTIR to TTNN conversion
- Added tests

Fixes #1202.

Half of solution (tt-xla and tt-torch tests are second half) for issues:
#1051
#1053
#1054
#1055

Left asserts in runtime code due to
tenstorrent/tt-metal#13582.
  • Loading branch information
kmitrovicTT authored Dec 19, 2024
1 parent e748e71 commit 1b6d7d8
Show file tree
Hide file tree
Showing 18 changed files with 433 additions and 61 deletions.
55 changes: 55 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.

Example:
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "ttir.bitwise_not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate op.";
let description = [{
Expand Down Expand Up @@ -514,6 +527,48 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
}];
}

def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
}];
}

def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
let summary = "Eltwise bitwise XOR.";
let description = [{
Performs element-wise bitwise XOR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
let summary = "Eltwise minimum OP.";
let description = [{
Expand Down
55 changes: 55 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ def TTNN_LogicalNotOp: TTNN_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTNN_BitwiseNotOp : TTNN_ElementwiseUnaryOp<"bitwise_not"> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.

Example:
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "ttnn.bitwise_not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
}];
}

def TTNN_NegOp : TTNN_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate.";
let description = [{
Expand Down Expand Up @@ -461,6 +474,48 @@ def TTNN_LogicalXorOp : TTNN_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTNN_BitwiseAndOp : TTNN_ElementwiseBinaryOp<"bitwise_and"> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
}];
}

def TTNN_BitwiseOrOp : TTNN_ElementwiseBinaryOp<"bitwise_or"> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
}];
}

def TTNN_BitwiseXorOp : TTNN_ElementwiseBinaryOp<"bitwise_xor"> {
let summary = "Eltwise bitwise XOR.";
let description = [{
Performs element-wise bitwise XOR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];
}

def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Expand Down
6 changes: 5 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ enum EltwiseOpType: uint32 {
GreaterThan,
LogicalAnd,
LogicalOr,
LogicalXor,
LogicalNot,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
BitwiseNot,
Cbrt,
Minimum,
Ceil,
Expand All @@ -131,7 +136,6 @@ enum EltwiseOpType: uint32 {
Floor,
Where,
Gelu,
LogicalXor,
Clamp,
LeakyRelu,
Scatter,
Expand Down
98 changes: 59 additions & 39 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,16 @@ class StableHLOToTTIRConcatOpConversionPattern
}
};

template <typename SrcOp, typename DestOp,
// Class implementing conversion from StableHLO to TTIR logical and bitwise ops.
// StableHLO has AND, OR, XOR and NOT ops defined in such a way that they do two
// different things based on type of inputs. In case of booleans, they perform
// logical version of the op, and in case of integers they perform bitwise
// version of the op. We made a decision to make those two cases completely
// distinct ops in TTIR. Thus, a StableHLO `SrcOp` is rewritten to one of
// `DestOp`s based on operand types.
template <typename SrcOp, typename LogicalDestOp, typename BitwiseDestOp,
typename Adaptor = typename SrcOp::Adaptor>
class StableHLOToTTIROpLogicalOpConversionPattern
class StableHLOToTTIRLogicalAndBitwiseOpConversionPattern
: public OpConversionPattern<SrcOp> {

using OpConversionPattern<SrcOp>::OpConversionPattern;
Expand All @@ -977,37 +984,49 @@ class StableHLOToTTIROpLogicalOpConversionPattern
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<DestOp>(
srcOp,
TypeRange(
this->getTypeConverter()->convertType(outputTensor.getType())),
adaptor.getOperands(), ValueRange(outputTensor));

if (getStableHLOOpType(srcOp) == StableHLOOpType::kLogical) {
replaceOpWithNewOp<LogicalDestOp>(srcOp, adaptor, outputTensor, rewriter);
} else {
replaceOpWithNewOp<BitwiseDestOp>(srcOp, adaptor, outputTensor, rewriter);
}

return success();
}

private:
LogicalResult checkBasicLegality(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (mlir::cast<RankedTensorType>(srcOp->getOperand(0).getType())
.getElementTypeBitWidth() > 1 &&
mlir::cast<RankedTensorType>(srcOp->getOperand(1).getType())
.getElementTypeBitWidth() > 1) {
llvm::errs()
<< "error: TTIR does not support bitwise logical operation.\n";
return rewriter.notifyMatchFailure(
srcOp, "TTIR does not support bitwise logical operation.");
}
enum StableHLOOpType { kLogical = 0, kBitwise = 1 };

// Determines stablehlo op type based on its operand types (i.e. their
// bit width). This assumes boolean operands are modeled as 1bit wide ints.
static StableHLOOpType getStableHLOOpType(const SrcOp &srcOp) {
// Checks if all operands are boolean (have bit width equal to 1).
bool allOperandsAreBoolean = std::all_of(
srcOp->operand_begin(), srcOp->operand_end(), [](auto operand) {
return mlir::cast<RankedTensorType>(operand.getType())
.getElementTypeBitWidth() == 1;
});

return allOperandsAreBoolean ? StableHLOOpType::kLogical
: StableHLOOpType::kBitwise;
}

return success();
// Helper function to replace the operation with the new op to avoid code
// duplication.
template <typename DestOp>
void replaceOpWithNewOp(SrcOp srcOp, Adaptor adaptor,
tensor::EmptyOp outputTensor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<DestOp>(
srcOp,
TypeRange(
this->getTypeConverter()->convertType(outputTensor.getType())),
adaptor.getOperands(), ValueRange(outputTensor));
}
};

Expand Down Expand Up @@ -1885,20 +1904,21 @@ void addCCLOpsConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
ctx);
}

void addLogicalOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::AndOp, mlir::tt::ttir::LogicalAndOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::NotOp, mlir::tt::ttir::LogicalNotOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::OrOp, mlir::tt::ttir::LogicalOrOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpLogicalOpConversionPattern<
mlir::stablehlo::XorOp, mlir::tt::ttir::LogicalXorOp>>(typeConverter,
ctx);
void addLogicalAndBitwiseOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRLogicalAndBitwiseOpConversionPattern<
mlir::stablehlo::AndOp, mlir::tt::ttir::LogicalAndOp,
mlir::tt::ttir::BitwiseAndOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIRLogicalAndBitwiseOpConversionPattern<
mlir::stablehlo::OrOp, mlir::tt::ttir::LogicalOrOp,
mlir::tt::ttir::BitwiseOrOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIRLogicalAndBitwiseOpConversionPattern<
mlir::stablehlo::XorOp, mlir::tt::ttir::LogicalXorOp,
mlir::tt::ttir::BitwiseXorOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIRLogicalAndBitwiseOpConversionPattern<
mlir::stablehlo::NotOp, mlir::tt::ttir::LogicalNotOp,
mlir::tt::ttir::BitwiseNotOp>>(typeConverter, ctx);
}

void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -1963,8 +1983,8 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
addConcatOpsConversionPatterns(ctx, patterns, typeConverter);
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addCCLOpsConversionPattern(ctx, patterns, typeConverter);
addLogicalAndBitwiseOpsConversionPatterns(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,10 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::LogicalOrOp, ttnn::LogicalOrOp>,
ElementwiseOpConversionPattern<ttir::LogicalNotOp, ttnn::LogicalNotOp>,
ElementwiseOpConversionPattern<ttir::LogicalXorOp, ttnn::LogicalXorOp>,
ElementwiseOpConversionPattern<ttir::BitwiseAndOp, ttnn::BitwiseAndOp>,
ElementwiseOpConversionPattern<ttir::BitwiseOrOp, ttnn::BitwiseOrOp>,
ElementwiseOpConversionPattern<ttir::BitwiseXorOp, ttnn::BitwiseXorOp>,
ElementwiseOpConversionPattern<ttir::BitwiseNotOp, ttnn::BitwiseNotOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
ElementwiseOpConversionPattern<ttir::EqualOp, ttnn::EqualOp>,
ElementwiseOpConversionPattern<ttir::NotEqualOp, ttnn::NotEqualOp>,
Expand Down Expand Up @@ -1173,7 +1177,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
SoftmaxOpConversionPattern,
Expand Down
8 changes: 6 additions & 2 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::FloorOp>,
DefaultOpConversionPattern<ttnn::IsFiniteOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::BitwiseNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::LeakyReluOp>,
Expand All @@ -703,11 +704,14 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Eltwise binary ops
//
patterns.add<EltwiseBinaryOpConversionPattern<ttnn::AddOp>,
EltwiseBinaryOpConversionPattern<ttnn::SubtractOp>,
EltwiseBinaryOpConversionPattern<ttnn::MultiplyOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalAndOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalOrOp>,
EltwiseBinaryOpConversionPattern<ttnn::LogicalXorOp>,
EltwiseBinaryOpConversionPattern<ttnn::SubtractOp>,
EltwiseBinaryOpConversionPattern<ttnn::MultiplyOp>,
DefaultOpConversionPattern<ttnn::BitwiseAndOp>,
DefaultOpConversionPattern<ttnn::BitwiseOrOp>,
DefaultOpConversionPattern<ttnn::BitwiseXorOp>,
DefaultOpConversionPattern<ttnn::EqualOp>,
DefaultOpConversionPattern<ttnn::NotEqualOp>,
DefaultOpConversionPattern<ttnn::GreaterEqualOp>,
Expand Down
Loading

0 comments on commit 1b6d7d8

Please sign in to comment.