Skip to content

Commit

Permalink
Added E2E PermuteOp support (#1505)
Browse files Browse the repository at this point in the history
Added support for PermuteOp, since it's already supported in
[ttnn](https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.permute.html#ttnn.permute).

`stablehlo.transpose` and `ttir.permute` have the same semantics, so
decomposition of `stablehlo.transpose` into series of `ttir.transpose`es
is not needed anymore.

Closes #652
  • Loading branch information
azecevicTT authored Dec 20, 2024
1 parent 432f8d8 commit db26004
Show file tree
Hide file tree
Showing 31 changed files with 621 additions and 291 deletions.
27 changes: 27 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,33 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
}
// ANCHOR_END: adding_an_op_matmul_ttir

def TTIR_PermuteOp : TTIR_DPSOp<"permute"> {
let summary = "Permute operation.";
let description = [{
Permute input tensor dimensions.

Attributes:
- `permutation` array<i64>: The permutation of the input tensor dimensions.

Example:
%a = tensor.empty() : () -> tensor<2x3x4xi32>
%output = tensor.empty() : () -> tensor<3x4x2xi32>
%0 = "ttir.permute"(%a, %output) {permutation = array<i64: 1, 2, 0>} : (tensor<2x3x4xi32>, tensor<3x4x2xi32>) -> tensor<3x4x2xi32>
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI64ArrayAttr:$permutation);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TTIR top level generic ops
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1062,4 +1062,27 @@ def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
let hasVerifier = 1;
}

def TTNN_PermuteOp : TTNN_Op<"permute"> {
let summary = "Permute operation.";
let description = [{
Permute input tensor dimensions.

Attributes:
- `permutation` array<i64>: The permutation of the input tensor dimensions.

Example:
%a = tensor.empty() : () -> tensor<2x3x4xi32>
%0 = "ttir.permute"(%a) {permutation = array<i64: 1, 2, 0>} : (tensor<2x3x4xi32>) -> tensor<3x4x2xi32>
}];

let arguments = (ins AnyRankedTensor:$input,
DenseI64ArrayAttr:$permutation,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
DefaultValuedOptionalAttr<F32Attr, "0.0f">:$pad_value);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

#endif
9 changes: 9 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,14 @@ table AllGatherOp {
num_links: uint32;
}

table PermuteOp {
in: tt.target.TensorRef;
permutation: [int64];
memory_config: MemoryConfigDesc;
pad_value: float;
out: tt.target.TensorRef;
}

table ReduceScatterOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -343,6 +351,7 @@ union OpType {
ArangeOp,
UpdateCacheOp,
FillCacheOp,
PermuteOp,
}

table Operation {
Expand Down
80 changes: 66 additions & 14 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
#ifndef TTMLIR_UTILS_H
#define TTMLIR_UTILS_H

#include <cstdint>
#include <sstream>

#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"

#include <cstdint>

namespace ttmlir::utils {
template <typename T>
Expand Down Expand Up @@ -72,17 +72,14 @@ constexpr std::underlying_type_t<Enum> enum_as_int(Enum e) {
return static_cast<std::underlying_type_t<Enum>>(e);
}

template <typename T>
std::string join(const llvm::SmallVector<T> &vec,
const std::string &delimiter) {
std::ostringstream result;
for (size_t i = 0; i < vec.size(); ++i) {
result << vec[i];
if (i != vec.size() - 1) {
result << delimiter;
}
}
return result.str();
// Returns a string that is the concatenation of the string representations of
// Range R elements interleaved with separator. Example: join({1, 2, 3}, ", ")
// -> "1, 2, 3"
template <typename Range>
std::string join(Range &&R, llvm::StringRef separator) {
return llvm::join(
llvm::map_range(R, [](auto &v) { return llvm::Twine(v).str(); }),
separator);
}

// Prepacks `MlirAttribute`s stored in input array into a vector of
Expand Down Expand Up @@ -131,6 +128,61 @@ inline bool isRankedTensor(mlir::Value v) {
return mlir::isa<mlir::RankedTensorType>(v.getType());
}

// Returns the element received as a parameter. Useful as a callback for
// higher-order functions.
template <typename T>
inline T identity(T x) {
return x;
}

// Returns a vector of indices `permutation` such that input[permutation[i]] ==
// output[i], for all i. Assumes that input and output have the same elements.
// Example: input = [1, 2, 3], output = [3, 1, 2] -> [2, 0, 1]
template <typename T>
inline llvm::SmallVector<int64_t>
generatePermutation(llvm::ArrayRef<T> input, llvm::ArrayRef<T> output) {
assert(input.size() == output.size());

llvm::DenseMap<T, int64_t> indices;
for (const auto [index, value] : llvm::enumerate(input)) {
indices[value] = index;
}
llvm::SmallVector<int64_t> permutation;
for (const T &dim : output) {
permutation.push_back(indices[dim]);
}
return permutation;
}

// Returns a vector `output`, such that output[i] = input[permutation[i]], for
// all i. Assumes that permutation is a valid permutation of the indices of
// input. Example: input = [1, 2, 3], permutation = [2, 0, 1] -> [3, 1, 2]
template <typename T>
inline llvm::SmallVector<T>
applyPermutation(llvm::ArrayRef<T> input, llvm::ArrayRef<int64_t> permutation) {
assert(input.size() == permutation.size());

llvm::SmallVector<T> output(input.size());

llvm::transform(permutation, output.begin(),
[&](const int64_t i) { return input[i]; });

return output;
}

// Returns a vector `inversePermutation`, such that
// inversePermutation[permutation[i]] = i, for all i. Assumes that permutation
// is a valid permutation of a range(0, permutation.size()). Example:
// permutation = [2, 0, 1] -> [1, 2, 0]
inline llvm::SmallVector<int64_t>
inversePermutation(llvm::ArrayRef<int64_t> permutation) {
llvm::SmallVector<int64_t> inversePermutation(permutation.size());
for (size_t i = 0; i < permutation.size(); ++i) {
inversePermutation[permutation[i]] = i;
}
return inversePermutation;
}

} // namespace ttmlir::utils

#endif
74 changes: 15 additions & 59 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,46 +138,16 @@ class StableHLOToTTIRTransposeOpConversionPattern
matchAndRewrite(mlir::stablehlo::TransposeOp srcOp,
mlir::stablehlo::TransposeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto input = Value(adaptor.getOperand());
auto transposes = getPermutationTransposes(adaptor.getPermutation().vec());

for (auto transposeDims : transposes) {
auto dim0 = std::get<0>(transposeDims);
auto dim1 = std::get<1>(transposeDims);

auto inputType = mlir::cast<RankedTensorType>(input.getType());
auto outputShape = inputType.getShape().vec();
std::swap(outputShape[dim0], outputShape[dim1]);

auto outputType = RankedTensorType::get(
outputShape, inputType.getElementType(), inputType.getEncoding());

auto outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputShape, outputType.getElementType());

input = rewriter.create<mlir::tt::ttir::TransposeOp>(
srcOp.getLoc(), outputType, input, outputTensor,
rewriter.getSI32IntegerAttr(dim0), rewriter.getSI32IntegerAttr(dim1));
}
rewriter.replaceOp(srcOp, input);
::mlir::RankedTensorType outputType = mlir::cast<mlir::RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
// stablehlo.transpose and ttir.permute have the same semantics.
rewriter.replaceOpWithNewOp<mlir::tt::ttir::PermuteOp>(
srcOp, getTypeConverter()->convertType(srcOp.getResult().getType()),
adaptor.getOperand(), outputTensor, adaptor.getPermutation());
return success();
}

private:
std::vector<std::tuple<int64_t, int64_t>>
getPermutationTransposes(std::vector<int64_t> permutation) const {
std::vector<std::tuple<int64_t, int64_t>> transposes;
for (uint32_t i = 0; i < permutation.size(); i++) {
while (i != permutation[i]) {
transposes.push_back(
std::make_tuple(permutation[i], permutation[permutation[i]]));
std::swap(permutation[i], permutation[permutation[i]]);
}
}

return transposes;
}
};

class StableHLOToTTIRReshapeOpConversionPattern
Expand All @@ -204,19 +174,6 @@ class StableHLOToTTIRReshapeOpConversionPattern
adaptor.getOperand(), outputTensor, new_shape_attr);
return success();
}

LogicalResult
checkBasicLegality(mlir::stablehlo::TransposeOp &srcOp,
mlir::stablehlo::TransposeOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {

if (adaptor.getPermutation().size() != 2) {
return rewriter.notifyMatchFailure(
srcOp, "TTIR supports only two dimensional transposeOp.");
}

return success();
}
};

class StableHLOToTTIRDotGeneralOpConversionPattern
Expand Down Expand Up @@ -1831,13 +1788,6 @@ void addReduceOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRReduceOpConversionPattern>(typeConverter, ctx);
}

void addTransposeOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<StableHLOToTTIRTransposeOpConversionPattern>(typeConverter, ctx);
}

void addMatmulOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1891,6 +1841,12 @@ void addConcatOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRConcatOpConversionPattern>(typeConverter, ctx);
}

void addTransposeOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRTransposeOpConversionPattern>(typeConverter, ctx);
}

void addReshapeOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1973,7 +1929,6 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter);
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter);
addReduceOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
addGetDimensionSizeOpsConversionPatterns(ctx, patterns, typeConverter);
addTensorCreationOpsConversionPatterns(ctx, patterns, typeConverter);
Expand All @@ -1982,6 +1937,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReduceWindowOpConversionPattern(ctx, patterns, typeConverter);
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
addConcatOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpConversionPattern(ctx, patterns, typeConverter);
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addCCLOpsConversionPattern(ctx, patterns, typeConverter);
addLogicalAndBitwiseOpsConversionPatterns(ctx, patterns, typeConverter);
Expand Down
Loading

0 comments on commit db26004

Please sign in to comment.