Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose WH sharded, generalize row major permute when N > 4, and do a minor refactor of ttnn::permute #15881

Merged
merged 20 commits into from
Dec 17, 2024

Conversation

sjameelTT
Copy link
Contributor

@sjameelTT sjameelTT commented Dec 10, 2024

Ticket

#14790 add transpose wh sharded implementation when shard shape < height dimension
#15165 add N-d permute with width dimension
#15589 correct permute dimensionality when less than 4D
#15750 remove the composite flag from permute
#12550 re-enable some permute tests for blackhole
#12349 re-enable working transpose tests for blackhole
#16066 disable test uniform as it's stochastic

Problem description

This PR addresses several permute and transpose problems all at once

  • Transpose WH sharded does not currently work when the shard shape is less than the height
  • Permute on greater than 4 dimensions does not work when moving width around (for both tiled and RM)
  • The Permute kernel when width doesn't change is single core
  • Permute has an unclean API in which we have a composite flag that is not generically applicable
  • Permute on less than 4 dimensions gets an incorrect output shape in cases where it's a no-op
  • Permute tests are disabled for BH due to LLK issues
  • Transpose tests are disabled for BH due to LLK issues

What's changed

  • Add transpose WH sharded implementation for when shard shape is less than the height dim (outputs a block sharded output)
  • Add an N-d permute kernel that works generically on any row major input. We have to call a global init each loop of the compute kernel as transpose sets some registers that aren't cleared (there's no transpose_uninit). This results in bad pcc when there's more than one loop. For GS/BH, even the global init doesn't solve the problem so the test is disabled. For Tiled, we need 5D untilize/tilize. This increases sweeps coverage for permute from 50% to 86%
  • For the optimized case where Permute's width dimension is not shuffled, make the kernel multicore
  • Remove composite flag that is default set to to make permute non-generic. This has caused forge models to have bad pcc as they were not aware of this optional argument.
  • Refactor ttnn::permute to add nop checks and correct shape calculations
  • Re-enable permute and transpose tests for blackhole

When replacing variants of transpose with this RM permute kernel, a lot of tests on BH/GS failed, so I will do that in a follow-up to address. The LLK issues are causing pains there. If we get N-d untilize/tilize support and once the LLK issues are fixed, permute should have the ability to be generic. The remaining issues for the pytorch 2.0 sweeps after the untilize/tilize fix are the CB overflow on transpose wh, which should be fixed out of the box when we replace the kernel that is used (which I am not doing in this PR since it doesn't work for GS/BH atm).

Checklist

@@ -30,6 +30,8 @@ class SimpleShape final : protected ShapeBase {
[[nodiscard]] size_t rank() const;
[[nodiscard]] uint64_t volume() const;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to return const value here.

Why is this API needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it for convenience since it's a commonly used helper for legacy shape shape. It makes it easier to switch to simpleshape. I can move it elsewhere if that's preferred.

@sjameelTT sjameelTT force-pushed the sjameel/transpose_wh_sharded branch 5 times, most recently from 515730e to 22d0b64 Compare December 12, 2024 17:54
)
tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config)

tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why is tt_input_tensor not sharded, but output is?
would there be any practical difference if input is sharded and output configuration is taken from input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tt_input_tensor that goes into transpose is sharded. Setting the memory_config=sharded_mem_config shouldn't make a difference here since it doesn't match the output.

Comment on lines 1041 to 1040
tt_output_tensor = ttnn.from_device(tt_output_tensor)
tt_output_tensor = ttnn.to_torch(tt_output_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be enough to call ttnn.to_torch, no need to explicitly call from_device before it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +145 to +147
uint32_t temp = array[i];
array[i] = array[j];
array[j] = temp;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is std::swap available in kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have std support in kernels unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya that would blow up the kernel program size a d not fit

uint32_t element_size,
uint32_t input_page_size,
uint32_t output_page_size) {
volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t*>(input_l1_addr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anywhere I can read about tt_l1_ptr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tt_metal/hw/inc/risc_attribs.h

#define tt_l1_ptr attribute((rvtt_l1_ptr))

Seems to be a wrapper around some riscv hw ptr concept


for (uint32_t n = 0; n < num_blocks; n++) {
// have to global init here, otherwise pcc is bad
// if n > 0, then some register isn't cleared and the output of tilize_block is garbage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I saw a thread of this, but did you fire a ticket?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

#15930

// removing this line causes the output of tilize_block to be garbage in the second iteration
tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize

// tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reading these comments - looks like we got many lines of workarounds. is it the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just the one global init workaround that we have atm, but I wrote a lot of comments documenting the issue so it's not lost to history

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And @rdjogoTT found the issue and all these comments are no longer needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pack_untilize_uninit is hard coded to a default cb_id = 16, which was cb_out0 before. Since the recommendation is now to use cb_ids in sequential order what worked with the old cb_id=16 does not work with the cb_id that I had. There's a clean up item that they're working on that removes the defaults since they're misleading.

Comment on lines 18 to 20
auto BUFFER_ALIGNMENT = input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT;
const auto& padded_shape = input_tensor.get_logical_shape(); // in anticipation of RM padding
return padded_shape[-1] * input_tensor.element_size();
return tt::round_up(padded_shape[-1] * input_tensor.element_size(), BUFFER_ALIGNMENT);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you get page_size from input_tensor.get_tensor_spec()?
CC @sminakov-tt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally you should be able to ask the spec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the case for many pieces of logic that you see reused between program factories

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Page size is actually configurable based on the op.

My page sizes are actually a constant size here for my other kernel independent of inputs, which is what makes this kernel generic. We sacrifice some perf to ensure it scales for all inputs (I read in 64B chunks regardless of input size). A generic page size getter probably wouldn't work (though tbf, most kernels just set them equal to row width for rm.

std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [input_tensor](std::int64_t idx) {
return input_tensor.get_logical_shape().get_normalized_index(idx);
});
if (detail::is_permute_nop(input_tensor, normalized_dims)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think about a 0 volume shapes with an arbitary rank like [1, 0, 32, 32]. Will this be handled nicely? Can I ask for a test?

And I think about rank < 2, will this be handled correctly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the 0 volume - there is no need for a device operation, but shape must be adjusted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I return it as a nop when it's rank < 2 now (not just rank 1).

ttnn::empty doesn't have SimpleShape support right now which is a bit annoying. I have to think of and set my padded shape for the 0 logical volume case as well as the normal shape. For now, it seems like with device operation launch it still works, though could boost perf by a bit if we avoided it all togther.

@@ -220,33 +206,14 @@ ttnn::Tensor ExecutePermute::invoke(
return new_order;
};
auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to the PR. Can we make unsqueeze_to_4D call view inside to make it clear that this operation is basically free?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @sminakov-tt to discuss how changing tensor params from the other thread is synchronized

// removing this line causes the output of tilize_block to be garbage in the second iteration
tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize

// tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue?

@sjameelTT sjameelTT force-pushed the sjameel/transpose_wh_sharded branch 2 times, most recently from 8fb31d3 to 546d29d Compare December 16, 2024 20:21
@sjameelTT sjameelTT changed the title Add tranpose WH sharded, generalize row major permute when N > 4, and do a minor refactor of ttnn::permute Add transpose WH sharded, generalize row major permute when N > 4, and do a minor refactor of ttnn::permute Dec 16, 2024
Copy link
Contributor

@ntarafdar ntarafdar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nitpicks


const InterleavedAddrGen<src0_is_dram> s0 = {.bank_base_address = src_addr, .page_size = page_size};

uint32_t curr_addr = src_addr;
for (uint32_t i = 0; i < num_rows; ++i) {
for (uint32_t row = start_row; row < end_row; ++row) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We love good variable names

input_shape[i - 1] = get_arg_val<uint32_t>(i);
perm[i - 1] = get_arg_val<uint32_t>(i + N);
dest_strides[i - 1] = get_arg_val<uint32_t>(i + 2 * N);
for (uint32_t i = 3; i < N + 3; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment about the magic number 3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@sjameelTT sjameelTT force-pushed the sjameel/transpose_wh_sharded branch 2 times, most recently from 89ac439 to cc2301b Compare December 16, 2024 22:00
@sjameelTT sjameelTT force-pushed the sjameel/transpose_wh_sharded branch from cc2301b to 01fa552 Compare December 17, 2024 05:48
@sjameelTT sjameelTT merged commit b80a975 into main Dec 17, 2024
229 of 231 checks passed
@sjameelTT sjameelTT deleted the sjameel/transpose_wh_sharded branch December 17, 2024 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants