Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal Direct] Added support for ttir.exp op #743

Merged
merged 1 commit into from
Sep 24, 2024

Conversation

kmitrovicTT
Copy link
Contributor

@kmitrovicTT kmitrovicTT commented Sep 18, 2024

Summary

Fixes #534.

Added TTIR_GenericElementwiseUnaryOp. Made TTIR_ExpOp inherit from it. Handled lowering to ttkernel.

  • Refactored TTIRToTTMetalDispatchRewriter to include common building steps as well as some separate steps for unary and binary ops
  • Deprecated AcquireDstOp and ReleaseDstOp in favor of up-to-date metal tile_regs_* API: acquire, commit, wait, release

Example:

Running ttmlir-opt on

func.func @exp(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
  // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]]
  %0 = tensor.empty() : tensor<64x128xf32>
  // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]]
  %1 = "ttir.exp"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
  return %1 : tensor<64x128xf32>
}

produces (only actual compute kernel shown)

%3 = "ttmetal.dispatch"(%1, %2) <{core_ranges = [#ttmetal.core_range<0x0, 1x1>], kernelConfigs = [#ttkernel.tensix_config<hifi4, false, false, false>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg1: !ttkernel.cb<cb_in0, 140128, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, %arg2: !ttkernel.cb<cb_out0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>):
  %c4_i32 = arith.constant 4 : i32
  %c1_i32 = arith.constant 1 : i32
  %c2_i32 = arith.constant 2 : i32
  %c0_i32 = arith.constant 0 : i32
  "ttkernel.unary_op_init_common"(%arg1, %arg2) : (!ttkernel.cb<cb_in0, 140128, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, !ttkernel.cb<cb_out0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>) -> ()
  "ttkernel.exp_tile_init"() : () -> ()
  %6 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %c0_i32) -> (i32)  : i32 {
    %7 = scf.for %arg5 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg6 = %arg4) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.copy_tile"(%arg1, %arg6, %c0_i32) : (!ttkernel.cb<cb_in0, 140128, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, i32, i32) -> ()
      "ttkernel.exp_tile"(%c0_i32) : (i32) -> ()
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg2, %arg6) : (i32, !ttkernel.cb<cb_out0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, i32) -> ()
      "ttkernel.tile_regs_release"() : () -> ()
      %8 = arith.addi %arg6, %c1_i32 : i32
      scf.yield %8 : i32
    }
    scf.yield %7 : i32
  }
  "ttkernel.return"() : () -> ()
}) : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>

which resembles already existing form for binary ops like

    %5 = "ttmetal.dispatch"(%1, %3, %4) <{core_ranges = [#ttmetal.core_range<0x0, 1x1>], kernelConfigs = [#ttkernel.tensix_config<hifi4, false, false, false>], operandSegmentSizes = array<i32: 2, 1>}> ({
    ^bb0(%arg2: !ttkernel.cb<cb_in0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, %arg3: !ttkernel.cb<cb_in1, 205664, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, %arg4: !ttkernel.cb<cb_out0, 238432, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>):
      %c4_i32 = arith.constant 4 : i32
      %c1_i32 = arith.constant 1 : i32
      %c2_i32 = arith.constant 2 : i32
      %c0_i32 = arith.constant 0 : i32
      "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4) : (!ttkernel.cb<cb_in0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, !ttkernel.cb<cb_in1, 205664, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, !ttkernel.cb<cb_out0, 238432, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>) -> ()
      "ttkernel.add_tiles_init"(%arg2, %arg3) : (!ttkernel.cb<cb_in0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, !ttkernel.cb<cb_in1, 205664, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>) -> ()
      %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
        %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
          "ttkernel.tile_regs_acquire"() : () -> ()
          "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32) : (!ttkernel.cb<cb_in0, 172896, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, !ttkernel.cb<cb_in1, 205664, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, i32, i32, i32) -> ()
          "ttkernel.tile_regs_commit"() : () -> ()
          "ttkernel.tile_regs_wait"() : () -> ()
          "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8) : (i32, !ttkernel.cb<cb_out0, 238432, memref<2x4x!tt.tile<32x32, f32>, #l1_>, 4096, 1>, i32) -> ()
          "ttkernel.tile_regs_release"() : () -> ()
          %10 = arith.addi %arg8, %c1_i32 : i32
          scf.yield %10 : i32
        }
        scf.yield %9 : i32
      }
      "ttkernel.return"() : () -> ()
    }) : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>

even though it is different from original tt_metal/kernels/compute/eltwise_sfpu.cpp form, because this way made more sense logically (like executing tile_regs_wait after tile_regs_commit, not after tile_regs_acquire, like moving unary-op-specific init function before for loops, not right before unary op itself, etc).

Running this through ttrt produces kernel

void kernel_main() {
  ::tt::CB v1 = ::tt::CB::c_in0;
  ::tt::CB v2 = ::tt::CB::c_out0;
  int32_t v3 = 4;
  int32_t v4 = 1;
  int32_t v5 = 2;
  int32_t v6 = 0;
  unary_op_init_common(v1, v2);
  exp_tile_init();
  int32_t v7;
  v7 = v6;
  for (int32_t v8 = v6; v8 < v5; v8 += v4) {
    int32_t v9;
    v9 = v7;
    for (int32_t v10 = v6; v10 < v3; v10 += v4) {
      tile_regs_acquire();
      copy_tile(v1, v9, v6);
      exp_tile(v6);
      tile_regs_commit();
      tile_regs_wait();
      pack_tile(v6, v2, v9);
      tile_regs_release();
      uint32_t v11 = (uint32_t) v9;
      uint32_t v12 = (uint32_t) v4;
      uint32_t v13 = v11 + v12;
      int32_t v14 = (int32_t) v13;
      v9 = v14;
    };
    v7 = v9;
  }
  return;
}

Collecting generated input and output tensor and passing them through python script (PR here #760) comparing golden = numpy.exp(input) with device generated output gives True for np.allclose(output_tile, golden_tile, rtol=1e-1, atol=1e-1) for each tile in tensor.

Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! Couple of minor things inline.

include/ttmlir/Dialect/TTIR/IR/TTIROps.td Show resolved Hide resolved
runtime/tools/python/ttrt/common/api.py Outdated Show resolved Hide resolved
include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td Outdated Show resolved Hide resolved
include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td Outdated Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Outdated Show resolved Hide resolved
@kmitrovicTT kmitrovicTT force-pushed the kmitrovic/ttmetal_direct_eltwise_exp branch from 4e3717e to 899fd45 Compare September 23, 2024 09:43
@kmitrovicTT kmitrovicTT force-pushed the kmitrovic/ttmetal_direct_eltwise_exp branch 4 times, most recently from f050f1b to 0b3a37a Compare September 24, 2024 09:15
…. Handled lowering to ttkernel.

- Refactored TTIRToTTMetalDispatchRewriter to include common building steps as well as some separate steps for unary and binary ops
- Deprecated AcquireDstOp and ReleaseDstOp in favor of up-to-date metal tile_regs_* API: acquire, commit, wait, release
@kmitrovicTT kmitrovicTT force-pushed the kmitrovic/ttmetal_direct_eltwise_exp branch from 0b3a37a to eb6da07 Compare September 24, 2024 09:41
@kmitrovicTT kmitrovicTT merged commit f3ed73b into main Sep 24, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Metal Direct] Eltwise exp
4 participants