Skip to content

Commit

Permalink
Merge branch 'main' into vwells/ttnn_hoisted_layout_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT authored Dec 19, 2024
2 parents 12156bc + 1b6d7d8 commit 6e0fdaa
Show file tree
Hide file tree
Showing 96 changed files with 1,229 additions and 566 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/on-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ on:
pull_request:
branches: [ "main" ]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
pre-commit:
uses: ./.github/workflows/pre-commit.yml
Expand Down Expand Up @@ -32,7 +36,7 @@ jobs:
gh workflow run ${{ env.WORKFLOW_NAME }} \
--repo ${{ env.TARGET_REPO }} --ref main \
--field test_mark=push \
--field mlir_override=${{ github.sha }}
--field mlir_override=${{ github.event.pull_request.head.sha }}
gh run list --workflow=${{ env.WORKFLOW_NAME }} --repo ${{ env.TARGET_REPO }} --limit 1
echo "Triggered ${{ env.TARGET_REPO }}"
echo "### Triggered [${{ env.TARGET_REPO }}](https://github.com/${{ env.TARGET_REPO }}/actions/workflows/${{ env.WORKFLOW_NAME }}) :rocket:" >> $GITHUB_STEP_SUMMARY
89 changes: 89 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.

Example:
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "ttir.bitwise_not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate op.";
let description = [{
Expand Down Expand Up @@ -514,6 +527,48 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
}];
}

def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
}];
}

def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
let summary = "Eltwise bitwise XOR.";
let description = [{
Performs element-wise bitwise XOR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
let summary = "Eltwise minimum OP.";
let description = [{
Expand Down Expand Up @@ -1129,6 +1184,40 @@ def TTIR_OnesOp : TTIR_Op<"ones"> {
let results = (outs AnyRankedTensor:$result);
}

def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]> {
let summary = "Reverse operation.";

let description = [{
Reverses the order of elements in the `operand` along the specified
`dimensions` and produces a `result` tensor.

Examples:
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1, 0>
} : (tensor<3x2xi64>) -> tensor<3x2xi64>
// %result: [[6, 5], [4, 3], [2, 1]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI64ArrayAttr:$dimensions);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
21 changes: 19 additions & 2 deletions include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace mlir::tt::ttir {

#ifdef TTMLIR_ENABLE_STABLEHLO
// Options for the TTIR to TTNN backend pipeline.
//
struct StableHLOToTTIRPipelineOptions
Expand All @@ -31,12 +32,28 @@ struct StableHLOToTTIRPipelineOptions
// that the TTIR inliner pass may inline the ops.
llvm::cl::init(true)};
};
#endif

struct LinalgToLLVMPipelineOptions
: public PassPipelineOptions<LinalgToLLVMPipelineOptions> {
// TODO (#1634): We might want some more options to say lower through affine
// loops instead of scf directly, etc. which could be new options.
Option<bool> cleanupOutputEnabled{
*this, "enable-optimization-passes",
llvm::cl::desc("Enable cleanup passes (canonicalize, SCC, CSE, "
"SymbolDCE) after basic lowering is finished."),
llvm::cl::init(true)};
};

#ifdef TTMLIR_ENABLE_STABLEHLO
void createStableHLOToTTIRPipeline(
OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options);
#endif

void createLinalgToLLVMPipeline(OpPassManager &pm,
const LinalgToLLVMPipelineOptions &options);

/// Registers all pipelines for the TTIR dialect. Currently,
/// this includes only the "stablehlo-to-ttir-pipeline".
/// Registers all pipelines for the TTIR dialect.
void registerTTIRPipelines();
} // namespace mlir::tt::ttir

Expand Down
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include <cstdint>

namespace mlir::tt::ttnn {

Expand Down
81 changes: 60 additions & 21 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,6 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let description = [{
Eltwise absolute operation.
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
wa::TTNNOperandWorkarounds tileLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::Tile);
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(tileLayoutWorkaround)
.addInputOperandWorkaround(tileLayoutWorkaround)
.addOutputOperandWorkaround(tileLayoutWorkaround);
}
}];
}

def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> {
Expand Down Expand Up @@ -257,6 +246,19 @@ def TTNN_LogicalNotOp: TTNN_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTNN_BitwiseNotOp : TTNN_ElementwiseUnaryOp<"bitwise_not"> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.

Example:
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "ttnn.bitwise_not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
}];
}

def TTNN_NegOp : TTNN_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate.";
let description = [{
Expand Down Expand Up @@ -472,6 +474,48 @@ def TTNN_LogicalXorOp : TTNN_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTNN_BitwiseAndOp : TTNN_ElementwiseBinaryOp<"bitwise_and"> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
}];
}

def TTNN_BitwiseOrOp : TTNN_ElementwiseBinaryOp<"bitwise_or"> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
}];
}

def TTNN_BitwiseXorOp : TTNN_ElementwiseBinaryOp<"bitwise_xor"> {
let summary = "Eltwise bitwise XOR.";
let description = [{
Performs element-wise bitwise XOR of two tensors `lhs` and `rhs`
and produces a `result` tensor.

Example:
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "ttnn.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];
}

def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum OP.";
let description = [{
Expand Down Expand Up @@ -567,8 +611,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
AnyRankedTensor:$weight);
AnyRankedTensor:$weight,
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -817,6 +861,9 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds();
}
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -858,14 +905,6 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> {
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::RowMajor);
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}
}];

let hasVerifier = 1;
}

Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
DataType getDataType() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForSharding(ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid);
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForL1Interleaved(ArrayRef<int64_t> tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid);
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
Expand Down
54 changes: 40 additions & 14 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,52 @@ struct TTNNOperandWorkarounds {
}
};

// Workaround result struct that encapsulates the previous and target
// (workaround) value and a method indicating whether the workaround modifies
// the workaround value.
template <typename T>
struct WorkaroundResult {
T previousValue;
T targetValue;
bool isModified() const { return previousValue != targetValue; }
};

// Layout workaround result struct.
struct LayoutWorkaroundResult : public WorkaroundResult<Layout> {};

// Buffer type workaround result struct.
struct BufferTypeWorkaroundResult : public WorkaroundResult<BufferType> {};

// Memory layout workaround result struct.
struct MemoryLayoutWorkaroundResult
: public WorkaroundResult<std::optional<TensorMemoryLayout>> {};

// Struct that encapsulates the result of applying the workarounds.
// It contains the target tensor layout, buffer type and tensor memory layout
// results and a flag indicating whether the workarounds were applied.
struct WorkaroundResult {
// Target tensor layout.
std::pair<Layout, bool> targetTensorLayoutResult;
struct WorkaroundResults {
// Tensor layout workaround result.
LayoutWorkaroundResult tensorLayoutResult;

// Target tensor buffer type.
std::pair<BufferType, bool> targetTensorBufferTypeResult;
// Tensor buffer type workaround result.
BufferTypeWorkaroundResult tensorBufferTypeResult;

// Target tensor memory layout. Can be nullopt for tensors on host.
std::pair<std::optional<TensorMemoryLayout>, bool>
targetTensorMemoryLayoutResult;
// Tensor memory layout workaround result.
MemoryLayoutWorkaroundResult tensorMemoryLayoutResult;

// Returns true if any of the workarounds were applied.
bool modified() const {
return targetTensorLayoutResult.second ||
targetTensorBufferTypeResult.second ||
targetTensorMemoryLayoutResult.second;
bool isModified() const {
return tensorLayoutResult.isModified() ||
tensorBufferTypeResult.isModified() ||
tensorMemoryLayoutResult.isModified();
}
};

// Apply the operand workarounds to the layout attribute that contains
// tensor layout, buffer type and tensor memory layout arguments.
// Returns the result of applying the workarounds.
WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround,
const TTNNLayoutAttr &inputLayoutAttr);
WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround,
const TTNNLayoutAttr &inputLayoutAttr);

// Class that encapsulates operands workarounds.
// It contains input and output workarounds for operands.
Expand Down Expand Up @@ -170,6 +189,13 @@ class TTNNOperandsWorkarounds {
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds;
};

// Workaround factory class that creates workarounds for ops.
class TTNNOperandsWorkaroundsFactory {
public:
// Create workarounds for max_pool2d op operands.
static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa

#endif
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct TTIRToTTNNBackendPipelineOptions
//
Option<bool> layouotWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(false)};
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};

Option<bool> decompositionWorkaroundsEnabled{
*this, "enable-decomposition-workaround-pass",
Expand Down
Loading

0 comments on commit 6e0fdaa

Please sign in to comment.