Skip to content

Commit

Permalink
Generic region pass (#392)
Browse files Browse the repository at this point in the history
Repurpose the existing --ttir-generic pass to pattern match against any
ops that implement `GenericRegionOpInterface`.

Ops that implement the `GenericRegionOpInterface` must implement
methods:
- `getIndexingMaps`: Return a pair of indexingMaps and iteratorTypes for
  the given GenericRegionOp.
- `buildGenericRegion`: Rewrite self into the generic region block using
  the arith and math dialects.

One of the other major differences in this change is to rewrite the
generic region operands to be tensors with encoding `ShardLayout`.
Previously they were just memrefs, but arith and math dialects cannot
work with memref, tensor seemed like a sensible container, `ShardLayout`
now holds the memref inside of it.

Here's an example of what this transformation does, given input:
```mlir
%1 = "ttir.add"(%arg0, %arg1, %0) <{
  operandSegmentSizes = array<i32: 2, 1>,
  operand_constraints = [#any_device, #any_device, #any_device]}> :
    (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) ->
      tensor<64x128xf32>
```

We now have:
```mlir
%5 = "ttir.generic"(%1, %3, %4) <{
  grid = #tt.grid<1x1>,
  indexing_maps = [#map, #map, #map],
  iterator_types = [#parallel, #parallel],
  operandSegmentSizes = array<i32: 2, 1>,
  operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> ({
    ^bb0(%arg2: tensor<64x128xf32, #tt.shard_layout<memref<2x4x!tt.tile<32x32, f32>, #l1_>>>,
         %arg3: tensor<64x128xf32, #tt.shard_layout<memref<2x4x!tt.tile<32x32, f32>, #l1_>>>,
         %arg4: tensor<64x128xf32, #tt.shard_layout<memref<2x4x!tt.tile<32x32, f32>, #l1_>>>):
      %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.shard_layout<memref<2x4x!tt.tile<32x32, f32>, #l1_>>>
      "ttir.yield"(%8) : (tensor<64x128xf32, #tt.shard_layout<memref<2x4x!tt.tile<32x32, f32>, #l1_>>>) -> ()
    }) : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>
}
```

Currently the only ops that implement this interface are eltwise.
  • Loading branch information
nsmithtt authored Aug 20, 2024
1 parent 4cac2f8 commit ed496c7
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 62 deletions.
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
}];
}

def TT_BufferAttr : TT_Attr<"Buffer", "buffer", []> {
let summary = "Buffer attribute in TT dialect";
let description = [{
Describes the physical footprint and layout of a buffer in L1. Its memref must also have a shape with rank equal to DeviceAttr grid.
}];
let parameters = (ins AttrParameter<"MemRefType", "A memref that describes the physical footprint and layout of the buffer. It must also have a shape with rank equal to DeviceAttr grid.">:$memref);
let assemblyFormat = "`<` $memref `>`";

let extraClassDeclaration = [{
::mlir::Type getElementType() const;
::llvm::SmallVector<int64_t> getShape() const;
}];
}

def TT_DeviceAttr : TT_Attr<"Device", "device", []> {
let summary = "Device attribute in TT dialect";
let description = [{
Expand Down
39 changes: 37 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,42 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_AddOp : TTIR_ElementwiseBinaryOp<"add"> {
class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) {
assert(sameRank(getOperands()) &&
"For now all operands must have the same rank");
auto rank = mlir::cast<RankedTensorType>(getOperand(0).getType()).getRank();
SmallVector<AffineMap> indexingMaps(getNumOperands(),
builder.getMultiDimIdentityMap(rank));
SmallVector<Attribute> iteratorTypes(
rank, builder.getAttr<IteratorTypeAttr>(IteratorType::Parallel));
return {builder.getAffineMapArrayAttr(indexingMaps),
builder.getArrayAttr(iteratorTypes)};
}

static bool sameRank(mlir::OperandRange operands) {
if (operands.empty()) {
return true;
}
auto rank = mlir::cast<RankedTensorType>(operands[0].getType()).getRank();
for (auto operand : operands) {
if (mlir::cast<RankedTensorType>(operand.getType()).getRank() != rank) {
return false;
}
}
return true;
}
}];
}

def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add"> {
let summary = "Eltwise add.";
let description = [{
Eltwise add operation.
Expand All @@ -189,7 +224,7 @@ def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
}];
}

def TTIR_MultiplyOp : TTIR_ElementwiseBinaryOp<"multiply"> {
def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply"> {
let summary = "Eltwise multiply.";
let description = [{
Eltwise multiply operation.
Expand Down
38 changes: 37 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,48 @@ def TTIROpInterface : OpInterface<"TTIROp"> {
];
}

def TTIR_ElementwiseOpInterface : OpInterface<"TTIR_ElementwiseOpInterface"> {
def TTIR_ElementwiseOpInterface : OpInterface<"ElementwiseOp"> {
let cppNamespace = "::mlir::tt::ttir";

let verify = [{
return detail::verifyElementwiseOp($_op);
}];
}

def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> {
let cppNamespace = "::mlir::tt::ttir";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return a pair of indexingMaps and iteratorTypes for the given GenericRegionOp.
Where:
indexingMaps: a list of AffineMapAttr, one AffineMapAttr per each input and
output view. Such AffineMapAttr specifies the mapping between
the loops and the indexing within each view. It effectively
defines how the op can legally be parallelized.
iteratorTypes: an ArrayAttr specifying the type of the enclosing loops. Each
element of the list represents an iterator of one of the
following types, parallel and reduction.
}],
/*retTy=*/"std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr>",
/*methodName=*/"getIndexingMaps",
/*args=*/(ins "::mlir::Builder &":$builder),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Rewrite self into the generic region block using the arith and math dialects.
}],
/*retTy=*/"void",
/*methodName=*/"buildGenericRegion",
/*args=*/(ins "::mlir::OpBuilder &":$op_builder,
"::mlir::Block*":$block),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
];
}

#endif
13 changes: 10 additions & 3 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,24 @@ def TTIRImplicitDevice: Pass<"ttir-implicit-device", "::mlir::ModuleOp"> {
}];
}

def TTIRGeneric: Pass<"ttir-generic", "::mlir::ModuleOp"> {
def TTIRGenericKernel: Pass<"ttir-generic-kernel", "::mlir::ModuleOp"> {
let summary = "";
let description = [{
Wrap top level ops in a generic op.
Wrap top level kernel ops in a generic op.
}];
}

def TTIRGenericRegion: Pass<"ttir-generic", "::mlir::ModuleOp"> {
let summary = "";
let description = [{
Wrap top level elementwise ops in a generic op.
}];
}

def TTIRGenericRegionOperandsToMemref: Pass<"ttir-generic-region-operands-to-memref", "::mlir::ModuleOp"> {
let summary = "";
let description = [{
Convert generic region operands to memref.
Convert region operations to work on memref instead of tensors.
}];
}

Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,19 @@ mlir::AffineMap LayoutAttr::projectOnto(ArrayRef<int64_t> logicalTensorShape,
getContext());
}

mlir::Type BufferAttr::getElementType() const {
return getMemref().getElementType();
}

llvm::SmallVector<int64_t> BufferAttr::getShape() const {
SmallVector<int64_t> bufferShape(getMemref().getShape());
auto elementType = getElementType();
if (mlir::isa<TileType>(elementType)) {
return mlir::cast<TileType>(elementType).getScalarShape(bufferShape);
}
return bufferShape;
}

DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context,
SystemDescAttr systemDesc,
ArrayRef<unsigned> chipIds) {
Expand Down
26 changes: 25 additions & 1 deletion lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc"

Expand Down Expand Up @@ -48,6 +50,28 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
isMemorySpaceChange);
}

template <typename OpTy>
static void buildGenericEltwiseBinaryRegion(::mlir::Location loc,
::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
auto lhs = block->getArgument(0);
auto rhs = block->getArgument(1);
auto result = opBuilder.create<OpTy>(loc, lhs, rhs);
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, mlir::ValueRange({result}));
}

void mlir::tt::ttir::AddOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::AddFOp>(getLoc(), opBuilder,
block);
}

void mlir::tt::ttir::MultiplyOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::MulFOp>(getLoc(), opBuilder,
block);
}

::mlir::LogicalResult mlir::tt::ttir::SoftmaxOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
Expand Down
Loading

0 comments on commit ed496c7

Please sign in to comment.