Skip to content

Commit

Permalink
Divanovic/llama tg demo (#13105)
Browse files Browse the repository at this point in the history
* #0: TG prefill+decode demo and 8k prefill pcc fix

* #0: Update setup_llama_env to use tg implementation

* #0: Remove prefill tests until CI is fixed
  • Loading branch information
djordje-tt authored Sep 30, 2024
1 parent 8d6ee5d commit b23d07e
Show file tree
Hide file tree
Showing 13 changed files with 775 additions and 502 deletions.
12 changes: 0 additions & 12 deletions models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Tuple
import numpy as np
import torch
from torch import nn
import ttnn
from models.utility_functions import tt2torch_tensor, torch2tt_tensor
from loguru import logger
Expand All @@ -18,7 +17,6 @@
load_chunked_checkpoints,
load_sharded_checkpoints,
)
import pytest
from models.demos.t3000.llama2_70b.tt.model_config import get_model_config

MAX_SEQ_LEN = 4096
Expand Down Expand Up @@ -126,10 +124,6 @@ def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4
ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3/llama-3-70b-repacked/"
tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3/tokenizer.model"
cache_path = Path("/mnt/MLPerf/tt_dnn-models/llama-3/llama-data-cache/weights-cache-3")
elif llama_version == "llama3-tg":
ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3/llama-3-70b-repacked/"
tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3/tokenizer.model"
cache_path = Path("/mnt/MLPerf/tt_dnn-models/llama-3/llama-data-cache/weights-cache-tg")
elif llama_version == "llama3-405b":
ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3-405b/llama-3-405b-repacked/"
tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3-405b/tokenizer.model"
Expand All @@ -143,12 +137,6 @@ def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4
ckpt_dir = os.getenv("LLAMA3_CKPT_DIR", "/proj_sw/llama3-data-repacked/llama-3-70b/")
tokenizer_path = os.getenv("LLAMA3_TOKENIZER_PATH", "/proj_sw/llama3-data-repacked/tokenizer.model")
cache_path = Path(os.getenv("LLAMA3_CACHE_PATH", "/proj_sw/llama-cache/llama-3-70b"))
elif llama_version == "llama3-tg":
ckpt_dir = os.getenv("LLAMA3_CKPT_DIR", "/proj_sw/user_dev/llama3-data-repacked/llama-3-70b/")
tokenizer_path = os.getenv(
"LLAMA3_TOKENIZER_PATH", "/proj_sw/user_dev/llama3-data-repacked/tokenizer.model"
)
cache_path = Path(os.getenv("LLAMA3_CACHE_PATH", "/proj_sw/user_dev/llama3-data-cache/weights-cache-2"))
elif llama_version == "llama3-405b":
ckpt_dir = os.getenv("LLAMA3_405B_CKPT_DIR", "/proj_sw/user_dev/llama3-405B-data-repacked/llama-3-405b/")
tokenizer_path = os.getenv(
Expand Down
16 changes: 12 additions & 4 deletions models/demos/tg/llama3_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from transformers.generation.utils import top_k_top_p_filtering
from models.demos.tg.llama3_70b.tt.llama_generation_galaxy import TtLlamaModelForGeneration
from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env
from models.demos.t3000.llama2_70b.reference.llama.llama.tokenizer3 import ChatFormat
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_mesh_device,
string_similarity_score,
load_llama_state_dict,
Expand Down Expand Up @@ -267,8 +267,14 @@ def run_decode(
break

# Decode the entire sequence generated so far and log it
for user_id in range(max(0, bsz - 3), bsz):
text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist())
for user_id in range(max(0, bsz - 5), bsz):
eos_found = False
for eos_idx, tk in enumerate(tokens[user_id, : cur_pos + 1].tolist()):
if tk == tokenizer.eos_id:
text = tokenizer.decode(tokens[user_id, :eos_idx].tolist())
eos_found = True
if not eos_found:
text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist())
if data_args.print_output_as_generated:
logger.info(f"Loop {cur_pos} user {user_id}: {text}\n")

Expand Down Expand Up @@ -364,7 +370,7 @@ def top_pk_logits_efficient(logits, p=0.9, k=10, temperature=1.0, return_probs=F
),
ids=("chat_completion", "text_completion"),
)
@pytest.mark.parametrize("decode_only", (True,), ids=("decode_only",))
@pytest.mark.parametrize("decode_only", (True, False), ids=("decode_only", "prefill_decode"))
@pytest.mark.parametrize("num_layers", (1, 2, 10, 80), ids=("1L", "2L", "10L", "80L"))
@pytest.mark.parametrize(
"implementation, skip_model_load, n_devices",
Expand Down Expand Up @@ -432,6 +438,8 @@ def test_LlamaModel_demo(
## Get model config
model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env(
llama_version=llama_version,
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

check_mesh_device(mesh_device, model_config)
Expand Down
25 changes: 3 additions & 22 deletions models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy
from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env
from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_mesh_device,
extract_pcc_from_log,
generate_rot_emb,
Expand All @@ -31,31 +31,13 @@
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
)

from models.utility_functions import skip_for_grayskull
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_mesh_device,
extract_pcc_from_log,
generate_rot_emb,
get_rotation_mat,
MAX_SEQ_LEN,
MAX_SEQ_LEN_LLAMA3,
BASE_URL,
UNIT_TEST_N_LAYER,
UNIT_TEST_LAYER_NUM,
UNIT_TEST_START_POS,
UNIT_TEST_GENERATION_LENGTH,
comp_pcc,
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
num_to_corerange,
ConcatMesh2DToTensor,
ShardTensor2dMesh,
)

from models.utility_functions import skip_for_grayskull


class PytorchLlamaAttentionModel(torch.nn.Module):
def __init__(self, hf_reference_model, layer_num, rope_theta):
Expand Down Expand Up @@ -476,7 +458,6 @@ def test_LlamaAttention_inference(
max_batch_size=max_batch_size,
max_context_len=max_context_len,
)

check_mesh_device(mesh_device, model_config)
run_test_LlamaAttention_inference(
mesh_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
import torch
import ttnn
from ttnn import ReplicateTensorToMesh, ListMeshToTensor
from ttnn import ReplicateTensorToMesh

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy
Expand Down
1 change: 0 additions & 1 deletion models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from loguru import logger
import torch
import ttnn
from ttnn import ListMeshToTensor

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy
Expand Down
6 changes: 3 additions & 3 deletions models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from models.demos.tg.llama3_70b.tt.llama_model_galaxy import TtLlamaModel_galaxy
from models.demos.tg.llama3_70b.tt.llama_common import PytorchLlamaModel
from models.utility_functions import skip_for_grayskull
from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_mesh_device,
extract_pcc_from_log,
BASE_URL,
Expand Down Expand Up @@ -242,11 +242,11 @@ def run_test_LlamaModel_inference(
"batch, seq_len",
[
(32, 1),
# (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024)
# (1, 32), (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024)
],
ids=[
"decode",
# "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k"
# "prefill_32", "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k"
],
)
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit b23d07e

Please sign in to comment.