From 9ead67a42ea179e51b483bf4d2e43870b4ea1f4f Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Fri, 20 Dec 2024 07:40:21 +0000 Subject: [PATCH] #0: Refactor/commonize some embeddings code and add support for output sharded embeddings --- .../unit_tests/operations/test_embedding.py | 135 +++++++++ .../tilize/device/kernels/compute/tilize.cpp | 27 ++ .../device/embedding_device_operation.cpp | 32 +- .../device/embedding_device_operation.hpp | 2 +- .../device/embedding_program_factory.hpp | 281 ++++++++++-------- .../device/kernels/dataflow/embeddings.cpp | 143 ++------- .../kernels/dataflow/embeddings_common.hpp | 75 +++++ .../kernels/dataflow/embeddings_tilize.cpp | 145 +++------ ttnn/cpp/ttnn/tensor/types.hpp | 8 + 9 files changed, 497 insertions(+), 351 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp create mode 100644 ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp diff --git a/tests/ttnn/unit_tests/operations/test_embedding.py b/tests/ttnn/unit_tests/operations/test_embedding.py index 89dc39a0788a..8976f340573e 100644 --- a/tests/ttnn/unit_tests/operations/test_embedding.py +++ b/tests/ttnn/unit_tests/operations/test_embedding.py @@ -174,3 +174,138 @@ def test_embedding_tiled_input( output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("batch_size, sentence_size, hidden_embedding_dim, vocabulary_size", [(10, 96, 2048, 128256)]) +@pytest.mark.parametrize( + "output_memory_layout, num_cores_x, num_cores_y", + [ + (ttnn.TensorMemoryLayout.WIDTH_SHARDED, 8, 4), + (ttnn.TensorMemoryLayout.HEIGHT_SHARDED, 6, 1), + (ttnn.TensorMemoryLayout.BLOCK_SHARDED, 4, 6), + ], +) +@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG]) +def test_embedding_tiled_sharded_output( + device, + batch_size, + sentence_size, + hidden_embedding_dim, + vocabulary_size, + output_memory_layout, + num_cores_x, + num_cores_y, + input_mem_config, +): + torch.manual_seed(1234) + layout = ttnn.TILE_LAYOUT + + output_shape = (batch_size, 1, sentence_size, hidden_embedding_dim) + shard_grid = ttnn.CoreRangeSet( + [ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))] + ) + fused_height = output_shape[0] * output_shape[1] * output_shape[2] + width = output_shape[-1] + if output_memory_layout == ttnn.TensorMemoryLayout.WIDTH_SHARDED: + shard_shape = (fused_height, width // (num_cores_x * num_cores_y)) + elif output_memory_layout == ttnn.TensorMemoryLayout.HEIGHT_SHARDED: + shard_shape = (fused_height // (num_cores_x * num_cores_y), width) + else: + shard_shape = (fused_height // num_cores_y, width // num_cores_x) + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + output_mem_config = ttnn.MemoryConfig( + output_memory_layout, + ttnn.BufferType.L1, + shard_spec, + ) + + torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size)) + torch_weights = torch_random((vocabulary_size, hidden_embedding_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_embedding = torch.nn.Embedding.from_pretrained(torch_weights) + torch_output_tensor = torch_embedding(torch_input_tensor) + + input_tensor = ttnn.to_device( + ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT), + device, + memory_config=input_mem_config, + ) + weights = ttnn.to_device( + ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT), + device, + memory_config=input_mem_config, + ) + + output_tensor = ttnn.embedding( + input_tensor, + weights, + embeddings_type=ttnn.EmbeddingsType.GENERIC, # Default embeddings type + dtype=ttnn.bfloat16, + memory_config=output_mem_config, # Default memory config + queue_id=0, # Default queue id + layout=layout, + ) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize( + "device_params", + [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], + indirect=True, +) +def test_tg_llama_sharded_embedding( + device, +): + torch.manual_seed(1234) + unharvested_grid_size = (7, 10) + compute_grid_size = device.compute_with_storage_grid_size() + if unharvested_grid_size[0] > compute_grid_size.x or unharvested_grid_size[1] > compute_grid_size.y: + pytest.skip(f"Need {unharvested_grid_size} grid size to run this test but core grid is {compute_grid_size}") + batch_size = 8 + vocabulary_size = 4096 + hidden_embedding_dim = 128 + token_padding = 31 + sentence_size = 1 + token_padding + torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size)) + torch_weights = torch.randn(vocabulary_size, hidden_embedding_dim) + torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights) + + start_core = ttnn.CoreCoord(1, 0) + core_grid = ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ) + num_cores = batch_size + shard_grid = ttnn.num_cores_to_corerangeset_in_subcoregrids(start_core, num_cores, core_grid, row_wise=True) + shard_spec = ttnn.ShardSpec( + shard_grid, + (batch_size * sentence_size // num_cores, hidden_embedding_dim), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ) + output_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + shard_spec, + ) + + input_tensor = ttnn.as_tensor( + torch_input_tensor, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + weights = ttnn.as_tensor( + torch_weights, + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + output_tensor = ttnn.embedding(input_tensor, weights, layout=ttnn.TILE_LAYOUT, memory_config=output_mem_config) + output_tensor = ttnn.reshape( + output_tensor, + ttnn.Shape((batch_size, 1, hidden_embedding_dim), (batch_size, sentence_size, hidden_embedding_dim)), + ) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(output_tensor, torch_output_tensor[:, 0, :].unsqueeze(1)) diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp new file mode 100644 index 000000000000..ab92ecf818fc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/tilize.h" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1); + constexpr uint32_t per_core_block_cnt = get_compile_time_arg_val(2); + constexpr uint32_t per_core_block_tile_cnt = get_compile_time_arg_val(3); + tilize_init(cb_id_in0, per_core_block_tile_cnt, cb_id_out0); + + for (uint32_t b = 0; b < per_core_block_cnt; ++b) { + cb_wait_front(cb_id_in0, per_core_block_tile_cnt); + cb_reserve_back(cb_id_out0, per_core_block_tile_cnt); + + tilize_block(cb_id_in0, per_core_block_tile_cnt, cb_id_out0); + + cb_push_back(cb_id_out0, per_core_block_tile_cnt); + cb_pop_front(cb_id_in0, per_core_block_tile_cnt); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp index 96bac8fbbc2b..79adec2c76cf 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp @@ -22,26 +22,34 @@ void Embeddings::validate(const std::vector &input_tensors) const { TT_FATAL(input_tensors.size() == 2, "Must have between 2 input tensors"); auto &a = input_tensors.at(0); const auto &weights = input_tensors.at(1); - TT_FATAL(a.get_layout() == Layout::ROW_MAJOR, "Error"); - TT_FATAL(weights.get_layout() == Layout::ROW_MAJOR, "Error"); + TT_FATAL(a.get_layout() == Layout::ROW_MAJOR, "Input tensor must be Row Major Layout"); + TT_FATAL(weights.get_layout() == Layout::ROW_MAJOR, "Weights tensor must be Row Major Layout"); TT_FATAL(a.get_dtype() == DataType::UINT32 or a.get_dtype() == DataType::BFLOAT16, "Input must be UINT32 or BFLOAT16"); - TT_FATAL(weights.get_dtype() == DataType::BFLOAT16, "Error"); - TT_FATAL(a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding"); - TT_FATAL(weights.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding"); + TT_FATAL(weights.get_dtype() == DataType::BFLOAT16, "Weights tensor must have BFLOAT16 dtype"); + TT_FATAL(a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharded inputs"); + TT_FATAL(weights.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharded weights"); TT_FATAL(weights.get_legacy_shape()[0] == 1 && weights.get_legacy_shape()[1] == 1, "First two dimensions for the weights must be 1"); if (this->tilized) { - TT_FATAL(a.get_legacy_shape()[-1] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(weights.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Number of columns in table must be factor of tile width"); + TT_FATAL(a.get_legacy_shape()[-1] % TILE_HEIGHT == 0, "Input tensor width {} must be a multiple of tile height {} to have the output tensor tilized", a.get_legacy_shape()[-1], TILE_HEIGHT); + TT_FATAL(weights.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Number of columns in table {} must be factor of tile width {}", weights.get_legacy_shape()[-1], TILE_WIDTH); + if (is_sharded(this->output_mem_config.memory_layout)) { + const auto& shard_spec = this->output_mem_config.shard_spec; + TT_FATAL(shard_spec.has_value(), "Sharded memory config must have a shard spec"); + TT_FATAL(shard_spec->shape[0] % TILE_HEIGHT == 0, "Shard height {} must be a multiple of tile height {} to have the output tensor tilized", shard_spec->shape[0], TILE_HEIGHT); + TT_FATAL(shard_spec->shape[1] % TILE_WIDTH == 0, "Shard width {} must be a multiple of tile width {} to have the output tensor tilized", shard_spec->shape[1], TILE_WIDTH); + TT_FATAL(a.volume() % shard_spec->shape[0] == 0, "Input tensor volume {} must be a multiple of shard height {}", a.volume(), shard_spec->shape[0]); + TT_FATAL(weights.get_legacy_shape()[-1] % shard_spec->shape[1] == 0, "Number of columns in table {} must be factor of shard width {}", weights.get_legacy_shape()[-1], shard_spec->shape[1]); + } } else { - TT_FATAL(this->output_dtype != DataType::BFLOAT8_B, "Error"); + TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding only supports interleaved RM outputs"); + TT_FATAL(!is_block_float(this->output_dtype), "Output cannot be a block float dtype when not tilized"); } TT_FATAL(a.get_legacy_shape()[1] == 1 && a.get_legacy_shape()[2] == 1, "Only dim 0 && 3 for the input can be non 1"); switch (this->embeddings_type) { - case EmbeddingsType::PADDED: TT_FATAL(this->pad_token.has_value(), "Error"); break; - case EmbeddingsType::BINARY: TT_FATAL(weights.get_legacy_shape()[-2] == 2, "Error"); - default: TT_FATAL(!this->pad_token.has_value(), "Error"); + case EmbeddingsType::PADDED: TT_FATAL(this->pad_token.has_value(), "Pad token must be specified when PADDED Embeddings Type is specified"); break; + case EmbeddingsType::BINARY: TT_FATAL(weights.get_legacy_shape()[-2] == 2, "Weight tensor must have 2 embeddings for BINARY Embeddings Type"); break; + default: TT_FATAL(!this->pad_token.has_value(), "Pad token must not be specified when PADDED Embeddings Type is not specified"); } } diff --git a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp index 69e14c39165c..3fad9391e0fa 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp @@ -14,7 +14,7 @@ using namespace tt::constants; namespace ttnn::operations::embedding { enum class EmbeddingsType { GENERIC, PADDED, BINARY }; -enum class EmbeddingsIndexType { UINT32, BFP16}; +enum class EmbeddingsIndexType { UINT32, BFP16 }; struct Embeddings { const MemoryConfig output_mem_config; diff --git a/ttnn/cpp/ttnn/operations/embedding/device/embedding_program_factory.hpp b/ttnn/cpp/ttnn/operations/embedding/device/embedding_program_factory.hpp index 33210a1386ad..3f5f3d8aa01c 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/embedding_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/embedding_program_factory.hpp @@ -4,7 +4,6 @@ #pragma once -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" @@ -44,6 +43,8 @@ operation::ProgramWithCallbacks embeddings_tilized( bool weights_is_dram = weights.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; bool out_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool output_sharded = is_sharded(output.buffer()->buffer_layout()); + uint32_t input_element_size_bytes = a.element_size(); uint32_t weights_element_size_bytes = weights.element_size(); @@ -57,27 +58,40 @@ operation::ProgramWithCallbacks embeddings_tilized( uint32_t batch_size = a.get_legacy_shape()[0]; uint32_t num_output_rows_per_batch = a.get_legacy_shape()[-1]; uint32_t num_output_rows = num_output_rows_per_batch * batch_size; + // Note: num_blocks is just blocks along height uint32_t num_blocks = num_output_rows / TILE_HEIGHT; uint32_t num_blocks_per_batch = num_output_rows_per_batch / TILE_HEIGHT; - - auto num_embedding_dims = weights.get_legacy_shape()[-1]; - - // setup problem and grid size - uint32_t start_core_x = 0; - uint32_t start_core_y = 0; - - uint32_t problem_size = num_blocks; - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_blocks_per_core_group_1, num_blocks_per_core_group_2] = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, problem_size); + uint32_t num_cores, num_blocks_per_core_group_1, num_blocks_per_core_group_2, num_tiles_per_block; + CoreRangeSet all_cores, core_group_1, core_group_2; + bool row_major; + if (output_sharded) { + const auto& shard_spec = output.shard_spec().value(); + all_cores = shard_spec.grid; + core_group_1 = all_cores; + num_cores = all_cores.num_cores(); + num_blocks_per_core_group_1 = shard_spec.shape[0] / TILE_HEIGHT; + num_blocks_per_core_group_2 = 0; + num_tiles_per_block = shard_spec.shape[1] / TILE_WIDTH; + row_major = shard_spec.orientation == ShardOrientation::ROW_MAJOR; + } else { + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + std::tie( + num_cores, + all_cores, + core_group_1, + core_group_2, + num_blocks_per_core_group_1, + num_blocks_per_core_group_2) = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_blocks); + num_tiles_per_block = weights.get_legacy_shape()[-1] / TILE_WIDTH; + row_major = false; + } uint32_t g1_numcores = core_group_1.num_cores(); uint32_t g2_numcores = core_group_2.num_cores(); // Create Buffers - uint32_t num_tiles_per_block = weights.get_legacy_shape()[-1] / TILE_WIDTH; tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); EmbeddingsIndexType embeddings_index_type; @@ -92,63 +106,74 @@ operation::ProgramWithCallbacks embeddings_tilized( tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_cb_data_format); - uint32_t buffering = weights.get_legacy_shape()[-1] > 2048 ? 1 : 2; + // Hardcoded limit to reduce L1 usage. Should be updated to be tuned based on overall L1 usage + constexpr uint32_t max_double_buffer_tiles = 64; + uint32_t buffering = num_tiles_per_block > max_double_buffer_tiles ? 1 : 2; - uint32_t src0_cb_index = 0; + constexpr uint32_t src0_cb_index = CBIndex::c_0; tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig( buffering * num_tiles_per_block * weights_single_tile_size, {{src0_cb_index, weights_cb_data_format}}) .set_page_size(src0_cb_index, weights_single_tile_size); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); - uint32_t src1_cb_index = 1; + constexpr uint32_t src1_cb_index = CBIndex::c_1; tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(TILE_HEIGHT * input_element_size_bytes, {{src1_cb_index, input_cb_data_format}}) .set_page_size(src1_cb_index, TILE_HEIGHT * input_element_size_bytes); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + constexpr uint32_t output_cb_index = CBIndex::c_2; + uint32_t output_cb_size; + if (output_sharded) { + output_cb_size = output.buffer()->aligned_size_per_bank(); + } else { + output_cb_size = buffering * num_tiles_per_block * output_single_tile_size; + } + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig(output_cb_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); + if (output_sharded) { + cb_output_config.set_globally_allocated_address(*out_buffer); + } + auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + + constexpr uint32_t src2_cb_index = CBIndex::c_3; if (embeddings_type == EmbeddingsType::PADDED) { - uint32_t src2_cb_index = 2; uint32_t cache_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src2_config = tt_metal::CircularBufferConfig(cache_page_size, {{src2_cb_index, weights_cb_data_format}}) .set_page_size(src2_cb_index, cache_page_size); auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); } else if (embeddings_type == EmbeddingsType::BINARY) { - uint32_t src2_cb_index = 2; uint32_t cache_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src2_config = tt_metal::CircularBufferConfig(2 * cache_page_size, {{src2_cb_index, weights_cb_data_format}}) .set_page_size(src2_cb_index, cache_page_size); auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); } + uint32_t weight_block_size; + if (output_sharded) { + weight_block_size = output.shard_spec().value().shape[1] * weights_element_size_bytes; + } else { + weight_block_size = weight_page_size; + } - uint32_t output_cb_index = CBIndex::c_16; - tt_metal::CircularBufferConfig cb_output_config = - tt_metal::CircularBufferConfig( - buffering * num_tiles_per_block * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) - .set_page_size(output_cb_index, output_single_tile_size); - auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); - - bool input_stick_size_is_power_of_two = is_power_of_two_at_least_32(input_page_size); - uint32_t input_log2_stick_size = input_stick_size_is_power_of_two ? (std::uint32_t)std::log2(input_page_size) : 0; - bool weight_stick_size_is_power_of_two = is_power_of_two_at_least_32(weight_page_size); - uint32_t weight_log2_stick_size = - weight_stick_size_is_power_of_two ? (std::uint32_t)std::log2(weight_page_size) : 0; - + // TODO: Can increase size for larger reads + uint32_t input_block_size_bytes = TILE_HEIGHT * input_element_size_bytes; // Create Kernels // reader std::vector embedding_compile_time_args = { + (std::uint32_t)src0_cb_index, + (std::uint32_t)src1_cb_index, + (std::uint32_t)src2_cb_index, (std::uint32_t)in0_is_dram, - (std::uint32_t)input_stick_size_is_power_of_two, (std::uint32_t)input_page_size, - (std::uint32_t)input_log2_stick_size, (std::uint32_t)weights_is_dram, - (std::uint32_t)weight_stick_size_is_power_of_two, (std::uint32_t)weight_page_size, - (std::uint32_t)weight_log2_stick_size, + (std::uint32_t)weight_block_size, (std::uint32_t)num_tiles_per_block, - (std::uint32_t)TILE_HEIGHT * input_element_size_bytes}; + (std::uint32_t)input_block_size_bytes}; std::map embedding_defines = { {magic_enum::enum_name(embeddings_type).data(), "1"}, @@ -162,48 +187,53 @@ operation::ProgramWithCallbacks embeddings_tilized( if (num_blocks_per_core_group_1 > 0) { std::vector compute_args_1 = { + uint32_t(src0_cb_index), // input embeddings_cb_index + uint32_t(output_cb_index), // output_cb_index uint32_t(num_blocks_per_core_group_1), // per_core_block_cnt uint32_t(num_tiles_per_block) // per_core_block_tile_cnt }; auto tilize_kernel_id_1 = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp", core_group_1, tt_metal::ComputeConfig{.compile_args = compute_args_1}); } if (num_blocks_per_core_group_2 > 0) { std::vector compute_args_2 = { + uint32_t(src0_cb_index), // input embeddings_cb_index + uint32_t(output_cb_index), // output_cb_index uint32_t(num_blocks_per_core_group_2), // per_core_block_cnt uint32_t(num_tiles_per_block) // per_core_block_tile_cnt }; auto tilize_kernel_id_2 = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/tilize.cpp", + "ttnn/cpp/ttnn/operations/data_movement/tilize/device/kernels/compute/tilize.cpp", core_group_2, tt_metal::ComputeConfig{.compile_args = compute_args_2}); } + KernelHandle writer_kernel_id = 0; + // TODO: We can use the second risc to do more work in parallel + if (!output_sharded) { + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)out_is_dram}; - std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)out_is_dram}; + // Tilized writer + writer_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + all_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + } - // Tilized writer - auto writer_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - all_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + auto cores = corerange_to_cores(all_cores, std::nullopt, row_major); - uint32_t input_offset = 0; - uint32_t weight_offset = 0; - uint32_t tile_offset = 0; - - auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); std::vector reader_runtime_args = { (std::uint32_t)a.buffer()->address(), (std::uint32_t)weights.buffer()->address(), (std::uint32_t)0, (std::uint32_t)0, (std::uint32_t)0, + (std::uint32_t)0, }; if (embeddings_type == EmbeddingsType::PADDED) { reader_runtime_args.push_back(pad_token.value()); @@ -212,54 +242,73 @@ operation::ProgramWithCallbacks embeddings_tilized( std::vector writer_runtime_args = { (std::uint32_t)output.buffer()->address(), (std::uint32_t)0, (std::uint32_t)0}; + uint32_t input_offset = 0; + uint32_t weight_offset = 0; + uint32_t tile_offset = 0; for (uint32_t i = 0; i < cores.size(); ++i) { const CoreCoord& core = cores[i]; - uint32_t local_input_offset = input_offset; uint32_t local_num_blocks = i < g1_numcores ? num_blocks_per_core_group_1 : num_blocks_per_core_group_2; // Reader { reader_runtime_args[2] = input_offset / num_blocks_per_batch; - reader_runtime_args[3] = input_offset % num_blocks_per_batch * TILE_HEIGHT * input_element_size_bytes; - reader_runtime_args[4] = local_num_blocks; + reader_runtime_args[3] = input_offset % num_blocks_per_batch * input_block_size_bytes; + reader_runtime_args[4] = weight_offset; + reader_runtime_args[5] = local_num_blocks; tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); } // Writer - { + if (!output_sharded) { writer_runtime_args[1] = num_tiles_per_block * local_num_blocks; writer_runtime_args[2] = tile_offset; tile_offset += local_num_blocks * num_tiles_per_block; tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); + input_offset += local_num_blocks; + } else { + weight_offset += weight_block_size; + if (weight_offset == weight_page_size) { + weight_offset = 0; + input_offset += local_num_blocks; + } } - - input_offset += local_num_blocks; } - auto override_runtime_args_callback = [num_cores_x, num_cores_y, reader_kernel_id, writer_kernel_id, cores, device]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto output_dram_buffer = output_buffers.at(0); - auto input_dram_buffer = input_buffers.at(0); - auto weights_dram_buffer = input_buffers.at(1); - - for (const auto& core : cores) { - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = input_dram_buffer->address(); - runtime_args[1] = weights_dram_buffer->address(); + auto override_runtime_arguments_callback = + [reader_kernel_id, writer_kernel_id, cores, cb_output]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto output_buffer = output_tensors.at(0).buffer(); + auto output_buffer_address = output_buffer->address(); + auto input_buffer_address = input_tensors.at(0).buffer()->address(); + auto weights_buffer_address = input_tensors.at(1).buffer()->address(); + + auto& reader_runtime_args = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id); + const bool output_sharded = is_sharded(output_buffer->buffer_layout()); + if (output_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_output, *output_buffer); } - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args[core.x][core.y]; + runtime_args[0] = input_buffer_address; + runtime_args[1] = weights_buffer_address; + } + + if (!output_sharded) { + auto& runtime_args = writer_runtime_args[core.x][core.y]; + runtime_args[0] = output_buffer_address; + } } - } - }; + }; - return {std::move(program), override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } operation::ProgramWithCallbacks embeddings_rm( @@ -307,13 +356,11 @@ operation::ProgramWithCallbacks embeddings_rm( uint32_t batch_size = a.get_legacy_shape()[0]; uint32_t num_output_rows_per_batch = a.get_legacy_shape()[-1]; uint32_t num_output_rows = num_output_rows_per_batch * batch_size; - constexpr uint32_t alignment = 32; + auto alignment = a.buffer()->alignment(); uint32_t block_height = (alignment / input_element_size_bytes); uint32_t num_blocks = num_output_rows; uint32_t num_blocks_per_batch = num_output_rows_per_batch; - auto num_embedding_dims = weights.get_legacy_shape()[-1]; - // setup problem and grid size uint32_t start_core_x = 0; uint32_t start_core_y = 0; @@ -335,29 +382,30 @@ operation::ProgramWithCallbacks embeddings_rm( tt::DataFormat weights_cb_data_format = tt_metal::datatype_to_dataformat_converter(weights.get_dtype()); tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - uint32_t src0_cb_index = 0; + constexpr uint32_t src0_cb_index = CBIndex::c_0; uint32_t rounded_weight_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(2 * rounded_weight_page_size, {{src0_cb_index, weights_cb_data_format}}) .set_page_size(src0_cb_index, rounded_weight_page_size); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); - uint32_t src1_cb_index = 1; + constexpr uint32_t src1_cb_index = CBIndex::c_1; uint32_t index_page_size = round_up_to_mul32(input_element_size_bytes); tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(block_height * index_page_size, {{src1_cb_index, input_cb_data_format}}) .set_page_size(src1_cb_index, block_height * index_page_size); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + constexpr uint32_t output_cb_index = src0_cb_index; + + constexpr uint32_t src2_cb_index = CBIndex::c_2; if (embeddings_type == EmbeddingsType::PADDED) { - uint32_t src2_cb_index = 2; uint32_t cache_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src2_config = tt_metal::CircularBufferConfig(cache_page_size, {{src2_cb_index, weights_cb_data_format}}) .set_page_size(src2_cb_index, cache_page_size); auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); } else if (embeddings_type == EmbeddingsType::BINARY) { - uint32_t src2_cb_index = 2; uint32_t cache_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src2_config = tt_metal::CircularBufferConfig(2 * cache_page_size, {{src2_cb_index, weights_cb_data_format}}) @@ -365,25 +413,16 @@ operation::ProgramWithCallbacks embeddings_rm( auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); } - uint32_t output_cb_index = src0_cb_index; - - bool input_stick_size_is_power_of_two = is_power_of_two_at_least_32(input_page_size); - uint32_t input_log2_stick_size = input_stick_size_is_power_of_two ? (std::uint32_t)std::log2(input_page_size) : 0; - bool weight_stick_size_is_power_of_two = is_power_of_two_at_least_32(weight_page_size); - uint32_t weight_log2_stick_size = - weight_stick_size_is_power_of_two ? (std::uint32_t)std::log2(weight_page_size) : 0; - // Create Kernels // reader std::vector embedding_compile_time_args = { + (std::uint32_t)src0_cb_index, + (std::uint32_t)src1_cb_index, + (std::uint32_t)src2_cb_index, (std::uint32_t)in0_is_dram, - (std::uint32_t)input_stick_size_is_power_of_two, (std::uint32_t)input_page_size, - (std::uint32_t)input_log2_stick_size, (std::uint32_t)weights_is_dram, - (std::uint32_t)weight_stick_size_is_power_of_two, (std::uint32_t)weight_page_size, - (std::uint32_t)weight_log2_stick_size, (std::uint32_t)block_height, (std::uint32_t)block_height * input_element_size_bytes}; @@ -463,29 +502,35 @@ operation::ProgramWithCallbacks embeddings_rm( input_offset += local_num_blocks; } - auto override_runtime_args_callback = [num_cores_x, num_cores_y, reader_kernel_id, writer_kernel_id, cores, device]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto output_dram_buffer = output_buffers.at(0); - auto input_dram_buffer = input_buffers.at(0); - auto weights_dram_buffer = input_buffers.at(1); - - for (const auto& core : cores) { - { - auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = input_dram_buffer->address(); - runtime_args[1] = weights_dram_buffer->address(); - } - - { - auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); + auto override_runtime_arguments_callback = + [reader_kernel_id, writer_kernel_id, cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + auto output_buffer_address = output_tensors.at(0).buffer()->address(); + auto input_buffer_address = input_tensors.at(0).buffer()->address(); + auto weights_buffer_address = input_tensors.at(1).buffer()->address(); + + auto& reader_runtime_args = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id); + + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args[core.x][core.y]; + runtime_args[0] = input_buffer_address; + runtime_args[1] = weights_buffer_address; + } + + { + auto& runtime_args = writer_runtime_args[core.x][core.y]; + runtime_args[0] = output_buffer_address; + } } - } - }; + }; - return {std::move(program), override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } operation::ProgramWithCallbacks embeddings_( diff --git a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings.cpp b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings.cpp index 686677befd3a..4915f8103c81 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings.cpp @@ -3,145 +3,68 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp" void kernel_main() { - const std::uint32_t input_dram_buffer_src_addr = get_arg_val(0); - const std::uint32_t weights_dram_buffer_src_addr = get_arg_val(1); + const std::uint32_t input_buffer_src_addr = get_arg_val(0); + const std::uint32_t weight_buffer_src_addr = get_arg_val(1); const std::uint32_t batch_offset = get_arg_val(2); const std::uint32_t weights_offset = get_arg_val(3); const std::uint32_t num_blocks = get_arg_val(4); const std::uint32_t index_idx = get_arg_val(5); -#define in_is_dram get_compile_time_arg_val(0) == 1 -#define in_stick_size_is_power_of_two get_compile_time_arg_val(1) == 1 - constexpr uint32_t input_page_size = get_compile_time_arg_val(2); -#if (in_stick_size_is_power_of_two) - constexpr uint32_t log_base_2_of_input_page_size = get_compile_time_arg_val(3); - const InterleavedPow2AddrGen input = { - .bank_base_address = input_dram_buffer_src_addr, - .log_base_2_of_page_size = log_base_2_of_input_page_size // TODO(AP): refactor - }; -#else - const InterleavedAddrGen input = { - .bank_base_address = input_dram_buffer_src_addr, .page_size = input_page_size}; -#endif + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(1); + constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(2); -#define weights_is_dram get_compile_time_arg_val(4) == 1 -#define weight_stick_size_is_power_of_two get_compile_time_arg_val(5) == 1 - constexpr uint32_t weight_stick_size = get_compile_time_arg_val(6); -#if (weight_stick_size_is_power_of_two) - constexpr uint32_t log_base_2_of_weights_page_size = get_compile_time_arg_val(7); - const InterleavedPow2AddrGen weights = { - .bank_base_address = weights_dram_buffer_src_addr, - .log_base_2_of_page_size = log_base_2_of_weights_page_size // TODO(AP): refactor - }; -#else - const InterleavedAddrGen weights = { - .bank_base_address = weights_dram_buffer_src_addr, .page_size = weight_stick_size}; -#endif + constexpr bool input_in_dram = get_compile_time_arg_val(3) == 1; + constexpr uint32_t input_page_size = get_compile_time_arg_val(4); + const auto input = get_interleaved_addr_gen(input_buffer_src_addr); - constexpr uint32_t rows_per_block = get_compile_time_arg_val(8); - constexpr uint32_t input_block_size_bytes = get_compile_time_arg_val(9); + constexpr bool weight_in_dram = get_compile_time_arg_val(5) == 1; + constexpr uint32_t weight_stick_size = get_compile_time_arg_val(6); + const auto weights = get_interleaved_addr_gen(weight_buffer_src_addr); - constexpr uint32_t cb_id_in0 = 0; - constexpr uint32_t cb_id_in1 = 1; - constexpr uint32_t cb_id_in2 = 2; + constexpr uint32_t rows_per_block = get_compile_time_arg_val(7); + constexpr uint32_t input_block_size_bytes = get_compile_time_arg_val(8); - constexpr uint32_t tile_height = 32; + prepare_local_cache(cb_id_in2, weights, weight_stick_size, /*pad_token_arg_idx=*/6); -#if defined PADDED - const std::uint32_t pad_token = get_arg_val(6); - uint64_t pad_noc_addr; - { - cb_reserve_back(cb_id_in2, 1); - uint32_t local_pad_addr = get_write_ptr(cb_id_in2); - uint64_t src_noc_addr = get_noc_addr(pad_token, weights); - noc_async_read(src_noc_addr, local_pad_addr, weight_stick_size); - noc_async_read_barrier(); - pad_noc_addr = get_noc_addr(local_pad_addr); - } -#elif defined BINARY - uint64_t zero_noc_addr, one_noc_addr; - { - cb_reserve_back(cb_id_in2, 2); - uint32_t local_write_addr = get_write_ptr(cb_id_in2); - uint64_t src_noc_addr = get_noc_addr(0, weights); - noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); - zero_noc_addr = get_noc_addr(local_write_addr); + cb_reserve_back(cb_id_in1, 1); + uint32_t input_l1_addr = get_write_ptr(cb_id_in1); + volatile tt_l1_ptr input_token_t* input_l1_ptr = reinterpret_cast(input_l1_addr); - local_write_addr += weight_stick_size; - src_noc_addr = get_noc_addr(1, weights); - noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); - one_noc_addr = get_noc_addr(local_write_addr); + uint32_t curr_row = batch_offset; + uint32_t offset = weights_offset; + uint32_t index = index_idx; - noc_async_read_barrier(); - } -#endif + uint64_t noc_input_src_addr = get_noc_addr(curr_row, input) + offset; + noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); + noc_async_read_barrier(); - cb_reserve_back(cb_id_in1, 1); - uint32_t input_l1_addr = get_write_ptr(cb_id_in1); -#if defined BFP16 - volatile tt_l1_ptr uint16_t* input_l1_ptr = reinterpret_cast(input_l1_addr); -#else - volatile tt_l1_ptr uint32_t* input_l1_ptr = reinterpret_cast(input_l1_addr); -#endif - auto read_block = [&](const uint32_t& token_idx, const uint32_t& width_size) { + for (uint32_t i = 0; i < num_blocks; ++i) { cb_reserve_back(cb_id_in0, 1); uint32_t l1_write_addr = get_write_ptr(cb_id_in0); - uint64_t src_noc_addr; - uint32_t token = input_l1_ptr[token_idx]; -#if defined PADDED - if (token == pad_token) { - src_noc_addr = pad_noc_addr; - } else { - src_noc_addr = get_noc_addr(token, weights); - } -#elif defined BINARY - if (token == 0) { - src_noc_addr = zero_noc_addr; - } else { - src_noc_addr = one_noc_addr; - } -#else -#if defined BFP16 - union { - float f; - uint32_t u; - } u; - u.u = (uint32_t)input_l1_ptr[token_idx] << 16; - uint32_t token_casted = static_cast(u.f); - src_noc_addr = get_noc_addr(token_casted, weights); -#else - src_noc_addr = get_noc_addr(token, weights); -#endif -#endif - noc_async_read(src_noc_addr, l1_write_addr, width_size); + input_token_t token = input_l1_ptr[index]; + uint64_t src_noc_addr = get_token_noc_addr(token, weights); + noc_async_read(src_noc_addr, l1_write_addr, weight_stick_size); noc_async_read_barrier(); cb_push_back(cb_id_in0, 1); - }; - uint32_t curr_row = batch_offset; - uint32_t offset = weights_offset; - uint32_t index = index_idx; - bool read_indices = true; - for (uint32_t i = 0; i < num_blocks; ++i) { - if (read_indices) { - uint64_t noc_input_src_addr = get_noc_addr(curr_row, input) + offset; - noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); - noc_async_read_barrier(); - read_indices = false; - } - read_block(index, weight_stick_size); index++; if (index == rows_per_block) { index = 0; - read_indices = true; offset += input_block_size_bytes; if (offset == input_page_size) { offset = 0; curr_row++; } + if (i != num_blocks - 1) { + noc_input_src_addr = get_noc_addr(curr_row, input) + offset; + noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); + noc_async_read_barrier(); + } } } } diff --git a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp new file mode 100644 index 000000000000..3b2b44fff2d6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "dataflow_api.h" + +// TODO: Should get this from somewhere +constexpr uint32_t tile_height = 32; + +#if defined BFP16 +typedef uint16_t input_token_t; +#else +typedef uint32_t input_token_t; +#endif + +// TODO: Can probably make this not global +uint64_t pad_noc_addr; +uint64_t zero_noc_addr; +uint64_t one_noc_addr; + +template +FORCE_INLINE constexpr void prepare_local_cache( + uint32_t local_cache_cb, const T& weights, uint32_t weights_stick_size, uint32_t pad_token_arg_idx = 0) { +#if defined PADDED + uint32_t pad_token = get_arg_val(pad_token_arg_idx); + cb_reserve_back(local_cache_cb, 1); + uint32_t local_pad_addr = get_write_ptr(local_cache_cb); + uint64_t src_noc_addr = get_noc_addr(pad_token, weights); + noc_async_read(src_noc_addr, local_pad_addr, weight_stick_size); + noc_async_read_barrier(); + pad_noc_addr = get_noc_addr(local_pad_addr); +#elif defined BINARY + cb_reserve_back(local_cache_cb, 2); + uint32_t local_write_addr = get_write_ptr(local_cache_cb); + uint64_t src_noc_addr = get_noc_addr(0, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + zero_noc_addr = get_noc_addr(local_write_addr); + + local_write_addr += weight_stick_size; + src_noc_addr = get_noc_addr(1, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + one_noc_addr = get_noc_addr(local_write_addr); + + noc_async_read_barrier(); +#endif +} + +template +FORCE_INLINE uint64_t get_token_noc_addr(input_token_t token, const T& weights) { +#if defined PADDED + if (token == pad_token) { + return pad_noc_addr; + } else { + return get_noc_addr(token, weights); + } +#elif defined BINARY + if (token == 0) { + return zero_noc_addr; + } else { + return one_noc_addr; + } +#elif defined BFP16 + union { + float f; + uint32_t u; + } u; + u.u = (uint32_t)token << 16; + uint32_t token_casted = static_cast(u.f); + return get_noc_addr(token_casted, weights); +#else + return get_noc_addr(token, weights); +#endif +} diff --git a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp index f9f21155de16..31a1954c39ff 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_tilize.cpp @@ -3,132 +3,57 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/embedding/device/kernels/dataflow/embeddings_common.hpp" void kernel_main() { - const std::uint32_t input_dram_buffer_src_addr = get_arg_val(0); - const std::uint32_t weights_dram_buffer_src_addr = get_arg_val(1); - const std::uint32_t batch_offset = get_arg_val(2); - const std::uint32_t weights_offset = get_arg_val(3); - const std::uint32_t num_blocks = get_arg_val(4); - -#define in_is_dram get_compile_time_arg_val(0) == 1 -#define in_stick_size_is_power_of_two get_compile_time_arg_val(1) == 1 - constexpr uint32_t input_page_size = get_compile_time_arg_val(2); -#if (in_stick_size_is_power_of_two) - constexpr uint32_t log_base_2_of_input_page_size = get_compile_time_arg_val(3); - const InterleavedPow2AddrGen input = { - .bank_base_address = input_dram_buffer_src_addr, - .log_base_2_of_page_size = log_base_2_of_input_page_size // TODO(AP): refactor - }; -#else - const InterleavedAddrGen input = { - .bank_base_address = input_dram_buffer_src_addr, .page_size = input_page_size}; -#endif - -#define weights_is_dram get_compile_time_arg_val(4) == 1 -#define weight_stick_size_is_power_of_two get_compile_time_arg_val(5) == 1 + const uint32_t input_buffer_src_addr = get_arg_val(0); + const uint32_t weight_buffer_src_addr = get_arg_val(1); + const uint32_t input_start_id = get_arg_val(2); + const uint32_t input_start_offset = get_arg_val(3); + const uint32_t weight_offset = get_arg_val(4); + const uint32_t num_blocks = get_arg_val(5); + + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(1); + constexpr uint32_t cb_id_in2 = get_compile_time_arg_val(2); + + constexpr bool input_in_dram = get_compile_time_arg_val(3) == 1; + constexpr uint32_t input_page_size = get_compile_time_arg_val(4); + auto input = get_interleaved_addr_gen(input_buffer_src_addr); + + constexpr bool weight_in_dram = get_compile_time_arg_val(5) == 1; constexpr uint32_t weight_stick_size = get_compile_time_arg_val(6); -#if (weight_stick_size_is_power_of_two) - constexpr uint32_t log_base_2_of_weights_page_size = get_compile_time_arg_val(7); - const InterleavedPow2AddrGen weights = { - .bank_base_address = weights_dram_buffer_src_addr, - .log_base_2_of_page_size = log_base_2_of_weights_page_size // TODO(AP): refactor - }; -#else - const InterleavedAddrGen weights = { - .bank_base_address = weights_dram_buffer_src_addr, .page_size = weight_stick_size}; -#endif + constexpr uint32_t weight_block_size = get_compile_time_arg_val(7); + auto weights = get_interleaved_addr_gen(weight_buffer_src_addr + weight_offset); + constexpr uint32_t tiles_per_block = get_compile_time_arg_val(8); constexpr uint32_t input_block_size_bytes = get_compile_time_arg_val(9); - constexpr uint32_t cb_id_in0 = 0; - constexpr uint32_t cb_id_in1 = 1; - constexpr uint32_t cb_id_in2 = 2; + prepare_local_cache(cb_id_in2, weights, weight_block_size, /*pad_token_arg_idx=*/6); - constexpr uint32_t tile_height = 32; - -#if defined PADDED - const std::uint32_t pad_token = get_arg_val(5); - uint64_t pad_noc_addr; - { - cb_reserve_back(cb_id_in2, 1); - uint32_t local_pad_addr = get_write_ptr(cb_id_in2); - uint64_t src_noc_addr = get_noc_addr(pad_token, weights); - noc_async_read(src_noc_addr, local_pad_addr, weight_stick_size); - noc_async_read_barrier(); - pad_noc_addr = get_noc_addr(local_pad_addr); - } -#elif defined BINARY - uint64_t zero_noc_addr, one_noc_addr; - { - cb_reserve_back(cb_id_in2, 2); - uint32_t local_write_addr = get_write_ptr(cb_id_in2); - uint64_t src_noc_addr = get_noc_addr(0, weights); - noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); - zero_noc_addr = get_noc_addr(local_write_addr); + cb_reserve_back(cb_id_in1, 1); + uint32_t input_l1_addr = get_write_ptr(cb_id_in1); - local_write_addr += weight_stick_size; - src_noc_addr = get_noc_addr(1, weights); - noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); - one_noc_addr = get_noc_addr(local_write_addr); + volatile tt_l1_ptr input_token_t* input_l1_ptr = reinterpret_cast(input_l1_addr); + uint32_t curr_row = input_start_id; + uint32_t offset = input_start_offset; + for (uint32_t i = 0; i < num_blocks; ++i) { + uint64_t noc_input_src_addr = get_noc_addr(curr_row, input) + offset; + noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); noc_async_read_barrier(); - } -#endif - - cb_reserve_back(cb_id_in1, 1); - uint32_t input_l1_addr = get_write_ptr(cb_id_in1); -#if defined BFP16 - volatile tt_l1_ptr uint16_t* input_l1_ptr = reinterpret_cast(input_l1_addr); -#else - volatile tt_l1_ptr uint32_t* input_l1_ptr = reinterpret_cast(input_l1_addr); -#endif - auto read_tiles = [&](const uint32_t& num_tiles, const uint32_t& width_size) { - cb_reserve_back(cb_id_in0, num_tiles); + cb_reserve_back(cb_id_in0, tiles_per_block); uint32_t l1_write_addr = get_write_ptr(cb_id_in0); for (uint32_t k = 0; k < tile_height; ++k) { - uint64_t src_noc_addr; - uint32_t token = input_l1_ptr[k]; -#if defined PADDED - if (token == pad_token) { - src_noc_addr = pad_noc_addr; - } else { - src_noc_addr = get_noc_addr(token, weights); - } -#elif defined BINARY - if (token == 0) { - src_noc_addr = zero_noc_addr; - } else { - src_noc_addr = one_noc_addr; - } -#else -#if defined BFP16 - union { - float f; - uint32_t u; - } u; - u.u = (uint32_t)input_l1_ptr[k] << 16; - uint32_t token_casted = static_cast(u.f); - src_noc_addr = get_noc_addr(token_casted, weights); -#else - src_noc_addr = get_noc_addr(token, weights); -#endif -#endif - noc_async_read(src_noc_addr, l1_write_addr, width_size); - l1_write_addr += width_size; + input_token_t token = input_l1_ptr[k]; + uint64_t src_noc_addr = get_token_noc_addr(token, weights); + noc_async_read(src_noc_addr, l1_write_addr, weight_block_size); + l1_write_addr += weight_block_size; } noc_async_read_barrier(); - cb_push_back(cb_id_in0, num_tiles); - }; + cb_push_back(cb_id_in0, tiles_per_block); - uint32_t curr_row = batch_offset; - uint32_t offset = weights_offset; - for (uint32_t i = 0; i < num_blocks; ++i) { - uint64_t noc_input_src_addr = get_noc_addr(curr_row, input) + offset; - noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); - noc_async_read_barrier(); - read_tiles(tiles_per_block, weight_stick_size); offset += input_block_size_bytes; if (offset == input_page_size) { offset = 0; diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 33fd91e8b402..60d48929395e 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -71,6 +71,14 @@ inline bool is_floating_point(DataType dtype) { } } +inline bool is_block_float(DataType dtype) { + switch (dtype) { + case DataType::BFLOAT8_B: + case DataType::BFLOAT4_B: return true; + default: return false; + } +} + enum class StorageType { OWNED, DEVICE,