diff --git a/tests/tt_metal/tt_metal/test_bcast.cpp b/tests/tt_metal/tt_metal/test_bcast.cpp index fb8fb784840..9796de5d1b4 100644 --- a/tests/tt_metal/tt_metal/test_bcast.cpp +++ b/tests/tt_metal/tt_metal/test_bcast.cpp @@ -154,28 +154,30 @@ int main(int argc, char** argv) { vector tiled_bcast_values; vector ref_bcast_values; - vector ref_bcast_shape = {N, C, 1, 1}; float bcast_1value = 10.0f; uint16_t bcast_1value16 = bfloat16(bcast_1value).to_uint16(); unsigned num_bcast_tiles = 0; // build the constant tiles to be broadcast if (bcast_dim == BcastDim::HW) { - num_bcast_tiles = NC; ref_bcast_values.resize(NC, 0); + vector ref_bcast_shape_with_tile_padding = {N, C, TILE_HEIGHT, TILE_WIDTH}; + vector ref_bcast_values_with_tile_padding; + ref_bcast_values_with_tile_padding.resize(NC * TILE_HEIGHT * TILE_WIDTH, 0); for (int j = 0; j < NC; j++) { // add something not too large but different between tiles - ref_bcast_values[j] = bfloat16(bcast_1value + (j % 7)).to_uint16(); + auto val = bfloat16(bcast_1value + (j % 7)).to_uint16(); + ref_bcast_values[j] = val; + ref_bcast_values_with_tile_padding[j * TILE_HEIGHT * TILE_WIDTH] = val; } // convert the reference broadcast tensor to tiled format tiled_bcast_values = convert_layout( - ref_bcast_values, - ref_bcast_shape, + ref_bcast_values_with_tile_padding, + ref_bcast_shape_with_tile_padding, tests::utils::TensorLayoutType::LIN_ROW_MAJOR, tests::utils::TensorLayoutType::TILED_NFACES); TT_FATAL(tiled_bcast_values[0] == bcast_1value16, "Error"); + num_bcast_tiles = NC; // restore ref values and shape to 1 - ref_bcast_shape[3] = 1; - ref_bcast_shape[4] = 1; } else if (bcast_dim == BcastDim::H) { // For bcast_h a.k.a. Dim::R we broadcast _over_ H, meaning we take a W vector and += it over each // element in the H dimension At least that's the behavior i've seen from a single tile bcast-H So @@ -185,14 +187,18 @@ int main(int argc, char** argv) { // generate broadcast values along the W axis with one extra tile (needed by the kernel I believe) // TODO(AP): need to figure out why the extra tile in broadcast inputs is expected by the kernel ref_bcast_values.resize(NC * W, 0); - ref_bcast_shape[3] = W; + vector ref_bcast_shape_with_tile_padding = {N, C, TILE_HEIGHT, W}; + vector ref_bcast_values_with_tile_padding; + ref_bcast_values_with_tile_padding.resize(NC * TILE_HEIGHT * W, 0); for (int j = 0; j < NC * W; j++) { // add something not too large but different between tiles - ref_bcast_values[j] = bfloat16(bcast_1value + (j % 7)).to_uint16(); + auto val = bfloat16(bcast_1value + (j % 7)).to_uint16(); + ref_bcast_values[j] = val; + ref_bcast_values_with_tile_padding[j % W + (j / W) * TILE_HEIGHT * W] = val; } tiled_bcast_values = convert_layout( - ref_bcast_values, - ref_bcast_shape, + ref_bcast_values_with_tile_padding, + ref_bcast_shape_with_tile_padding, tests::utils::TensorLayoutType::LIN_ROW_MAJOR, tests::utils::TensorLayoutType::TILED_NFACES); num_bcast_tiles = NC * Wt; @@ -200,14 +206,18 @@ int main(int argc, char** argv) { } else if (bcast_dim == BcastDim::W) { // see the comments above for BCAST_H ref_bcast_values.resize(NC * H, 0); - ref_bcast_shape[2] = H; + vector ref_bcast_shape_with_tile_padding = {N, C, H, TILE_WIDTH}; + vector ref_bcast_values_with_tile_padding; + ref_bcast_values_with_tile_padding.resize(NC * H * TILE_WIDTH, 0); for (int j = 0; j < NC * H; j++) { // add something not too large but different between tiles - ref_bcast_values[j] = bfloat16(bcast_1value + (j % 7)).to_uint16(); + auto val = bfloat16(bcast_1value + (j % 7)).to_uint16(); + ref_bcast_values[j] = val; + ref_bcast_values_with_tile_padding[j * TILE_WIDTH] = val; } tiled_bcast_values = convert_layout( - ref_bcast_values, - ref_bcast_shape, + ref_bcast_values_with_tile_padding, + ref_bcast_shape_with_tile_padding, tests::utils::TensorLayoutType::LIN_ROW_MAJOR, tests::utils::TensorLayoutType::TILED_NFACES); num_bcast_tiles = NC * Ht; diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py index 6e00f178fc0..63442308831 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py @@ -77,16 +77,14 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): tt_tensor = ttnn.Tensor(py_tensor, tt_dtype) if tt_dtype in {ttnn.bfloat8_b, ttnn.bfloat4_b}: assert tt_tensor.storage_type() == ttnn.StorageType.OWNED - tt_tensor = tt_tensor.to(ttnn.TILE_LAYOUT) + assert tt_tensor.layout == ttnn.TILE_LAYOUT else: assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED + assert tt_tensor.layout == ttnn.ROW_MAJOR_LAYOUT tt_tensor = tt_tensor.to(device) tt_tensor = tt_tensor.cpu() - if tt_dtype in {ttnn.bfloat8_b, ttnn.bfloat4_b}: - tt_tensor = tt_tensor.to(ttnn.ROW_MAJOR_LAYOUT) - if python_lib == torch: py_tensor_after_round_trip = tt_tensor.to_torch() elif python_lib == np: diff --git a/tt_metal/common/test_tiles.hpp b/tt_metal/common/test_tiles.hpp index 50674abc39d..44e18fbc448 100644 --- a/tt_metal/common/test_tiles.hpp +++ b/tt_metal/common/test_tiles.hpp @@ -25,13 +25,15 @@ enum class TensorLayoutType { }; } // namespace tests::utils +using PhysicalSize = std::array; + template typename BufferType> std::vector convert_to_tile_layout( const BufferType& data, - std::optional> tile_shape = std::nullopt, - std::optional> face_shape = std::nullopt, - const std::optional& transpose_within_face = std::nullopt, - const std::optional& transpose_of_faces = std::nullopt) { + std::optional tile_shape = std::nullopt, + std::optional face_shape = std::nullopt, + const bool transpose_face = false, + const bool transpose_face_order = false) { ZoneScoped; std::vector result; if(data.size() == 0) { @@ -45,8 +47,6 @@ std::vector convert_to_tile_layout( auto face_W = face_shape.has_value() ? face_shape.value()[1] : tt::constants::FACE_WIDTH; auto tile_HW = tile_H * tile_W; auto face_HW = face_H * face_W; - bool transpose_face = transpose_within_face.has_value() ? transpose_within_face.value() : false; - bool transpose_face_order = transpose_of_faces.has_value() ? transpose_of_faces.value() : false; TT_ASSERT(data.size() % tile_HW == 0); int num_tiles = data.size() / tile_HW; for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { @@ -116,10 +116,10 @@ std::vector convert_to_tile_layout( template typename BufferTyp> std::vector convert_to_flat_layout( const BufferTyp& data, - std::optional> tile_shape = std::nullopt, - std::optional> face_shape = std::nullopt, - const std::optional& transpose_within_face = std::nullopt, - const std::optional& transpose_of_faces = std::nullopt) { + std::optional tile_shape = std::nullopt, + std::optional face_shape = std::nullopt, + const bool transpose_face = false, + const bool transpose_face_order = false) { ZoneScoped; std::vector result; if(data.size() == 0) { @@ -134,8 +134,6 @@ std::vector convert_to_flat_layout( auto face_HW = face_H * face_W; auto num_faces_col = tile_W / face_W; auto num_faces_row = tile_H / face_H; - bool transpose_face = transpose_within_face.has_value() ? transpose_within_face.value() : false; - bool transpose_face_order = transpose_of_faces.has_value() ? transpose_of_faces.value() : false; TT_ASSERT(data.size() % tile_HW == 0); int num_tiles = data.size() / tile_HW; for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { @@ -194,38 +192,35 @@ std::vector convert_to_flat_layout( // Converts a 32-swizzled tilized row-major tensor to a linear 32-zero-padded row-major tensor template typename BufferType> -inline std::vector untilize_nchw(const BufferType& in, tt::stl::Span shape, std::optional> tile_shape = std::nullopt) { +inline std::vector untilize_nchw( + const BufferType& in, const PhysicalSize& shape, std::optional tile_shape = std::nullopt) { ZoneScoped; - auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; - auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; - std::vector result; if(in.size() == 0) { return result; } - TT_ASSERT(shape[shape.size() - 2] % tile_H == 0 && shape[shape.size() - 1] % tile_W == 0); + auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; + auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; + + TT_ASSERT(shape[0] % tile_H == 0 && shape[1] % tile_W == 0); // Untilize into row major - uint32_t H = shape[shape.size() - 2], W = shape[shape.size() - 1]; - uint64_t batch_size = 1; - for (uint32_t i = 0; i < shape.size() - 2; i++) { - batch_size *= shape[i]; - } - result.resize(batch_size * H * W); + uint32_t H = shape[0]; + uint32_t W = shape[1]; + + result.resize(H * W); uint64_t linear = 0; - for (auto batch_index = 0; batch_index < batch_size; batch_index++) { - for (auto hs = 0; hs < H; hs += tile_H) { // iterate over h with stride 32 - for (auto ws = 0; ws < W; ws += tile_W) { // iterate over w with stride 32 - for (auto ht = 0; ht < tile_H; ht++) { // hs + ht = h - for (auto wt = 0; wt < tile_W; wt++) { // ws + wt = w - T val = in[linear]; - auto w = wt + ws; - auto h = ht + hs; - auto offs = w + h * W + batch_index * H * W; - result[offs] = val; - linear++; - } + for (auto hs = 0; hs < H; hs += tile_H) { // iterate over h with stride 32 + for (auto ws = 0; ws < W; ws += tile_W) { // iterate over w with stride 32 + for (auto ht = 0; ht < tile_H; ht++) { // hs + ht = h + for (auto wt = 0; wt < tile_W; wt++) { // ws + wt = w + T val = in[linear]; + auto w = wt + ws; + auto h = ht + hs; + auto offs = w + h * W; // + batch_index * H * W; + result[offs] = val; + linear++; } } } @@ -240,50 +235,42 @@ inline std::uint32_t round_up_to_mul32(std::uint32_t val) { return ((val & 31) = inline std::uint32_t round_up_to_tile(int val, int tile_val) { return (val + tile_val - 1) & ~(tile_val - 1); } -// Converts a linear non-zero-padded row-major tensor to zero-padded-32 32-swizzled tilized row-major tensor +// Converts a linear non-zero-padded row-major tensor to 32-swizzled tilized row-major tensor template typename BufferType> -inline std::vector tilize_nchw(const BufferType& in_rowmajor, tt::stl::Span shape, std::optional> tile_shape = std::nullopt) { +inline std::vector tilize_nchw( + const BufferType& in_rowmajor, + const PhysicalSize& shape, + std::optional tile_shape = std::nullopt) { ZoneScoped; std::vector tilized_result; if(in_rowmajor.size() == 0) { return tilized_result; } - uint32_t H = shape[shape.size() - 2], W = shape[shape.size() - 1]; - uint64_t batch_size = 1; - for (uint32_t i = 0; i < shape.size() - 2; i++) { - batch_size *= shape[i]; - } - uint64_t input_volume = batch_size * H * W; auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; - uint32_t OH = round_up_to_tile(H, tile_H); - uint32_t OW = round_up_to_tile(W, tile_W); - tilized_result.resize(batch_size * OH * OW); - std::fill(tilized_result.begin(), tilized_result.end(), 0); + + TT_ASSERT(shape[0] % tile_H == 0 && shape[1] % tile_W == 0); + + uint32_t H = shape[0]; + uint32_t W = shape[1]; + + tilized_result.resize(H * W); uint64_t out_index = 0; - for (auto batch_index = 0; batch_index < batch_size; batch_index++) { - for (auto hs = 0; hs < H; hs += tile_H) { - for (auto ws = 0; ws < W; ws += tile_W) { - for (auto ht = 0; ht < tile_H; ht++) { - for (auto wt = 0; wt < tile_W; wt++) { - auto w = wt + ws; - auto h = ht + hs; - auto in_offs = w + h * W + batch_index * H * W; - auto val = (w >= W || h >= H || in_offs >= input_volume) ? 0 : in_rowmajor[in_offs]; - auto out_w = (out_index % OW); - auto out_h = (out_index / OW) % OH; - TT_ASSERT(w < OW); - TT_ASSERT(h < OH); - auto out_offs = out_w + out_h * OW + batch_index * OH * OW; - tilized_result[out_offs] = val; - out_index++; - } + for (auto hs = 0; hs < H; hs += tile_H) { + for (auto ws = 0; ws < W; ws += tile_W) { + for (auto ht = 0; ht < tile_H; ht++) { + for (auto wt = 0; wt < tile_W; wt++) { + auto w = wt + ws; + auto h = ht + hs; + auto in_offs = w + h * W; + auto val = in_rowmajor[in_offs]; + tilized_result[out_index] = val; + out_index++; } } } } - TT_ASSERT(tilized_result.size() == batch_size * OH * OW); return tilized_result; } @@ -308,13 +295,13 @@ struct TensAddr { template typename BufferType> inline std::vector convert_layout( const BufferType& inp, - tt::stl::Span shape, + const PhysicalSize& shape, tests::utils::TensorLayoutType inL, tests::utils::TensorLayoutType outL, - std::optional> tile_shape = std::nullopt, - std::optional> face_shape = std::nullopt, - const std::optional& transpose_within_face = std::nullopt, - const std::optional& transpose_of_faces = std::nullopt) { + std::optional tile_shape = std::nullopt, + std::optional face_shape = std::nullopt, + const bool transpose_within_face = false, + const bool transpose_of_faces = false) { ZoneScoped; if(inp.size() == 0) { return std::vector(); @@ -333,8 +320,9 @@ inline std::vector convert_layout( if (outL == tests::utils::TensorLayoutType::TILED_SWIZZLED) { return tilize_nchw(inp, shape, tile_shape); } else if (outL == tests::utils::TensorLayoutType::TILED_NFACES) { - auto swiz32 = convert_layout(inp, shape, inL, tests::utils::TensorLayoutType::TILED_SWIZZLED, tile_shape, face_shape, transpose_within_face, transpose_of_faces); - return convert_layout(swiz32, shape, tests::utils::TensorLayoutType::TILED_SWIZZLED, outL, tile_shape, face_shape, transpose_within_face, transpose_of_faces); + auto swiz32 = tilize_nchw(inp, shape, tile_shape); + return convert_to_tile_layout( + swiz32, tile_shape, face_shape, transpose_within_face, transpose_of_faces); } else TT_ASSERT(false && "Unsupported conversion."); break; @@ -342,7 +330,8 @@ inline std::vector convert_layout( if (outL == tests::utils::TensorLayoutType::TILED_SWIZZLED) { return convert_to_flat_layout(inp, tile_shape, face_shape, transpose_within_face, transpose_of_faces); } else if (outL == tests::utils::TensorLayoutType::LIN_ROW_MAJOR) { - auto swiz32 = convert_layout(inp, shape, inL, tests::utils::TensorLayoutType::TILED_SWIZZLED, tile_shape, face_shape, transpose_within_face, transpose_of_faces); + auto swiz32 = + convert_to_flat_layout(inp, tile_shape, face_shape, transpose_within_face, transpose_of_faces); return untilize_nchw(swiz32, shape, tile_shape); } else { TT_ASSERT(false && "Unsupported conversion"); @@ -353,3 +342,25 @@ inline std::vector convert_layout( } return std::vector(); } + +template typename BufferType> +inline std::vector convert_layout( + const BufferType& inp, + tt::stl::Span shape, + tests::utils::TensorLayoutType inL, + tests::utils::TensorLayoutType outL, + std::optional tile_shape = std::nullopt, + std::optional face_shape = std::nullopt, + const bool transpose_within_face = false, + const bool transpose_of_faces = false) { + ZoneScoped; + + TT_ASSERT(shape.size() >= 2, "Shape size {} must be at least rank 2!", shape.size()); + uint32_t H = shape[shape.size() - 2]; + uint32_t W = shape[shape.size() - 1]; + for (int i = 0; i < shape.size() - 2; i++) { + H *= shape[i]; + } + return convert_layout( + inp, PhysicalSize{H, W}, inL, outL, tile_shape, face_shape, transpose_within_face, transpose_of_faces); +} diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 0a244cc1655..8cd2e3da094 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -18,6 +18,7 @@ #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/tensor/tensor_ops.hpp" using namespace tt::tt_metal; @@ -78,6 +79,72 @@ Tensor create_owned_tensor( return Tensor(std::move(storage), shape, data_type, layout, optional_tile); } +OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector&& data, DataType data_type) { + switch (data_type) { + case DataType::FLOAT32: { + return owned_buffer::create(std::move(data)); + } + case DataType::BFLOAT16: { + std::vector<::bfloat16> bfloat16_data(data.size()); + std::transform(std::begin(data), std::end(data), std::begin(bfloat16_data), [](float value) { + return ::bfloat16(value); + }); + return owned_buffer::create<::bfloat16>(std::move(bfloat16_data)); + } + default: { + TT_THROW("Cannot create a host buffer!"); + } + } +} + +Tensor convert_float_vector_to_tt_tensor( + std::vector&& data, + const std::array& shape, + DataType data_type, + Layout layout, + Device* device, + const std::optional& memory_config, + const std::optional& tile) { + if (data_type == DataType::BFLOAT8_B || data_type == DataType::BFLOAT4_B) { + if (layout != Layout::TILE) { + log_warning( + tt::LogAlways, + "Tensor layout must be Layout::TILE for bfloat8_b or bfloat4_b! Tensor layout will be {} instead of " + "the requested {}!", + Layout::TILE, + layout); + } + auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), DataType::FLOAT32); + auto float_tensor = Tensor(OwnedStorage{owned_buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, tile); + auto tile_val = tile.value_or(Tile()); + if (shape[2] % tile_val.get_height() != 0 || shape[3] % tile_val.get_width() != 0) { + auto padded_shape = shape; + padded_shape[2] = tt::round_up(shape[2], tile_val.get_height()); + padded_shape[3] = tt::round_up(shape[3], tile_val.get_width()); + + float_tensor = tensor_ops::tensor_pad( + float_tensor, LegacyShape(shape, padded_shape), ttnn::SimpleShape{0, 0, 0, 0}, 0); + } + auto output_float_data = owned_buffer::get_as(float_tensor.to(Layout::TILE)).get(); + auto output_packed_data = + data_type == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); + auto output_buffer = owned_buffer::create(std::move(output_packed_data)); + auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); + if (device) { + return tensor.to(device, memory_config.value_or(MemoryConfig{})); + } + return tensor; + } + auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type); + auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); + if (device) { + return tensor.to(device, memory_config.value_or(MemoryConfig{})); + } + return tensor; +} + Tensor create_tt_tensor_from_py_data( std::size_t num_elements, std::size_t py_data_ptr, @@ -149,19 +216,7 @@ Tensor create_tt_tensor_from_py_data( return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } - case DataType::BFLOAT8_B: { - auto data_ptr = reinterpret_cast(py_data_ptr); - auto data = std::vector(data_ptr, data_ptr + num_elements); - auto buffer = owned_buffer::create(std::move(data)); - auto tile = optional_tile.value_or(Tile()); - auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); - auto output_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); - } + case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(py_data_ptr); auto data = std::vector(data_ptr, data_ptr + num_elements); @@ -170,8 +225,11 @@ Tensor create_tt_tensor_from_py_data( auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) .to(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); + auto output_packed_data = data_type == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles( + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile) + : pack_fp32_vec_as_bfp4_tiles( + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); } @@ -374,32 +432,6 @@ Tensor convert_python_tensors_to_tt_tensors( return output; } -OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector&& data, DataType data_type) { - switch (data_type) { - case DataType::BFLOAT8_B: { - auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return owned_buffer::create(std::move(uint32_vector)); - } - case DataType::BFLOAT4_B: { - auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - return owned_buffer::create(std::move(uint32_vector)); - } - case DataType::FLOAT32: { - return owned_buffer::create(std::move(data)); - } - case DataType::BFLOAT16: { - std::vector<::bfloat16> bfloat16_data(data.size()); - std::transform(std::begin(data), std::end(data), std::begin(bfloat16_data), [](float value) { - return ::bfloat16(value); - }); - return owned_buffer::create<::bfloat16>(std::move(bfloat16_data)); - } - default: { - TT_THROW("Cannot create a host buffer!"); - } - } -} - std::pair, DataType> get_buffer_and_dtype_from_tensor( const Tensor& tt_tensor) { TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); @@ -427,34 +459,16 @@ std::pair, DataType> get_buffer_and_dt const auto tile = tt_tensor.get_tensor_spec().tile(); auto tt_dtype = tt_tensor.get_dtype(); - if (tt_dtype == DataType::BFLOAT8_B) { - TT_ASSERT( - std::holds_alternative(buffer), - "Unexpected type {}", - tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = - unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tile) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; - } - if (tt_dtype == DataType::BFLOAT4_B) { + if (tt_dtype == DataType::BFLOAT8_B || tt_dtype == DataType::BFLOAT4_B) { TT_ASSERT( std::holds_alternative(buffer), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(buffer)); auto uint32_data = std::get>(std::get(buffer)).get(); auto float_unpacked_data = - unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + tt_dtype == DataType::BFLOAT8_B + ? unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile) + : unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); auto float_tensor = Tensor( OwnedStorage{input_float_buffer}, @@ -676,8 +690,8 @@ void pytensor_module(py::module& m_tensor) { DataType data_type, Layout layout, const std::optional& tile) { - auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - return Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); + return detail::convert_float_vector_to_tt_tensor( + std::move(data), shape, data_type, layout, nullptr, std::nullopt, tile); }), py::arg("data"), py::arg("shape"), @@ -717,16 +731,15 @@ void pytensor_module(py::module& m_tensor) { Layout layout, Device* device, const std::optional& tile) { - auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); - return tensor.to(device, MemoryConfig{}); + return detail::convert_float_vector_to_tt_tensor( + std::move(data), shape, data_type, layout, device, std::nullopt, tile); }), py::keep_alive<1, 6>(), py::arg("data"), py::arg("shape"), py::arg("data_type"), py::arg("layout"), - py::arg("device") = std::nullopt, + py::arg("device") = nullptr, py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( @@ -771,16 +784,15 @@ void pytensor_module(py::module& m_tensor) { Device* device, const MemoryConfig& memory_config, const std::optional& tile) { - auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); - return tensor.to(device, memory_config); + return detail::convert_float_vector_to_tt_tensor( + std::move(data), shape, data_type, layout, device, memory_config, tile); }), py::keep_alive<1, 7>(), py::arg("data"), py::arg("shape"), py::arg("data_type"), py::arg("layout"), - py::arg("device") = std::nullopt, + py::arg("device") = nullptr, py::arg("memory_config"), py::arg("tile") = std::nullopt, py::return_value_policy::move, diff --git a/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp index 2d8c6f43a0c..4fa171be016 100644 --- a/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp +++ b/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp @@ -137,21 +137,20 @@ inline Tensor create_tensor_from_buffer( auto data = cast<::bfloat16, T>(input_buffer); return create_owned_tensor(std::move(data), shape, dtype, Layout::ROW_MAJOR).to(input_layout); } - case DataType::BFLOAT8_B: { - auto data = cast(input_buffer); - auto uint32_vector = pack_fp32_vec_as_bfp8_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto buffer = owned_buffer::create(std::move(uint32_vector)); - auto storage = OwnedStorage{std::move(buffer)}; - return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR) - .to(ttnn::TILE_LAYOUT); // has to be in tile layout - } + case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { auto data = cast(input_buffer); - auto uint32_vector = pack_fp32_vec_as_bfp4_tiles(data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto buffer = owned_buffer::create(std::move(uint32_vector)); - auto storage = OwnedStorage{std::move(buffer)}; - return Tensor(std::move(storage), shape, dtype, Layout::ROW_MAJOR) - .to(ttnn::TILE_LAYOUT); // has to be in tile layout + auto buffer = owned_buffer::create(std::move(data)); + auto tensor = + Tensor(OwnedStorage{std::move(buffer)}, shape, DataType::FLOAT32, Layout::ROW_MAJOR).to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); + auto output_packed_data = + dtype == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_buffer = owned_buffer::create(std::move(output_packed_data)); + return Tensor( + OwnedStorage{std::move(output_buffer)}, shape, dtype, Layout::TILE); // has to be in tile layout } default: { TT_THROW("Unsupported DataType: {}", dtype); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 456b65d7f1a..20df7107370 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -822,21 +822,21 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { return tensor; } - auto shape = tensor.get_legacy_shape(); auto source_layout = tensor.get_layout(); auto tile = tensor.tensor_spec().tile(); - auto convert = [tile, &shape, source_layout, target_layout](const auto& input_data) -> std::vector { + auto physical_shape = tensor.tensor_spec().physical_shape(); + auto convert = [tile, &physical_shape, source_layout, target_layout](const auto& input_data) -> std::vector { switch (source_layout) { case Layout::ROW_MAJOR: if (target_layout == Layout::TILE) { - return convert_layout_row_major_to_tile(shape, tile, input_data); + return convert_layout_row_major_to_tile(physical_shape, tile, input_data); } else { TT_THROW("Unsupported layout conversion"); } break; case Layout::TILE: if (target_layout == Layout::ROW_MAJOR) { - return convert_layout_tile_to_row_major(shape, tile, input_data); + return convert_layout_tile_to_row_major(physical_shape, tile, input_data); } else { TT_THROW("Unsupported layout conversion"); } @@ -916,91 +916,18 @@ inline std::vector pack_fp32_vec_as_bfloat_tiles(const bfloat4_b&, Arg return pack_fp32_vec_as_bfp4_tiles(std::forward(args)...); } -// Template specialization for BFloatLayout based on type T -template -struct bfloat_enum; - -template <> -struct bfloat_enum { - static constexpr DataType value = DataType::BFLOAT8_B; -}; - -template <> -struct bfloat_enum { - static constexpr DataType value = DataType::BFLOAT4_B; -}; - template Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { static_assert(std::is_same_v || std::is_same_v, "Invalid type T"); - - // TODO(arakhmati): do not convert to FLOA32 - - if (tensor.get_layout() == target_layout) { - return tensor; + // TODO: Flip to assert when we remove use cases in python and c++ + if (tensor.get_layout() != target_layout or tensor.get_layout() != Layout::TILE) { + log_warning( + tt::LogAlways, + "Tensor layout must be Layout::TILE for bfloat8_b or bfloat4_b! Conversion from {} to {} was not executed!", + tensor.get_layout(), + target_layout); } - auto tile = tensor.get_tensor_spec().tile(); - return std::visit( - [&tensor, &target_layout, &tile](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - std::vector output_buffers; - for (int i = 0; i < storage.num_buffers(); i++) { - // Convert to FLOAT32 tensor and change layout - auto input_packed_data = owned_buffer::get_as(storage.get_buffer(i)).get(); - auto input_float_data = unpack_bfloat_tiles_into_float_vec( - T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tensor.get_legacy_shape(), - DataType::FLOAT32, - tensor.get_layout(), - tile) - .to(target_layout); - - // Convert back to BFLOAT8_B - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfloat_tiles( - T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - output_buffers.push_back(output_uint32_buffer); - } - return Tensor( - std::move(MultiDeviceHostStorage{storage.strategy, output_buffers, storage.shapes}), - tensor.get_legacy_shape(), - bfloat_enum::value, - target_layout, - tile); - - } else { - // Convert to FLOAT32 tensor and change layout - auto input_packed_data = owned_buffer::get_as(tensor).get(); - auto input_float_data = unpack_bfloat_tiles_into_float_vec( - T{}, input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tensor.get_legacy_shape(), - DataType::FLOAT32, - tensor.get_layout(), - tile) - .to(target_layout); - - // Convert back to BFLOAT - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - auto output_packed_data = pack_fp32_vec_as_bfloat_tiles( - T{}, output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - tensor.get_legacy_shape(), - bfloat_enum::value, - target_layout, - tile); - } - }, - tensor.get_storage()); + return tensor; } template <> diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index e8b059dc155..5a0ec30ecdd 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -98,22 +98,23 @@ static ttnn::SmallVector to_4D_shape(const tt::tt_metal::LegacyShape& template typename BufferType> inline std::vector convert_layout_row_major_to_tile( - const tt::tt_metal::LegacyShape& shape, const Tile& tile, const BufferType& data_to_convert) { + const Size& shape, const Tile& tile, const BufferType& data_to_convert) { TT_FATAL( - (shape[-2] % tile.get_tile_shape()[0] == 0 && shape[-1] % tile.get_tile_shape()[1] == 0), + (shape.height() % tile.get_tile_shape()[0] == 0 && shape.width() % tile.get_tile_shape()[1] == 0), "Unsupported shape for tensor conversion from row-major to tile layout. The tensor shape height and width must " "be a multiple of tile height ({}) and width ({}), but the provided shape is {}", tile.get_tile_shape()[0], tile.get_tile_shape()[1], shape); - auto tile_shape = ttnn::SmallVector{tile.get_tile_shape()[0], tile.get_tile_shape()[1]}; - auto face_shape = ttnn::SmallVector{tile.get_face_shape()[0], tile.get_face_shape()[1]}; + auto tile_shape = tile.get_tile_shape(); + auto face_shape = tile.get_face_shape(); auto transpose_within_face = tile.get_transpose_within_face(); auto transpose_of_faces = tile.get_transpose_of_faces(); + return convert_layout( data_to_convert, - tt::stl::Span(shape.begin(), shape.end()), + shape, tests::utils::TensorLayoutType::LIN_ROW_MAJOR, tests::utils::TensorLayoutType::TILED_NFACES, tile_shape, @@ -124,14 +125,15 @@ inline std::vector convert_layout_row_major_to_tile( template typename BufferType> inline std::vector convert_layout_tile_to_row_major( - const tt::tt_metal::LegacyShape& shape, const Tile& tile, const BufferType& data_to_convert) { - auto tile_shape = ttnn::SmallVector{tile.get_tile_shape()[0], tile.get_tile_shape()[1]}; - auto face_shape = ttnn::SmallVector{tile.get_face_shape()[0], tile.get_face_shape()[1]}; + const Size& shape, const Tile& tile, const BufferType& data_to_convert) { + auto tile_shape = tile.get_tile_shape(); + auto face_shape = tile.get_face_shape(); auto transpose_within_face = tile.get_transpose_within_face(); auto transpose_of_faces = tile.get_transpose_of_faces(); + return convert_layout( data_to_convert, - tt::stl::Span(shape.begin(), shape.end()), + shape, tests::utils::TensorLayoutType::TILED_NFACES, tests::utils::TensorLayoutType::LIN_ROW_MAJOR, tile_shape, diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index f5871ac5213..8a46d676dc7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -248,7 +248,14 @@ Tensor tensor_pad( input_tensor.storage_type() == StorageType::OWNED or input_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST or input_tensor.storage_type() == StorageType::BORROWED && "Tensor must be on host for padding"); - TT_ASSERT(input_tensor.get_layout() == Layout::ROW_MAJOR && "Tensor layout must be ROW_MAJOR for padding"); + // TODO: Flip to assert when we remove use cases in python and c++ + if (input_tensor.get_layout() != Layout::ROW_MAJOR) { + log_warning( + tt::LogOp, + "Tensor layout {} must be ROW_MAJOR for padding! Returning original tensor!", + input_tensor.get_layout()); + return input_tensor; + } auto input_shape = input_tensor.get_legacy_shape(); auto dimensions_pads = std::vector(); diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 6d11ff06bd6..b364413f57b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -45,29 +45,21 @@ Tensor to_weight_special_padding_tile_layout( } } if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + output_dtype == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); } } else { TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); @@ -130,29 +122,21 @@ Tensor to_weight_tile_layout( } } if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + auto tensor = Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), + output_shape, + DataType::FLOAT32, + Layout::ROW_MAJOR) + .to(Layout::TILE); + auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + output_dtype == DataType::BFLOAT8_B + ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false) + : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::TILE); } } else { TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B));