Skip to content

Commit

Permalink
#10622: Llama TG full model implementation added to CI. Added TG llam…
Browse files Browse the repository at this point in the history
…a demos
  • Loading branch information
mikevin920 committed Jul 25, 2024
1 parent 8bc9f18 commit 7939faa
Show file tree
Hide file tree
Showing 13 changed files with 1,636 additions and 142 deletions.
460 changes: 460 additions & 0 deletions models/demos/tg/llama3_70b/demo/demo.py

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
check_kv_cache,
num_to_corerange,
ConcatMesh2DToTensor,
ShardTensor2dMesh,
)
from models.utility_functions import nearest_32

Expand Down Expand Up @@ -114,7 +115,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size)

ACT_MEMCFG = ttnn.create_sharded_memory_config(
shape=(x.shape[2], x.shape[3] // 32),
shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]),
core_grid=ttnn.CoreGrid(y=4, x=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
Expand All @@ -126,7 +127,9 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
layout=ttnn.TILE_LAYOUT,
memory_config=ACT_MEMCFG,
device=llama_attention_model.device_mesh,
mesh_mapper=ReplicateTensorToMesh(llama_attention_model.device_mesh),
mesh_mapper=ShardTensor2dMesh(
llama_attention_model.device_mesh, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape
),
)

batch_size_per_group = llama_attention_model.batch_size_per_device_group
Expand Down Expand Up @@ -272,7 +275,12 @@ def run_test_LlamaAttention_inference(
attn_mask,
)

tt_out = ttnn.to_torch(tt_out, mesh_composer=ListMeshToTensor(device_mesh))[0]
# tt_out = ttnn.to_torch(tt_out, mesh_composer=ListMeshToTensor(device_mesh))[0]

tt_out = ttnn.to_torch(
tt_out, mesh_composer=ConcatMesh2DToTensor(device_mesh, dims=(3, 1), cluster_shape=cluster_shape)
)
tt_out = tt_out[:, 0:1, :, :]
tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim]

# check outputs ----------------------------------------------------------------------
Expand Down Expand Up @@ -345,7 +353,7 @@ def run_test_LlamaAttention_inference(
)
@pytest.mark.parametrize(
"batch, seq_len, pcc",
[(32, 1, 0.9997)],
[(32, 1, 0.9995)],
ids=["decode"],
)
@pytest.mark.parametrize(
Expand Down
26 changes: 16 additions & 10 deletions models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
check_kv_cache,
num_to_corerange,
ConcatMesh2DToTensor,
ShardTensor2dMesh,
)
import gc

Expand Down Expand Up @@ -113,19 +114,21 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos):
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, hidden_dim]

ACT_MEMCFG = ttnn.create_sharded_memory_config(
shape=(x.shape[2], x.shape[3] // 32),
core_grid=ttnn.CoreGrid(y=4, x=8),
shape=(x.shape[2], x.shape[3] // 8 // llama_decoder_model.cluster_shape[0]),
core_grid=ttnn.CoreGrid(y=1, x=8),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
xs = ttnn.as_tensor(
xs = ttnn.from_torch(
x,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ACT_MEMCFG,
device=llama_decoder_model.device_mesh,
mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.device_mesh),
memory_config=ACT_MEMCFG,
mesh_mapper=ShardTensor2dMesh(
llama_decoder_model.device_mesh, dims=(3, None), cluster_shape=llama_decoder_model.cluster_shape
),
)

rot_emb = generate_rot_emb(
Expand Down Expand Up @@ -229,7 +232,7 @@ def run_test_LlamaDecoder_inference(
generation_length = 1
else:
generation_start_pos = UNIT_TEST_START_POS
generation_length = UNIT_TEST_GENERATION_LENGTH
generation_length = UNIT_TEST_GENERATION_LENGTH # 1
for i in range(generation_length):
# Prepare input
pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len))
Expand Down Expand Up @@ -264,8 +267,11 @@ def run_test_LlamaDecoder_inference(
attn_mask,
)

tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ListMeshToTensor(device_mesh))[0]
tt_out = ttnn.to_torch(
tt_out, mesh_composer=ConcatMesh2DToTensor(device_mesh, dims=(3, 1), cluster_shape=cluster_shape)
)

tt_out = tt_out[:, 0:1, :, :]
tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim]

# check outputs ----------------------------------------------------------------------
Expand Down Expand Up @@ -334,7 +340,7 @@ def run_test_LlamaDecoder_inference(
)
@pytest.mark.parametrize(
"batch, seq_len, pcc",
[(32, 1, 0.9995)],
[(32, 1, 0.995)],
ids=["decode"],
)
@pytest.mark.parametrize(
Expand All @@ -357,7 +363,7 @@ def test_LlamaDecoder_inference(
max_context_len,
llama_version,
cluster_shape,
# use_program_cache,
use_program_cache,
):
if batch > max_batch_size:
pytest.skip(f"Decode with {batch} users is not supported with large context")
Expand Down
6 changes: 5 additions & 1 deletion models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
comp_pcc,
should_skip_model_load,
ShardTensor2dMesh,
ConcatMesh2DToTensor,
)
import gc

Expand Down Expand Up @@ -126,7 +127,10 @@ def run_test_LlamaMLP_inference(

tt_out = tt_LlamaMLP_model(tt_mlp_input)

tt_out = ttnn.to_torch(tt_out, mesh_composer=ListMeshToTensor(device_mesh))[0]
tt_out = ttnn.to_torch(
tt_out, mesh_composer=ConcatMesh2DToTensor(device_mesh, dims=(3, 1), cluster_shape=cluster_shape)
)
tt_out = tt_out[:, 0:1, :, :]

does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc)
logger.info(f"PCC value: {output_pcc}")
Expand Down
Loading

0 comments on commit 7939faa

Please sign in to comment.