Skip to content

Commit

Permalink
#0: Refactor/commonize some embeddings code and add support for outpu…
Browse files Browse the repository at this point in the history
…t sharded embeddings
  • Loading branch information
tt-aho committed Dec 20, 2024
1 parent df8bbdc commit 18e6ba2
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 351 deletions.
135 changes: 135 additions & 0 deletions tests/ttnn/unit_tests/operations/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,138 @@ def test_embedding_tiled_input(
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize("batch_size, sentence_size, hidden_embedding_dim, vocabulary_size", [(10, 96, 2048, 128256)])
@pytest.mark.parametrize(
"output_memory_layout, num_cores_x, num_cores_y",
[
(ttnn.TensorMemoryLayout.WIDTH_SHARDED, 8, 4),
(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, 6, 1),
(ttnn.TensorMemoryLayout.BLOCK_SHARDED, 4, 6),
],
)
@pytest.mark.parametrize("input_mem_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_embedding_tiled_sharded_output(
device,
batch_size,
sentence_size,
hidden_embedding_dim,
vocabulary_size,
output_memory_layout,
num_cores_x,
num_cores_y,
input_mem_config,
):
torch.manual_seed(1234)
layout = ttnn.TILE_LAYOUT

output_shape = (batch_size, 1, sentence_size, hidden_embedding_dim)
shard_grid = ttnn.CoreRangeSet(
[ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))]
)
fused_height = output_shape[0] * output_shape[1] * output_shape[2]
width = output_shape[-1]
if output_memory_layout == ttnn.TensorMemoryLayout.WIDTH_SHARDED:
shard_shape = (fused_height, width // (num_cores_x * num_cores_y))
elif output_memory_layout == ttnn.TensorMemoryLayout.HEIGHT_SHARDED:
shard_shape = (fused_height // (num_cores_x * num_cores_y), width)
else:
shard_shape = (fused_height // num_cores_y, width // num_cores_x)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False)
output_mem_config = ttnn.MemoryConfig(
output_memory_layout,
ttnn.BufferType.L1,
shard_spec,
)

torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size))
torch_weights = torch_random((vocabulary_size, hidden_embedding_dim), -0.1, 0.1, dtype=torch.bfloat16)
torch_embedding = torch.nn.Embedding.from_pretrained(torch_weights)
torch_output_tensor = torch_embedding(torch_input_tensor)

input_tensor = ttnn.to_device(
ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT),
device,
memory_config=input_mem_config,
)
weights = ttnn.to_device(
ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
device,
memory_config=input_mem_config,
)

output_tensor = ttnn.embedding(
input_tensor,
weights,
embeddings_type=ttnn.EmbeddingsType.GENERIC, # Default embeddings type
dtype=ttnn.bfloat16,
memory_config=output_mem_config, # Default memory config
queue_id=0, # Default queue id
layout=layout,
)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"device_params",
[{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}],
indirect=True,
)
def test_tg_llama_sharded_embedding(
device,
):
torch.manual_seed(1234)
unharvested_grid_size = (7, 10)
compute_grid_size = device.compute_with_storage_grid_size()
if unharvested_grid_size[0] > compute_grid_size.x or unharvested_grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {unharvested_grid_size} grid size to run this test but core grid is {compute_grid_size}")
batch_size = 8
vocabulary_size = 4096
hidden_embedding_dim = 128
token_padding = 31
sentence_size = 1 + token_padding
torch_input_tensor = torch.randint(0, vocabulary_size - 1, (batch_size, sentence_size))
torch_weights = torch.randn(vocabulary_size, hidden_embedding_dim)
torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights)

start_core = ttnn.CoreCoord(1, 0)
core_grid = ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
)
num_cores = batch_size
shard_grid = ttnn.num_cores_to_corerangeset_in_subcoregrids(start_core, num_cores, core_grid, row_wise=True)
shard_spec = ttnn.ShardSpec(
shard_grid,
(batch_size * sentence_size // num_cores, hidden_embedding_dim),
ttnn.ShardOrientation.ROW_MAJOR,
False,
)
output_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
shard_spec,
)

input_tensor = ttnn.as_tensor(
torch_input_tensor, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
weights = ttnn.as_tensor(
torch_weights,
device=device,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
output_tensor = ttnn.embedding(input_tensor, weights, layout=ttnn.TILE_LAYOUT, memory_config=output_mem_config)
output_tensor = ttnn.reshape(
output_tensor,
ttnn.Shape((batch_size, 1, hidden_embedding_dim), (batch_size, sentence_size, hidden_embedding_dim)),
)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(output_tensor, torch_output_tensor[:, 0, :].unsqueeze(1))
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/tilize.h"

namespace NAMESPACE {
void MAIN {
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0);
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(1);
constexpr uint32_t per_core_block_cnt = get_compile_time_arg_val(2);
constexpr uint32_t per_core_block_tile_cnt = get_compile_time_arg_val(3);
tilize_init(cb_id_in0, per_core_block_tile_cnt, cb_id_out0);

for (uint32_t b = 0; b < per_core_block_cnt; ++b) {
cb_wait_front(cb_id_in0, per_core_block_tile_cnt);
cb_reserve_back(cb_id_out0, per_core_block_tile_cnt);

tilize_block(cb_id_in0, per_core_block_tile_cnt, cb_id_out0);

cb_push_back(cb_id_out0, per_core_block_tile_cnt);
cb_pop_front(cb_id_in0, per_core_block_tile_cnt);
}
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,34 @@ void Embeddings::validate(const std::vector<Tensor> &input_tensors) const {
TT_FATAL(input_tensors.size() == 2, "Must have between 2 input tensors");
auto &a = input_tensors.at(0);
const auto &weights = input_tensors.at(1);
TT_FATAL(a.get_layout() == Layout::ROW_MAJOR, "Error");
TT_FATAL(weights.get_layout() == Layout::ROW_MAJOR, "Error");
TT_FATAL(a.get_layout() == Layout::ROW_MAJOR, "Input tensor must be Row Major Layout");
TT_FATAL(weights.get_layout() == Layout::ROW_MAJOR, "Weights tensor must be Row Major Layout");
TT_FATAL(a.get_dtype() == DataType::UINT32 or a.get_dtype() == DataType::BFLOAT16, "Input must be UINT32 or BFLOAT16");
TT_FATAL(weights.get_dtype() == DataType::BFLOAT16, "Error");
TT_FATAL(a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding");
TT_FATAL(weights.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding");
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharding");
TT_FATAL(weights.get_dtype() == DataType::BFLOAT16, "Weights tensor must have BFLOAT16 dtype");
TT_FATAL(a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharded inputs");
TT_FATAL(weights.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding does not currently support sharded weights");

TT_FATAL(weights.get_legacy_shape()[0] == 1 && weights.get_legacy_shape()[1] == 1, "First two dimensions for the weights must be 1");
if (this->tilized) {
TT_FATAL(a.get_legacy_shape()[-1] % TILE_HEIGHT == 0, "Error");
TT_FATAL(weights.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Number of columns in table must be factor of tile width");
TT_FATAL(a.get_legacy_shape()[-1] % TILE_HEIGHT == 0, "Input tensor width {} must be a multiple of tile height {} to have the output tensor tilized", a.get_legacy_shape()[-1], TILE_HEIGHT);
TT_FATAL(weights.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Number of columns in table {} must be factor of tile width {}", weights.get_legacy_shape()[-1], TILE_WIDTH);
if (is_sharded(this->output_mem_config.memory_layout)) {
const auto& shard_spec = this->output_mem_config.shard_spec;
TT_FATAL(shard_spec.has_value(), "Sharded memory config must have a shard spec");
TT_FATAL(shard_spec->shape[0] % TILE_HEIGHT == 0, "Shard height {} must be a multiple of tile height {} to have the output tensor tilized", shard_spec->shape[0], TILE_HEIGHT);
TT_FATAL(shard_spec->shape[1] % TILE_WIDTH == 0, "Shard width {} must be a multiple of tile width {} to have the output tensor tilized", shard_spec->shape[1], TILE_WIDTH);
TT_FATAL(a.volume() % shard_spec->shape[0] == 0, "Input tensor volume {} must be a multiple of shard height {}", a.volume(), shard_spec->shape[0]);
TT_FATAL(weights.get_legacy_shape()[-1] % shard_spec->shape[1] == 0, "Number of columns in table {} must be factor of shard width {}", weights.get_legacy_shape()[-1], shard_spec->shape[1]);
}
} else {
TT_FATAL(this->output_dtype != DataType::BFLOAT8_B, "Error");
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Embedding only supports interleaved RM outputs");
TT_FATAL(!is_block_float(this->output_dtype), "Output cannot be a block float dtype when not tilized");
}
TT_FATAL(a.get_legacy_shape()[1] == 1 && a.get_legacy_shape()[2] == 1, "Only dim 0 && 3 for the input can be non 1");
switch (this->embeddings_type) {
case EmbeddingsType::PADDED: TT_FATAL(this->pad_token.has_value(), "Error"); break;
case EmbeddingsType::BINARY: TT_FATAL(weights.get_legacy_shape()[-2] == 2, "Error");
default: TT_FATAL(!this->pad_token.has_value(), "Error");
case EmbeddingsType::PADDED: TT_FATAL(this->pad_token.has_value(), "Pad token must be specified when PADDED Embeddings Type is specified"); break;
case EmbeddingsType::BINARY: TT_FATAL(weights.get_legacy_shape()[-2] == 2, "Weight tensor must have 2 embeddings for BINARY Embeddings Type"); break;
default: TT_FATAL(!this->pad_token.has_value(), "Pad token must not be specified when PADDED Embeddings Type is not specified");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using namespace tt::constants;
namespace ttnn::operations::embedding {

enum class EmbeddingsType { GENERIC, PADDED, BINARY };
enum class EmbeddingsIndexType { UINT32, BFP16};
enum class EmbeddingsIndexType { UINT32, BFP16 };

struct Embeddings {
const MemoryConfig output_mem_config;
Expand Down
Loading

0 comments on commit 18e6ba2

Please sign in to comment.