-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support lowering add and mul through ttir.generic metal backend #468
Conversation
If someone has a chance to take a look at this review that'd be great! |
Nice work! Looks good to me, minor comments. |
|
||
// Build the inner loop compute / unpack / pack | ||
{ | ||
Value output = computeBlock->getArgument(numDPSInputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we don't do push_back/pop_front from CBs, did you intentionally skip it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes intentionally skipped, currently only works in buffer mode alias. When we add support for buffer mode stream we'll need to generate the cb push/pops for the streaming inputs.
The core of this change is generating a loop nest from arith on tensors, consider the following `ttir.generic` body: ^bb0(%arg2: tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>, %arg3, %arg4): %8 = arith.addf %arg2, %arg3 : tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>> "ttir.yield"(%8) : (tensor<64x128xf32, #tt.buffer<memref<2x4x!tt.tile<32x32, f32>, #l1_>, alias>>) -> () }) Into a loop nest using the scf dialect: "ttkernel.binary_op_init_common"(%arg2, %arg3, %arg4) "ttkernel.add_tiles_init"(%arg2, %arg3) %8 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %c0_i32) -> (i32) : i32 { %9 = scf.for %arg7 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg8 = %arg6) -> (i32) : i32 { "ttkernel.tile_regs_acquire"() : () -> () "ttkernel.add_tiles"(%arg2, %arg3, %arg8, %arg8, %c0_i32) "ttkernel.tile_regs_commit"() : () -> () "ttkernel.tile_regs_wait"() : () -> () "ttkernel.pack_tile"(%c0_i32, %arg4, %arg8) "ttkernel.tile_regs_release"() : () -> () %10 = arith.addi %arg8, %c1_i32 : i32 scf.yield %10 : i32 } scf.yield %9 : i32 } "ttkernel.return"() : () -> ()
3027752
to
9396436
Compare
The core of this change is generating a loop nest from arith on tensors, consider the following
ttir.generic
body:Into a loop nest using the scf dialect: