Skip to content

Commit

Permalink
Support lowering add and mul through ttir.generic metal backend (#468)
Browse files Browse the repository at this point in the history
The core of this change is generating a loop nest from arith on tensors,
consider the following `ttir.generic` body:

  ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4):
    %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>
    "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> ()
  })

Into a loop nest using the scf dialect:

  "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4)
  "ttkernel.add_tiles_init"(%arg2, %arg3)
  %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
    %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32)
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8)
      "ttkernel.tile_regs_release"() : () -> ()
      %10 = arith.addi %arg8, %c1_i32 : i32
      scf.yield %10 : i32
    }
    scf.yield %9 : i32
  }
  "ttkernel.return"() : () -> ()
  • Loading branch information
nsmithtt authored Aug 29, 2024
1 parent 92bc980 commit 82c079b
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 80 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> {
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let hasVerifier = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
Expand Down
3 changes: 1 addition & 2 deletions include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def TTMetal_DispatchOp : TTMetal_Op<"dispatch", [DestinationStyleOpInterface, At
let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TTMetal_CoreRangeArrayAttr:$core_ranges,
TTKernel_ThreadTypeArrayAttr:$threadTypes,
ArrayAttr:$operand_cb_port_mapping);
TTKernel_ThreadTypeArrayAttr:$threadTypes);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region VariadicRegion<AnyRegion>:$regions);

Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ llvm::SmallVector<int64_t> evalShape(mlir::AffineMap map, Vector shape) {
return result;
}

template <typename Enum> std::underlying_type_t<Enum> enum_as_int(Enum e) {
return static_cast<std::underlying_type_t<Enum>>(e);
}
} // namespace ttmlir::utils

#endif
8 changes: 8 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
isMemorySpaceChange);
}

::mlir::LogicalResult mlir::tt::ttir::GenericOp::verify() {
if (getNumOperands() != getRegion().getNumArguments()) {
return emitOpError(
"The number of op operands and region/block operands must match");
}
return success();
}

template <typename OpTy>
static void buildGenericEltwiseBinaryRegion(::mlir::Location loc,
::mlir::OpBuilder &opBuilder,
Expand Down
Loading

0 comments on commit 82c079b

Please sign in to comment.