Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lowering add and mul through ttir.generic metal backend #468

Merged
merged 1 commit into from
Aug 29, 2024

Conversation

nsmithtt
Copy link
Contributor

@nsmithtt nsmithtt commented Aug 22, 2024

The core of this change is generating a loop nest from arith on tensors, consider the following ttir.generic body:

  ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4):
    %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>
    "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> ()
  })

Into a loop nest using the scf dialect:

  "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4)
  "ttkernel.add_tiles_init"(%arg2, %arg3)
  %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
    %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32)
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8)
      "ttkernel.tile_regs_release"() : () -> ()
      %10 = arith.addi %arg8, %c1_i32 : i32
      scf.yield %10 : i32
    }
    scf.yield %9 : i32
  }
  "ttkernel.return"() : () -> ()

@nsmithtt
Copy link
Contributor Author

If someone has a chance to take a look at this review that'd be great!

lib/Dialect/TTMetal/Transforms/Passes.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Show resolved Hide resolved
lib/Dialect/TTMetal/Transforms/Passes.cpp Show resolved Hide resolved
@rpavlovicTT
Copy link
Contributor

Nice work! Looks good to me, minor comments.


// Build the inner loop compute / unpack / pack
{
Value output = computeBlock->getArgument(numDPSInputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we don't do push_back/pop_front from CBs, did you intentionally skip it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes intentionally skipped, currently only works in buffer mode alias. When we add support for buffer mode stream we'll need to generate the cb push/pops for the streaming inputs.

The core of this change is generating a loop nest from arith on tensors,
consider the following `ttir.generic` body:

  ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4):
    %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>
    "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> ()
  })

Into a loop nest using the scf dialect:

  "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4)
  "ttkernel.add_tiles_init"(%arg2, %arg3)
  %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32)  : i32 {
    %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32)  : i32 {
      "ttkernel.tile_regs_acquire"() : () -> ()
      "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32)
      "ttkernel.tile_regs_commit"() : () -> ()
      "ttkernel.tile_regs_wait"() : () -> ()
      "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8)
      "ttkernel.tile_regs_release"() : () -> ()
      %10 = arith.addi %arg8, %c1_i32 : i32
      scf.yield %10 : i32
    }
    scf.yield %9 : i32
  }
  "ttkernel.return"() : () -> ()
@nsmithtt nsmithtt merged commit 82c079b into main Aug 29, 2024
13 checks passed
@nsmithtt nsmithtt deleted the nsmith/kernel11 branch August 29, 2024 01:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants