diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index f448624e19b..7080e995a01 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -410,3 +410,34 @@ def test_fp32_support(input_shape, output_shape, device): output = ttnn.to_torch(ttnn_output) assert_with_pcc(torch_result, output, 0.9999) + + +@pytest.mark.parametrize( + "input_shape, output_shape", + [ + ((1, 1, 864, 128), (1, 27, 32, 128)), + ((1, 256, 32), (32, 256)), + ((1, 256, 1024), (1, 128, 32, 64)), + ((64, 32), (32, 64)), + ((1, 1445, 192), (1445, 192)), + ((1, 256), (1, 1, 256)), + ((16, 1, 32), (16, 1, 32)), + ], +) +def test_bf8_support(input_shape, output_shape, device): + torch_input_tensor = torch.randint(0, 100, input_shape) + torch_result = torch_input_tensor.reshape(output_shape) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + ttnn_output = ttnn.reshape(input_tensor, output_shape) + + output = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_result, output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 6ab3a58f5cb..27c68f53b18 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -32,10 +32,11 @@ FORCE_INLINE void enhanced_noc_async_write( const uint32_t src_l1_addr, const uint64_t dst_noc_addr, const uint32_t bytes) { // If you do not know the max_transfer_size at compile time write 0 to it. // only writes is true if we ONLY use noc_async_read and all calls to tt_memmove have use_read_datamover as False - if constexpr (((max_transfer_size < NOC_MAX_BURST_SIZE) && (max_transfer_size != 0)) || only_writes) { + if constexpr (only_writes) { noc_async_write_one_packet(src_l1_addr, dst_noc_addr, bytes); } else { - noc_async_write(src_l1_addr, dst_noc_addr, bytes); + noc_async_write( + src_l1_addr, dst_noc_addr, bytes); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 6af8646a007..8fb90cbc761 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -17,6 +17,7 @@ #include "ttnn/operations/data_movement/slice/slice.hpp" #include "ttnn/operations/core/core.hpp" #include "device/reshape_rm_op.hpp" +#include "ttnn/cpp/ttnn/operations/copy.hpp" #include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" #include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" #include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp" @@ -36,18 +37,20 @@ ttnn::Tensor convert_tile_to_rm( const uint8_t queue_id, const PadValue &pad_value ) { - //Convert the 3D->3D reshaping to row major and back to tile - auto rm_tensor = - ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr); - rm_tensor = ReshapeViewOperation::invoke( - rm_tensor, - shape, - memory_config, - queue_id, - pad_value - ); - rm_tensor = ttnn::to_layout(rm_tensor, ttnn::TILE_LAYOUT, rm_tensor.get_dtype(), memory_config, (Device*)nullptr); - return rm_tensor; + // Convert the 3D->3D reshaping to row major and back to tile + TT_FATAL( + !(((shape[-1] % tile_first_dim != 0) || (shape[-2] % tile_second_dim != 0) || + (tensor.get_shape()[-1] % tile_first_dim != 0) || (tensor.get_shape()[-2] % tile_second_dim != 0)) && + (tensor.get_dtype() == DataType::BFLOAT8_B)), + "illegal dimensions for a bfloat8 tensor"); + auto new_tensor = (tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(tensor, DataType::BFLOAT16) : tensor; + new_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr); + new_tensor = ReshapeViewOperation::invoke(new_tensor, shape, memory_config, queue_id, pad_value); + new_tensor = + ttnn::to_layout(new_tensor, ttnn::TILE_LAYOUT, new_tensor.get_dtype(), memory_config, (Device*)nullptr); + new_tensor = + (tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(new_tensor, tensor.get_dtype()) : new_tensor; + return new_tensor; } ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { //This function is due to embedding issue 15558, once the issue is fixed we want to delete it @@ -350,10 +353,10 @@ ttnn::Tensor ReshapeViewOperation::invoke( return tensor; } PadValue default_pad_value; - if(tensor.get_dtype() == DataType::BFLOAT16 or tensor.get_dtype() == DataType::FLOAT32) { + if (tensor.get_dtype() == DataType::BFLOAT8_B or tensor.get_dtype() == DataType::BFLOAT16 or + tensor.get_dtype() == DataType::FLOAT32) { default_pad_value = 0.0f; - } - else { + } else { default_pad_value = (uint32_t)0; }