Skip to content

Commit

Permalink
#15747: Giving bfloat8 support to reshape (#15751)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#15747)

### Problem description
I neglected to add support for bfloat8_b as none of the tests looked for
it. I am adding both the support and an extra test

### What's changed
Added bfloat8_b support. As per the docs, bfloat8_b will only work on
unpadded tiles, so this is what we support and it asserts out in other
cases.
### Checklist
- [x] Post commit Passes
https://github.com/tenstorrent/tt-metal/actions/runs/12188208619
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
jvegaTT authored Dec 6, 2024
1 parent 5f656ea commit d422ccf
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
31 changes: 31 additions & 0 deletions tests/ttnn/unit_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,34 @@ def test_fp32_support(input_shape, output_shape, device):
output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)


@pytest.mark.parametrize(
"input_shape, output_shape",
[
((1, 1, 864, 128), (1, 27, 32, 128)),
((1, 256, 32), (32, 256)),
((1, 256, 1024), (1, 128, 32, 64)),
((64, 32), (32, 64)),
((1, 1445, 192), (1445, 192)),
((1, 256), (1, 1, 256)),
((16, 1, 32), (16, 1, 32)),
],
)
def test_bf8_support(input_shape, output_shape, device):
torch_input_tensor = torch.randint(0, 100, input_shape)
torch_result = torch_input_tensor.reshape(output_shape)

input_tensor = ttnn.from_torch(
torch_input_tensor,
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

ttnn_output = ttnn.reshape(input_tensor, output_shape)

output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_result, output, 0.9999)
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ FORCE_INLINE void enhanced_noc_async_write(
const uint32_t src_l1_addr, const uint64_t dst_noc_addr, const uint32_t bytes) {
// If you do not know the max_transfer_size at compile time write 0 to it.
// only writes is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as False
if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_writes) {
if constexpr (only_writes) {
noc_async_write_one_packet(src_l1_addr, dst_noc_addr, bytes);
} else {
noc_async_write(src_l1_addr, dst_noc_addr, bytes);
noc_async_write<max_transfer_size == 0 ? NOC_MAX_BURST_SIZE + 1 : max_transfer_size>(
src_l1_addr, dst_noc_addr, bytes);
}
}

Expand Down
33 changes: 18 additions & 15 deletions ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/operations/core/core.hpp"
#include "device/reshape_rm_op.hpp"
#include "ttnn/cpp/ttnn/operations/copy.hpp"
#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp"
#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp"
#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"
Expand All @@ -36,18 +37,20 @@ ttnn::Tensor convert_tile_to_rm(
const uint8_t queue_id,
const PadValue &pad_value
) {
//Convert the 3D->3D reshaping to row major and back to tile
auto rm_tensor =
ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr);
rm_tensor = ReshapeViewOperation::invoke(
rm_tensor,
shape,
memory_config,
queue_id,
pad_value
);
rm_tensor = ttnn::to_layout(rm_tensor, ttnn::TILE_LAYOUT, rm_tensor.get_dtype(), memory_config, (Device*)nullptr);
return rm_tensor;
// Convert the 3D->3D reshaping to row major and back to tile
TT_FATAL(
!(((shape[-1] % tile_first_dim != 0) || (shape[-2] % tile_second_dim != 0) ||
(tensor.get_shape()[-1] % tile_first_dim != 0) || (tensor.get_shape()[-2] % tile_second_dim != 0)) &&
(tensor.get_dtype() == DataType::BFLOAT8_B)),
"illegal dimensions for a bfloat8 tensor");
auto new_tensor = (tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(tensor, DataType::BFLOAT16) : tensor;
new_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr);
new_tensor = ReshapeViewOperation::invoke(new_tensor, shape, memory_config, queue_id, pad_value);
new_tensor =
ttnn::to_layout(new_tensor, ttnn::TILE_LAYOUT, new_tensor.get_dtype(), memory_config, (Device*)nullptr);
new_tensor =
(tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(new_tensor, tensor.get_dtype()) : new_tensor;
return new_tensor;
}
ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) {
//This function is due to embedding issue 15558, once the issue is fixed we want to delete it
Expand Down Expand Up @@ -350,10 +353,10 @@ ttnn::Tensor ReshapeViewOperation::invoke(
return tensor;
}
PadValue default_pad_value;
if(tensor.get_dtype() == DataType::BFLOAT16 or tensor.get_dtype() == DataType::FLOAT32) {
if (tensor.get_dtype() == DataType::BFLOAT8_B or tensor.get_dtype() == DataType::BFLOAT16 or
tensor.get_dtype() == DataType::FLOAT32) {
default_pad_value = 0.0f;
}
else {
} else {
default_pad_value = (uint32_t)0;
}

Expand Down

0 comments on commit d422ccf

Please sign in to comment.