Skip to content

Commit

Permalink
Support lowering add and mul through ttir.generic metal backend
Browse files Browse the repository at this point in the history
The core of this change is generating a loop nest from arith on tensors,
consider the following `ttir.generic` body:

  ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4):
    %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>
    "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> ()
  })

Into a loop nest using the scf dialect:

  "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4)
  "ttkernel.add_tiles_init"(%arg2, %arg3)
  %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
    %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32)
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8)
      "ttkernel.tile_regs_release"() : () -> ()
      %10 = arith.addi %arg8, %c1_i32 : i32
      scf.yield %10 : i32
    }
    scf.yield %9 : i32
  }
  "ttkernel.return"() : () -> ()
  • Loading branch information
nsmithtt committed Aug 28, 2024
1 parent 92bc980 commit 9396436
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 80 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> {
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let hasVerifier = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
Expand Down
3 changes: 1 addition & 2 deletions include/ttmlir/Dialect/TTMetal/IR/TTMetalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def TTMetal_DispatchOp : TTMetal_Op<"dispatch", [DestinationStyleOpInterface, At
let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TTMetal_CoreRangeArrayAttr:$core_ranges,
TTKernel_ThreadTypeArrayAttr:$threadTypes,
ArrayAttr:$operand_cb_port_mapping);
TTKernel_ThreadTypeArrayAttr:$threadTypes);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region VariadicRegion<AnyRegion>:$regions);

Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ llvm::SmallVector<int64_t> evalShape(mlir::AffineMap map, Vector shape) {
return result;
}

template <typename Enum> std::underlying_type_t<Enum> enum_as_int(Enum e) {
return static_cast<std::underlying_type_t<Enum>>(e);
}
} // namespace ttmlir::utils

#endif
8 changes: 8 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
isMemorySpaceChange);
}

::mlir::LogicalResult mlir::tt::ttir::GenericOp::verify() {
if (getNumOperands() != getRegion().getNumArguments()) {
return emitOpError(
"The number of op operands and region/block operands must match");
}
return success();
}

template <typename OpTy>
static void buildGenericEltwiseBinaryRegion(::mlir::Location loc,
::mlir::OpBuilder &opBuilder,
Expand Down
Loading

0 comments on commit 9396436

Please sign in to comment.