diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 760e0279656..63c8aad8233 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -9,7 +9,6 @@ from typing import Tuple import numpy as np import torch -from torch import nn import ttnn from models.utility_functions import tt2torch_tensor, torch2tt_tensor from loguru import logger @@ -18,7 +17,6 @@ load_chunked_checkpoints, load_sharded_checkpoints, ) -import pytest from models.demos.t3000.llama2_70b.tt.model_config import get_model_config MAX_SEQ_LEN = 4096 @@ -126,10 +124,6 @@ def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4 ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3/llama-3-70b-repacked/" tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3/tokenizer.model" cache_path = Path("/mnt/MLPerf/tt_dnn-models/llama-3/llama-data-cache/weights-cache-3") - elif llama_version == "llama3-tg": - ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3/llama-3-70b-repacked/" - tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3/tokenizer.model" - cache_path = Path("/mnt/MLPerf/tt_dnn-models/llama-3/llama-data-cache/weights-cache-tg") elif llama_version == "llama3-405b": ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3-405b/llama-3-405b-repacked/" tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3-405b/tokenizer.model" @@ -143,12 +137,6 @@ def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4 ckpt_dir = os.getenv("LLAMA3_CKPT_DIR", "/proj_sw/llama3-data-repacked/llama-3-70b/") tokenizer_path = os.getenv("LLAMA3_TOKENIZER_PATH", "/proj_sw/llama3-data-repacked/tokenizer.model") cache_path = Path(os.getenv("LLAMA3_CACHE_PATH", "/proj_sw/llama-cache/llama-3-70b")) - elif llama_version == "llama3-tg": - ckpt_dir = os.getenv("LLAMA3_CKPT_DIR", "/proj_sw/user_dev/llama3-data-repacked/llama-3-70b/") - tokenizer_path = os.getenv( - "LLAMA3_TOKENIZER_PATH", "/proj_sw/user_dev/llama3-data-repacked/tokenizer.model" - ) - cache_path = Path(os.getenv("LLAMA3_CACHE_PATH", "/proj_sw/user_dev/llama3-data-cache/weights-cache-2")) elif llama_version == "llama3-405b": ckpt_dir = os.getenv("LLAMA3_405B_CKPT_DIR", "/proj_sw/user_dev/llama3-405B-data-repacked/llama-3-405b/") tokenizer_path = os.getenv( diff --git a/models/demos/tg/llama3_70b/demo/demo.py b/models/demos/tg/llama3_70b/demo/demo.py index 0754321b900..62c28d348fd 100644 --- a/models/demos/tg/llama3_70b/demo/demo.py +++ b/models/demos/tg/llama3_70b/demo/demo.py @@ -15,9 +15,9 @@ from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from transformers.generation.utils import top_k_top_p_filtering from models.demos.tg.llama3_70b.tt.llama_generation_galaxy import TtLlamaModelForGeneration +from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.reference.llama.llama.tokenizer3 import ChatFormat from models.demos.t3000.llama2_70b.tt.llama_common import ( - setup_llama_env, check_mesh_device, string_similarity_score, load_llama_state_dict, @@ -267,8 +267,14 @@ def run_decode( break # Decode the entire sequence generated so far and log it - for user_id in range(max(0, bsz - 3), bsz): - text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist()) + for user_id in range(max(0, bsz - 5), bsz): + eos_found = False + for eos_idx, tk in enumerate(tokens[user_id, : cur_pos + 1].tolist()): + if tk == tokenizer.eos_id: + text = tokenizer.decode(tokens[user_id, :eos_idx].tolist()) + eos_found = True + if not eos_found: + text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist()) if data_args.print_output_as_generated: logger.info(f"Loop {cur_pos} user {user_id}: {text}\n") @@ -364,7 +370,7 @@ def top_pk_logits_efficient(logits, p=0.9, k=10, temperature=1.0, return_probs=F ), ids=("chat_completion", "text_completion"), ) -@pytest.mark.parametrize("decode_only", (True,), ids=("decode_only",)) +@pytest.mark.parametrize("decode_only", (True, False), ids=("decode_only", "prefill_decode")) @pytest.mark.parametrize("num_layers", (1, 2, 10, 80), ids=("1L", "2L", "10L", "80L")) @pytest.mark.parametrize( "implementation, skip_model_load, n_devices", @@ -432,6 +438,8 @@ def test_LlamaModel_demo( ## Get model config model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( llama_version=llama_version, + max_batch_size=max_batch_size, + max_context_len=max_context_len, ) check_mesh_device(mesh_device, model_config) diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index 5f664456ec2..454b797453e 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -11,9 +11,9 @@ from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy +from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis from models.demos.t3000.llama2_70b.tt.llama_common import ( - setup_llama_env, check_mesh_device, extract_pcc_from_log, generate_rot_emb, @@ -31,31 +31,13 @@ get_rot_transformation_mat, should_skip_model_load, check_kv_cache, -) - -from models.utility_functions import skip_for_grayskull -from models.demos.t3000.llama2_70b.tt.llama_common import ( - setup_llama_env, - check_mesh_device, - extract_pcc_from_log, - generate_rot_emb, - get_rotation_mat, - MAX_SEQ_LEN, - MAX_SEQ_LEN_LLAMA3, - BASE_URL, - UNIT_TEST_N_LAYER, - UNIT_TEST_LAYER_NUM, - UNIT_TEST_START_POS, - UNIT_TEST_GENERATION_LENGTH, - comp_pcc, - get_rot_transformation_mat, - should_skip_model_load, - check_kv_cache, num_to_corerange, ConcatMesh2DToTensor, ShardTensor2dMesh, ) +from models.utility_functions import skip_for_grayskull + class PytorchLlamaAttentionModel(torch.nn.Module): def __init__(self, hf_reference_model, layer_num, rope_theta): @@ -476,7 +458,6 @@ def test_LlamaAttention_inference( max_batch_size=max_batch_size, max_context_len=max_context_len, ) - check_mesh_device(mesh_device, model_config) run_test_LlamaAttention_inference( mesh_device, diff --git a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py index 4af00e8f2e3..fa885e122b1 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ListMeshToTensor +from ttnn import ReplicateTensorToMesh from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy diff --git a/models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py index 6a952b21046..8fa77b8d37d 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py @@ -6,7 +6,6 @@ from loguru import logger import torch import ttnn -from ttnn import ListMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy diff --git a/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py index f6666b26ecf..be05ae91057 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_model_galaxy.py @@ -16,8 +16,8 @@ from models.demos.tg.llama3_70b.tt.llama_model_galaxy import TtLlamaModel_galaxy from models.demos.tg.llama3_70b.tt.llama_common import PytorchLlamaModel from models.utility_functions import skip_for_grayskull +from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.tt.llama_common import ( - setup_llama_env, check_mesh_device, extract_pcc_from_log, BASE_URL, @@ -242,11 +242,11 @@ def run_test_LlamaModel_inference( "batch, seq_len", [ (32, 1), - # (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024) + # (1, 32), (1, 256), (1, 8192), (1, 32768), (1, 128 * 1024) ], ids=[ "decode", - # "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k" + # "prefill_32", "prefill_256", "prefill_8k", "prefill_32k", "prefill_128k" ], ) @pytest.mark.parametrize( diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index f9e14f638d1..5f251da552b 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -107,139 +107,6 @@ def get_user_selection_mat(self): mesh_mapper=ReplicateTensorToMesh(self.mesh_device), ) - def get_attn_model_config(self, mode): - if mode == "decode": - # 32 x 2048 X 2048 x 1280 - self.FUSED_QKV_MM_PROGCFG = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=(8, 5), - in0_block_w=2, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=1, - per_core_N=1, - fuse_batch=True, - fused_activation=None, - mcast_in0=True, - ) - self.COMPUTE_KERNEL_QKV = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=True, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - self.COMPUTE_KERNEL_SELFOUT = self.COMPUTE_KERNEL_QKV - - total_cores = (self.n_local_heads + self.n_local_kv_heads * 2) * self.head_dim // 32 - shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(total_cores)}) - self.CREATE_HEAD_INPUT_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_n_cores_grid, - [ - 32, - 32, - ], - ttnn.ShardOrientation.ROW_MAJOR, - False, - ), - ) - self.COMPUTE_KERNEL_ROTARY = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - self.ROTARY_PROGCFG = ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=[8, 1], - in0_block_w=4, - out_subblock_h=1, - out_subblock_w=4, - per_core_M=1, - per_core_N=4, - ) - - self.COMPUTE_KERNEL_SDPA = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - shard_grid = ttnn.CoreRangeSet({num_to_corerange(self.batch_size_per_device_group)}) - shard_spec = ttnn.ShardSpec( - shard_grid, (self.padded_local_heads, self.head_dim), ttnn.ShardOrientation.ROW_MAJOR, False - ) - - self.SDPA_HEIGHT_SHARDED_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec - ) - mesh_rows, mesh_cols = 8, 4 - self.QKV_OUT_GATHERED_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32 * mesh_cols, 1280 // 40), - core_grid=ttnn.CoreGrid(y=5, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - self.SELF_OUT_GATHERED_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32 * mesh_rows, 2048 // 32), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - self.GATHER_USERS_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32 * mesh_cols, 1024 // 32), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - elif mode == "prefill": - self.COMPUTE_KERNEL_QKV = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=True, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - self.COMPUTE_KERNEL_SELFOUT = self.COMPUTE_KERNEL_QKV - - self.COMPUTE_KERNEL_ROTARY = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - self.COMPUTE_KERNEL_SDPA = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - self.FUSED_QKV_MM_PROGCFG = get_matmul_2d_config_from_tensor_shapes( - (1, 1, self.model_config["MAX_MM_SEQ_LEN"], 2048), - (1, 1, 2048, 1280), - grid=ttnn.CoreGrid(x=8, y=4), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - fuse_batch=False, - ) - - self.SELFOUT_PROGCFG = get_matmul_2d_config_from_tensor_shapes( - (1, 1, self.model_config["MAX_MM_SEQ_LEN"], 1024), - (1, 1, 1024, 2048), - grid=ttnn.CoreGrid(x=8, y=4), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - fuse_batch=False, - ) - def init_kv_cache(self): """ Generates empty KV cache and pushed to device memory @@ -263,17 +130,14 @@ def init_kv_cache(self): ) layer_past = [cache_k, cache_v] self.layer_past = [ - ttnn.to_device( - ttnn.as_tensor( - lp, - device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, - cache_file_name=self.cache_path / f"empty_attn_cache_galaxy_{cache_k.shape}", - ), - self.mesh_device, + ttnn.as_tensor( + lp, + device=self.mesh_device, + mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + cache_file_name=self.cache_path / f"empty_attn_cache_galaxy_{cache_k.shape}", ) for lp in layer_past ] @@ -357,7 +221,7 @@ def load_weights(self): ) def __call__(self, xs, rot_mats, start_pos: int, attn_masks, user_id: int = 0, mode="decode"): - self.get_attn_model_config(mode) + self.attention_config = self.model_config["attention"][mode] # Decode should have input tensor of shape (seqlen=1, 1, batch, hidden_size) if mode == "decode": return self.decode_forward(xs, rot_mats, start_pos, attn_masks) @@ -382,20 +246,21 @@ def attn_qkv( xs, rot_mats, ): + batch_size = xs.shape[2] # Fused QKV fused_query_key_value = ttnn.matmul( xs, self.qkv, - program_config=self.FUSED_QKV_MM_PROGCFG, + program_config=self.attention_config["FUSED_QKV_MM_PROGCFG"], dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.COMPUTE_KERNEL_QKV, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_QKV"], ) xs.deallocate(True) # TODO: Use sharded all_reduce when PCC issue is fixed in this particular configuration # fused_query_key_value = tt_sharded_all_reduce( - # fused_query_key_value, self.mesh_device, cluster_axis=1, num_links=2, memory_config=self.QKV_OUT_GATHERED_MEMCFG + # fused_query_key_value, self.mesh_device, cluster_axis=1, num_links=2, memory_config=self.attention_config["QKV_OUT_GATHERED_MEMCFG"](self.cluster_shape[0]) # ) fused_query_key_value = tt_all_reduce( @@ -420,7 +285,7 @@ def attn_qkv( ) fused_query_key_value = ttnn.to_memory_config( - fused_query_key_value, memory_config=self.CREATE_HEAD_INPUT_MEMCFG + fused_query_key_value, memory_config=self.attention_config["CREATE_HEAD_INPUT_MEMCFG"] ) # Split QKV @@ -442,17 +307,17 @@ def attn_qkv( query_layer = ttnn.matmul( query_layer, rot_mats, - program_config=self.model_config["ROT_MAT_MM_PROGCFG"], + program_config=self.attention_config["ROT_MAT_MM_PROGCFG"](batch_size), memory_config=ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG, - compute_kernel_config=self.COMPUTE_KERNEL_ROTARY, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_ROTARY"], ) key_layer = ttnn.matmul( key_layer, rot_mats, - program_config=self.model_config["ROT_MAT_MM_PROGCFG"], + program_config=self.attention_config["ROT_MAT_MM_PROGCFG"](batch_size), memory_config=ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG, - compute_kernel_config=self.COMPUTE_KERNEL_ROTARY, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_ROTARY"], ) return query_layer, key_layer, value_layer @@ -499,8 +364,8 @@ def attn_mqa( [start_pos for _ in range(self.max_batch_size)], scale=self.scale, program_config=program_config, - compute_kernel_config=self.COMPUTE_KERNEL_SDPA, - memory_config=self.SDPA_HEIGHT_SHARDED_MEMCFG, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SDPA"], + memory_config=self.attention_config["SDPA_HEIGHT_SHARDED_MEMCFG"](self.batch_size_per_device_group), ) return attn_output @@ -509,7 +374,6 @@ def attn_selfout( attn_output, ): # ATTENTION SELFOUT - # breakpoint() # (1, 8, 8(32), 128) - > (1, 1, 8(32), 1024) ->(1, 1, 32, 1024) attn_output = ttnn.experimental.nlp_concat_heads_decode( attn_output, @@ -522,7 +386,7 @@ def attn_selfout( dim=2, cluster_axis=1, num_links=2, - memory_config=self.GATHER_USERS_MEMCFG, + memory_config=self.attention_config["GATHER_USERS_MEMCFG"](self.cluster_shape[0]), ) attn_output = ttnn.to_memory_config(attn_output, ttnn.L1_MEMORY_CONFIG) # user_selection_matrix = [1, 1, 32, 128] @@ -541,7 +405,7 @@ def attn_selfout( core_grid=ttnn.CoreGrid(y=4, x=8), memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, dtype=ttnn.bfloat8_b, - compute_kernel_config=self.COMPUTE_KERNEL_SELFOUT, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SELFOUT"], ) attn_output = tt_sharded_all_reduce( @@ -549,7 +413,7 @@ def attn_selfout( self.mesh_device, cluster_axis=0, num_links=2, - memory_config=self.SELF_OUT_GATHERED_MEMCFG, + memory_config=self.attention_config["SELF_OUT_GATHERED_MEMCFG"](self.cluster_shape[1]), ) return attn_output @@ -571,10 +435,10 @@ def prefill_attn_qkv( rot_mats, ): assert xs.shape[1] == 1, "batch must be 1" - assert xs.shape[2] % 128 == 0 and xs.shape[2] > 0, "Seqlen must be divisible by 128" + assert xs.shape[2] % 32 == 0 and xs.shape[2] > 0, "Seqlen must be divisible by 32" _, _, seq_len, _ = xs.shape - max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"] + max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"](seq_len) batch_dim = 1 if seq_len < max_mm_seq_len else seq_len // max_mm_seq_len # Find the division factor xs = ttnn.reshape(xs, (1, batch_dim, seq_len // batch_dim, -1)) @@ -585,8 +449,8 @@ def prefill_attn_qkv( self.qkv, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.COMPUTE_KERNEL_QKV, - program_config=self.FUSED_QKV_MM_PROGCFG, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_QKV"], + program_config=self.attention_config["FUSED_QKV_MM_PROGCFG"](seq_len), ) fused_query_key_value = tt_all_reduce( @@ -647,8 +511,8 @@ def prefill_attn_mqa( single_user_key_layer = self.prefill_prepare_tensor_for_kv_cache(key_layer, user_id) # Fill cache with multi-device tensor - ttnn.experimental.paged_fill_cache( - keys_reshaped, + ttnn.fill_cache( + keys, ttnn.experimental.typecast(single_user_key_layer, ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG), user_id % self.batch_size_per_device_group, ) @@ -657,8 +521,8 @@ def prefill_attn_mqa( values = self.layer_past[1] single_user_value_layer = self.prefill_prepare_tensor_for_kv_cache(value_layer, user_id) - ttnn.experimental.paged_fill_cache( - values_reshaped, + ttnn.fill_cache( + values, ttnn.experimental.typecast(single_user_value_layer, ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG), user_id % self.batch_size_per_device_group, ) @@ -670,6 +534,8 @@ def prefill_attn_mqa( value_layer, is_causal=True, scale=self.scale, + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SDPA"], + program_config=self.attention_config["SDPA_PROG_CFG"](query_layer.shape[-2]), # pass seq_len ) return attn_output @@ -683,7 +549,7 @@ def prefill_attn_selfout(self, attn_output): _, _, seq_len, _ = attn_output.shape - max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"] + max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"](seq_len) batch_dim = 1 if seq_len < max_mm_seq_len else seq_len // max_mm_seq_len # Find the division factor attn_output = ttnn.reshape(attn_output, (1, batch_dim, seq_len // batch_dim, -1)) @@ -692,7 +558,8 @@ def prefill_attn_selfout(self, attn_output): self.wo, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - program_config=self.SELFOUT_PROGCFG, + program_config=self.attention_config["SELFOUT_PROGCFG"](seq_len), + compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SELFOUT"], ) # bsz, 1, seqlen, hidden_size attn_output = ttnn.reshape(attn_output, (1, 1, seq_len, -1)) diff --git a/models/demos/tg/llama3_70b/tt/llama_common.py b/models/demos/tg/llama3_70b/tt/llama_common.py index 41fe530814b..4ad620a59b7 100644 --- a/models/demos/tg/llama3_70b/tt/llama_common.py +++ b/models/demos/tg/llama3_70b/tt/llama_common.py @@ -4,6 +4,11 @@ import ttnn import torch +import os +from pathlib import Path +from loguru import logger + +from models.demos.tg.llama3_70b.tt.model_config import get_model_config class PytorchLlamaModel(torch.nn.Module): @@ -82,3 +87,50 @@ def tt_sharded_all_gather(input_tensor, mesh_device, cluster_axis, dim, num_link mesh_device=mesh_device, memory_config=memory_config, ) + + +def upper_pad_sequence_length(length, padding_size): + if length % padding_size == 0: + return length # No padding needed + return ((length + padding_size - 1) // padding_size) * padding_size + + +def setup_llama_env(llama_version="llama3", max_batch_size=32, max_context_len=4096): + if os.getenv("CI") == "true": + if llama_version == "llama3-tg": + ckpt_dir = "/mnt/MLPerf/tt_dnn-models/llama-3/llama-3-70b-repacked/" + tokenizer_path = "/mnt/MLPerf/tt_dnn-models/llama-3/tokenizer.model" + cache_path = Path("/mnt/MLPerf/tt_dnn-models/llama-3/llama-data-cache/weights-cache-tg") + else: + raise ValueError(f"Unknown llama version: {llama_version}") + else: + if llama_version == "llama3-tg": + ckpt_dir = os.getenv("LLAMA3_CKPT_DIR", "/proj_sw/user_dev/llama3-data-repacked/llama-3-70b/") + tokenizer_path = os.getenv( + "LLAMA3_TOKENIZER_PATH", "/proj_sw/user_dev/llama3-data-repacked/tokenizer.model" + ) + cache_path = Path(os.getenv("LLAMA3_CACHE_PATH", "/proj_sw/user_dev/llama3-data-cache/weights-cache-2")) + else: + raise ValueError(f"Unknown llama version: {llama_version}") + + assert os.path.exists( + ckpt_dir + ), f"Checkpoint directory {ckpt_dir} does not exist, please use export {llama_version.upper()}_CKPT_DIR=..." + assert os.path.exists( + tokenizer_path + ), f"Tokenizer file {tokenizer_path} does not exist, please use export {llama_version.upper()}_TOKENIZER_PATH=..." + assert os.path.exists( + cache_path + ), f"Cache directory {cache_path} does not exist, please use export {llama_version.upper()}_CACHE_PATH=..." + + logger.info(f"Checkpoint directory: {ckpt_dir}") + logger.info(f"Tokenizer file: {tokenizer_path}") + logger.info(f"Cache directory: {cache_path}") + + model_config = get_model_config( + llama_version=llama_version, + max_batch_size=max_batch_size, + max_context_len=max_context_len, + ) + + return model_config, ckpt_dir, tokenizer_path, cache_path diff --git a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py index 5348cb3decc..ed2c6030b16 100644 --- a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py @@ -2,17 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 -from loguru import logger from typing import List -import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy from models.demos.t3000.llama2_70b.tt.llama_common import ( ShardTensor2dMesh, - ConcatMesh2DToTensor, ) from models.demos.tg.llama3_70b.tt.llama_common import tt_all_gather @@ -80,60 +76,6 @@ def __init__( self.load_weights() - def get_decoder_config(self, mode): - self.LN_COMPUTE_KERNEL_CONFIG = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - if mode == "decode": - self.LN_PROGCFG = ttnn.LayerNormShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=[8, 4], - subblock_w=8, - block_h=32 // 32, - block_w=8, - inplace=False, - ) - - shard_spec_32_cores_grid = ttnn.CoreRangeSet( - { - ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(7, 3), - ), - } - ) - - self.LN_OUTPUT_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_32_cores_grid, - [ - 32, - 8192 // 32, - ], - ttnn.ShardOrientation.ROW_MAJOR, - False, - ), - ) - self.ATTN_ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32, 2048 // 32), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - self.MLP_ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32, 2048 // 8), - core_grid=ttnn.CoreGrid(y=1, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - def set_model_config(self, model_config): self.model_config = model_config self.attention.set_model_config(model_config) @@ -190,7 +132,7 @@ def __call__( user_id: int = 0, mode="decode", ) -> ttnn.Tensor: - self.get_decoder_config(mode) + self.decoder_config = self.model_config["decoder"][mode] if mode == "decode": return self.decode_forward(xs, rot_mats, start_pos, attn_masks) elif mode == "prefill": @@ -201,7 +143,7 @@ def __call__( def tt_distributed_rmsnorm(self, inp, epsilon, gamma): # Run distributed rmsnorm part 1 tt_stats = ttnn.rms_norm_pre_all_gather( - inp, compute_kernel_config=self.LN_COMPUTE_KERNEL_CONFIG, dtype=ttnn.bfloat16 + inp, compute_kernel_config=self.decoder_config["LN_COMPUTE_KERNEL_CONFIG"], dtype=ttnn.bfloat16 ) tt_stats = ttnn.reshape( @@ -220,7 +162,11 @@ def tt_distributed_rmsnorm(self, inp, epsilon, gamma): # Run distributed rmsnorm part 2 tt_out = ttnn.rms_norm_post_all_gather( - inp, tt_stats, epsilon=epsilon, weight=gamma, compute_kernel_config=self.LN_COMPUTE_KERNEL_CONFIG + inp, + tt_stats, + epsilon=epsilon, + weight=gamma, + compute_kernel_config=self.decoder_config["LN_COMPUTE_KERNEL_CONFIG"], ) tt_stats.deallocate(True) @@ -242,9 +188,9 @@ def decode_forward( gamma=self.attn_norm_sharded, ) - attn_norm_out = ttnn.to_memory_config(attn_norm_out, memory_config=self.ATTN_ACT_MEMCFG) + attn_norm_out = ttnn.to_memory_config(attn_norm_out, memory_config=self.decoder_config["ATTN_ACT_MEMCFG"]) attn_outs = self.attention(attn_norm_out, rot_mats, start_pos, attn_masks, mode="decode") - attn_outs = ttnn.to_memory_config(attn_outs, memory_config=self.MLP_ACT_MEMCFG) + attn_outs = ttnn.to_memory_config(attn_outs, memory_config=self.decoder_config["MLP_ACT_MEMCFG"]) output = xs output = ttnn.add( @@ -261,7 +207,7 @@ def decode_forward( gamma=self.ffn_norm_sharded, ) - ffn_norm_out = ttnn.to_memory_config(ffn_norm_out, memory_config=self.MLP_ACT_MEMCFG) + ffn_norm_out = ttnn.to_memory_config(ffn_norm_out, memory_config=self.decoder_config["MLP_ACT_MEMCFG"]) ffn_out = self.mlp(ffn_norm_out, mode="decode") ### residual add diff --git a/models/demos/tg/llama3_70b/tt/llama_generation_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_generation_galaxy.py index 489842c5b8d..ffd80c157b5 100644 --- a/models/demos/tg/llama3_70b/tt/llama_generation_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_generation_galaxy.py @@ -9,9 +9,10 @@ import copy from models.demos.tg.llama3_70b.tt.llama_model_galaxy import TtLlamaModel_galaxy as TtLlamaModel from models.demos.t3000.llama2_70b.tt.llama_common import BASE_URL, ConcatMesh2DToTensor -from models.demos.t3000.llama2_70b.tt.model_config import ( +from models.demos.tg.llama3_70b.tt.model_config import ( get_model_config, ) +from models.demos.tg.llama3_70b.tt.llama_common import upper_pad_sequence_length class TtLlamaModelForGeneration: @@ -54,6 +55,8 @@ def forward(self, tokens: torch.Tensor, start_pos: int): _, seq_len = tokens.shape if seq_len == 1: return self.decode_forward(tokens, start_pos) + else: + return self.prefill_forward(tokens, start_pos) def decode_forward(self, tokens: torch.Tensor, start_pos: int): batch = tokens.shape[0] @@ -73,6 +76,55 @@ def decode_forward(self, tokens: torch.Tensor, start_pos: int): return logits + def prefill_forward_single_user( + self, tokens: torch.Tensor, start_pos: int, user_id: int, last_token_idx=None, page_table=None + ): + batch, seq_len = tokens.shape + assert batch == 1 + assert start_pos == 0, "start_pos must be 0 for prefill_forward_single_user" + assert seq_len % 32 == 0, f"seq_len must be divisible by 32, got {seq_len}" + tt_inp_emb, start_pos, rot_mat, attn_mask = self.tt_model.prepare_inputs( + tokens, + start_pos=start_pos, + valid_seq_len=seq_len, + mode="prefill", + ) + + tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, attn_mask, user_id=user_id, mode="prefill") + + del tt_inp_emb + del rot_mat + del attn_mask + + logits = self._process_logits(tt_logits) + logits = logits.squeeze(1) + del tt_logits + return logits + + def prefill_forward(self, tokens: torch.Tensor, start_pos: int): + batch, seq_len = tokens.shape + assert seq_len <= 8 * 1024, f"Only prefill up to 2048 tokens is supported, got {seq_len}" + prefill_seq_len = upper_pad_sequence_length( + seq_len, self.tt_model.model_config["PADDING_LENGTH"] + ) # Pad seq_len to nearest_32 multiple + + batch, seq_len = tokens.shape + last_token_idx = seq_len - 1 + output_logits = torch.zeros(batch, seq_len, self.params.vocab_size) + # pad tokens to nearest 32 multiple + prefill_ids = torch.cat([tokens, torch.zeros(batch, prefill_seq_len - seq_len).long()], dim=-1) + + for user_id in range(batch): + logger.info(f"Filling kv cache for user {user_id + 1}") + + logits = self.prefill_forward_single_user(prefill_ids[user_id : user_id + 1], start_pos, user_id) + + # Since we give padded_seq_len, we get only the last token + output_logits[user_id] = logits[:, last_token_idx % 32 : last_token_idx % 32 + 1, :] + logger.info(f"Finished prefill for all users up to {seq_len} tokens, Starting decode...") + + return output_logits + def _process_logits(self, tt_logits): logits = ttnn.to_torch( tt_logits, diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index 4514be7af5f..de79d44af01 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -2,11 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 -from loguru import logger from typing import List -import torch import ttnn -from ttnn import ReplicateTensorToMesh from models.demos.t3000.llama2_70b.tt.llama_common import ShardTensor2dMesh, ConcatMesh2DToTensor from models.utility_functions import nearest_32 from models.demos.tg.llama3_70b.tt.llama_common import tt_all_reduce, tt_sharded_all_reduce @@ -46,148 +43,6 @@ def __init__( def set_model_config(self, model_config): self.model_config = model_config - def get_mlp_model_config(self, mode): - if mode == "decode": - # Weight Sharding - weight_grid = ttnn.CoreRangeSet( - { - ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord( - self.mesh_device.dram_grid_size().x - 1, - self.mesh_device.dram_grid_size().y - 1, - ), - ) - } - ) - M, K, N = 32, self.model_config["HIDDEN_SIZE"], self.model_config["FFN_EXPANDED_HIDDEN_SIZE"] - - K = K // self.cluster_shape[0] - N = N // self.cluster_shape[1] - shard_shape = (K, nearest_32(N // 12)) # padded cols to divide by 12 - shard_spec = ttnn.ShardSpec(weight_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) - self.w1_mem_config = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, shard_spec - ) - - w2_K, w2_N = N, K - shard_shape = (w2_K, nearest_32(w2_N // 12)) # padded cols to divide by 12 - shard_spec = ttnn.ShardSpec(weight_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) - self.w2_mem_config = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, shard_spec - ) - - self.FF1_DRAM_SHARDED_PROGCFG = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( - in0_block_w=K - // 8 - // 32, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size - per_core_M=M // 32, # M / TILE_HEIGHT = 32 / 32 - per_core_N=N // 8 // 32, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size - fused_activation=None, - ) - - self.FF2_DRAM_SHARDED_PROGCFG = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( - in0_block_w=w2_K - // 8 - // 32, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size - per_core_M=M // 32, # M / TILE_HEIGHT = 32 / 32 - per_core_N=w2_N // 8 // 32, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size - fused_activation=None, - ) - - self.COMPUTE_KERNEL_LOFI = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - full_grid = ttnn.CoreRangeSet( - { - ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(7, 7), - ) - } - ) - self.FULL_GRID_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - full_grid, - [ - 32, - nearest_32(56), - ], - ttnn.ShardOrientation.ROW_MAJOR, - False, - ), - ) - - self.FF2_ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(M, N // 8), - core_grid=ttnn.CoreGrid(y=1, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - self.FF1_ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32, 2048 // 8), - core_grid=ttnn.CoreGrid(y=1, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - mesh_rows, mesh_cols = 8, 4 - self.FF1_OUT_GATHERED_MEMCFG = ttnn.create_sharded_memory_config( - shape=(M * mesh_cols, N // 8), - core_grid=ttnn.CoreGrid(y=1, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - self.FF2_OUT_GATHERED_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32 * mesh_rows, 2048 // 8), - core_grid=ttnn.CoreGrid(y=1, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - elif mode == "prefill": - hidden_dim_per_chip = self.hidden_size // self.cluster_shape[0] # 2048 - ff_outer_dim_per_chip = ( - self.state_dict["layers.0.feed_forward.w1.weight"].shape[0] // self.cluster_shape[1] - ) # 3584 - self.FF1_PROGCFG = get_matmul_2d_config_from_tensor_shapes( - ( - 1, - 1, - self.model_config["MAX_MM_SEQ_LEN"], - hidden_dim_per_chip, - ), # (1, 1, self.model_config["MAX_MM_SEQ_LEN"], 2048) - (1, 1, hidden_dim_per_chip, ff_outer_dim_per_chip), # (1, 1, 2048, 3584) - grid=ttnn.CoreGrid(x=8, y=4), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - fuse_batch=False, - ) - self.FF2_PROGCFG = get_matmul_2d_config_from_tensor_shapes( - ( - 1, - 1, - self.model_config["MAX_MM_SEQ_LEN"], - ff_outer_dim_per_chip, - ), # (1, 1, self.model_config["MAX_MM_SEQ_LEN"], 3584) - (1, 1, ff_outer_dim_per_chip, hidden_dim_per_chip), # (1, 1, 3584, 2048) - grid=ttnn.CoreGrid(x=8, y=4), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - fuse_batch=False, - ) - def load_weights(self): assert not hasattr(self, "w1_list"), "w1_list is already an attribute of this object" assert not hasattr(self, "w3_list"), "w3_list is already an attribute of this object" @@ -233,7 +88,7 @@ def load_weights(self): dtype=w3_dtype, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - # memory_config=self.w1_mem_config, # TODO: Reenable when DRAM-SHARDED PCC issues resolves + # memory_config=self.mlp_config["W1_MEM_CONFIG"](self.mesh_device, self.cluster_shape), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), cache_file_name=self.cache_path / w3_cache_str, @@ -244,14 +99,14 @@ def load_weights(self): dtype=w2_dtype, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - # memory_config=self.w2_mem_config, # TODO: Reenable when DRAM-SHARDED PCC issues resolves + # memory_config=self.mlp_config["W2_MEM_CONFIG"](self.mesh_device), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, 2), cluster_shape=self.cluster_shape), cache_file_name=self.cache_path / w2_cache_str, ) def __call__(self, x: List[ttnn.Tensor], mode="decode") -> List[ttnn.Tensor]: - self.get_mlp_model_config(mode) + self.mlp_config = self.model_config["mlp"][mode] # Decode should have input tensor of shape (seqlen=1, 1, batch, hidden_size) if mode == "decode": return self.decode_forward(x) @@ -264,9 +119,9 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: w1_out = ttnn.matmul( x, self.w1, - # program_config=self.FF1_DRAM_SHARDED_PROGCFG, + # program_config=self.mlp_config["FF1_DRAM_SHARDED_PROGCFG"], core_grid=ttnn.CoreGrid(y=1, x=8), - compute_kernel_config=self.COMPUTE_KERNEL_LOFI, + compute_kernel_config=self.mlp_config["COMPUTE_KERNEL_LOFI"], dtype=ttnn.bfloat16, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ) @@ -274,23 +129,31 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: w3_out = ttnn.matmul( x, self.w3, - # program_config=self.FF1_DRAM_SHARDED_PROGCFG, # TODO: Reenable when DRAM-SHARDED PCC issues resolves + # program_config=self.mlp_config["FF1_DRAM_SHARDED_PROGCFG"], # TODO: Reenable when DRAM-SHARDED PCC issues resolves core_grid=ttnn.CoreGrid(y=1, x=8), - compute_kernel_config=self.COMPUTE_KERNEL_LOFI, + compute_kernel_config=self.mlp_config["COMPUTE_KERNEL_LOFI"], dtype=ttnn.bfloat16, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ) x.deallocate(True) w1_out = tt_sharded_all_reduce( - w1_out, self.mesh_device, cluster_axis=1, num_links=2, memory_config=self.FF1_OUT_GATHERED_MEMCFG + w1_out, + self.mesh_device, + cluster_axis=1, + num_links=2, + memory_config=self.mlp_config["FF1_OUT_GATHERED_MEMCFG"], ) w3_out = tt_sharded_all_reduce( - w3_out, self.mesh_device, cluster_axis=1, num_links=2, memory_config=self.FF1_OUT_GATHERED_MEMCFG + w3_out, + self.mesh_device, + cluster_axis=1, + num_links=2, + memory_config=self.mlp_config["FF1_OUT_GATHERED_MEMCFG"], ) - w1_out = ttnn.to_memory_config(w1_out, self.FULL_GRID_MEMCFG) - w3_out = ttnn.to_memory_config(w3_out, self.FULL_GRID_MEMCFG) + w1_out = ttnn.to_memory_config(w1_out, self.mlp_config["FULL_GRID_MEMCFG"]) + w3_out = ttnn.to_memory_config(w3_out, self.mlp_config["FULL_GRID_MEMCFG"]) hidden_states = ttnn.mul( w1_out, @@ -302,13 +165,13 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: w1_out.deallocate(True) w3_out.deallocate(True) - hidden_states = ttnn.to_memory_config(hidden_states, self.FF2_ACT_MEMCFG) + hidden_states = ttnn.to_memory_config(hidden_states, self.mlp_config["FF2_ACT_MEMCFG"]) hidden_states = ttnn.matmul( hidden_states, self.w2, - # program_config=self.FF2_DRAM_SHARDED_PROGCFG, # TODO: Reenable when DRAM-SHARDED PCC issues resolves + # program_config=self.mlp_config["FF2_DRAM_SHARDED_PROGCFG"], # TODO: Reenable when DRAM-SHARDED PCC issues resolves core_grid=ttnn.CoreGrid(y=1, x=8), - compute_kernel_config=self.COMPUTE_KERNEL_LOFI, + compute_kernel_config=self.mlp_config["COMPUTE_KERNEL_LOFI"], dtype=ttnn.bfloat16, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ) @@ -318,16 +181,16 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: self.mesh_device, cluster_axis=0, num_links=2, - memory_config=self.FF2_OUT_GATHERED_MEMCFG, + memory_config=self.mlp_config["FF2_OUT_GATHERED_MEMCFG"], ) - hidden_states = ttnn.to_memory_config(hidden_states, self.FF1_ACT_MEMCFG) + hidden_states = ttnn.to_memory_config(hidden_states, self.mlp_config["FF1_ACT_MEMCFG"]) return hidden_states def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: _, _, seq_len, _ = x.shape - max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"] + max_mm_seq_len = self.model_config["MAX_MM_SEQ_LEN"](seq_len) batch_dim = 1 if seq_len < max_mm_seq_len else seq_len // max_mm_seq_len # Find the division factor x = ttnn.reshape(x, (1, batch_dim, seq_len // batch_dim, -1)) @@ -335,15 +198,17 @@ def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: x, self.w1, dtype=ttnn.bfloat16, - program_config=self.FF1_PROGCFG, + program_config=self.mlp_config["FF1_PROGCFG"](seq_len), memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"], ) w3_out = ttnn.matmul( x, self.w3, dtype=ttnn.bfloat16, - program_config=self.FF1_PROGCFG, + program_config=self.mlp_config["FF1_PROGCFG"](seq_len), memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"], ) w1_out = ttnn.reshape(w1_out, (1, 1, seq_len, -1)) @@ -363,23 +228,22 @@ def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: num_links=2, ) - # w1_out.deallocate(True) - hidden_states = ttnn.mul( w1_out, w3_out, input_tensor_a_activation=ttnn.UnaryOpType.SILU, dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - # hidden_states = ttnn.to_memory_config(hidden_states, self.FF2_ACT_MEMCFG) hidden_states = ttnn.reshape(hidden_states, (1, batch_dim, seq_len // batch_dim, -1)) hidden_states = ttnn.matmul( hidden_states, self.w2, dtype=ttnn.bfloat16, - program_config=self.FF2_PROGCFG, + program_config=self.mlp_config["FF2_PROGCFG"](seq_len), memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"], ) hidden_states = ttnn.reshape(hidden_states, (1, 1, seq_len, -1)) diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index 0b0bfb4ba4f..d0cc87efc5c 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -7,9 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh - -from models.utility_functions import nearest_32, profiler +from ttnn import ReplicateTensorToMesh from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy from models.demos.tg.llama3_70b.tt.llama_embedding_galaxy import TtLlamaEmbedding_galaxy from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -20,16 +18,17 @@ num_to_corerange, gather_cos_sin, ShardTensor2dMesh, - ConcatMesh2DToTensor, ) from models.demos.tg.llama3_70b.tt.llama_common import ( tt_all_reduce, tt_all_gather, ) -from models.demos.t3000.falcon40b.tt.model_utils import ( - matmul_2d_config_from_tensor_shapes, - matmul_1d_config_from_tensor_shapes, -) + + +def is_power_of_two(n): + if n <= 0: + return False + return (n & (n - 1)) == 0 class TtLlamaModel_galaxy: @@ -114,42 +113,6 @@ def set_model_config(self, model_config): for layer in self.layers: layer.set_model_config(model_config) - def get_model_config(self, mode): - self.LN_COMPUTE_KERNEL_CONFIG = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - self.COMPUTE_KERNEL_CONFIG = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=True, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - self.LM_HEAD_ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(32, 2048 // 32), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - if mode == "prefill": - # seq_len is 32 if we slice LM head input - hidden_size_per_chip = self.hidden_size // self.cluster_shape[0] - self.LM_HEAD_PROGCFG = matmul_1d_config_from_tensor_shapes( - (1, 1, 32, hidden_size_per_chip), # get only last 32 tokens # (1, 1, 32, 2048) - ( - 1, - 1, - hidden_size_per_chip, - self.padded_vocab_size // self.cluster_shape[1], - ), # (1, 1, 2048, 16 * 1024) - grid=ttnn.CoreGrid(x=8, y=2), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - ) - def load_weights(self): norm_str = "norm.weight" lm_head_str = "output.weight" @@ -227,7 +190,14 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, xs = ttnn.to_memory_config(xs, memory_config=ACT_MEMCFG) - rot_mat = get_rotation_mat(self.rot_emb, start_pos, seq_len, batch // self.cluster_shape[0]) + if isinstance(start_pos, int): + cache_idxs = torch.tensor([start_pos for _ in range(batch // self.cluster_shape[0])], dtype=torch.int64) + else: + raise ValueError("start_pos must be an int, different start_pos for each user not supported yet") + cache_idxs = start_pos + + # TODO : Create different rot_mat for each user_groups in the cluster + rot_mat = get_rotation_mat(self.rot_emb, cache_idxs, seq_len, batch // self.cluster_shape[0]) assert rot_mat.size() == (1, batch // self.cluster_shape[0], self.head_dim, self.head_dim) shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch // 4)}) @@ -256,9 +226,10 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, ) attn_masks = None - elif mode == "prefill": - assert seq_len % 128 == 0 and seq_len > 0, "Prefill mode only supports seq_len > 0 and seq_len % 128" + elif mode == "prefill": + # check if seq_len is power of 2 + assert is_power_of_two(seq_len), "Prefill mode only supports seq_len as power of 2" assert xs.shape == (batch, 1, seq_len, self.hidden_size // self.cluster_shape[0]) cos_gathered, sin_gathered = gather_cos_sin( @@ -317,7 +288,7 @@ def __call__( user_id: int = 0, mode="decode", ) -> ttnn.Tensor: - self.get_model_config(mode) + self.core_model_config = self.model_config["core_model"][mode] if mode == "decode": return self.decode_forward(xs, rot_mats, start_pos, attn_masks) elif mode == "prefill": @@ -328,7 +299,7 @@ def __call__( def tt_distributed_rmsnorm(self, inp, epsilon, gamma): # Run distributed rmsnorm part 1 tt_stats = ttnn.rms_norm_pre_all_gather( - inp, compute_kernel_config=self.LN_COMPUTE_KERNEL_CONFIG, dtype=ttnn.bfloat16 + inp, compute_kernel_config=self.core_model_config["LN_COMPUTE_KERNEL_CONFIG"], dtype=ttnn.bfloat16 ) padded_shape = (1, 1, inp.shape[-2], 32) @@ -344,7 +315,11 @@ def tt_distributed_rmsnorm(self, inp, epsilon, gamma): # Run distributed rmsnorm part 2 tt_out = ttnn.rms_norm_post_all_gather( - inp, tt_stats, epsilon=epsilon, weight=gamma, compute_kernel_config=self.LN_COMPUTE_KERNEL_CONFIG + inp, + tt_stats, + epsilon=epsilon, + weight=gamma, + compute_kernel_config=self.core_model_config["LN_COMPUTE_KERNEL_CONFIG"], ) tt_stats.deallocate(True) @@ -370,7 +345,7 @@ def decode_forward( gamma=self.norm_sharded, ) - norm_out = ttnn.to_memory_config(norm_out, memory_config=self.LM_HEAD_ACT_MEMCFG) + norm_out = ttnn.to_memory_config(norm_out, memory_config=self.core_model_config["LM_HEAD_ACT_MEMCFG"]) ### Each device does an LM head fracture lm_head_out = ttnn.matmul( @@ -378,7 +353,7 @@ def decode_forward( self.lm_head, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, dtype=ttnn.bfloat16, - compute_kernel_config=self.COMPUTE_KERNEL_CONFIG, + compute_kernel_config=self.core_model_config["COMPUTE_KERNEL_CONFIG"], ) norm_out.deallocate(True) @@ -410,14 +385,13 @@ def prefill_forward( epsilon=self.norm_eps, gamma=self.norm_sharded, ) - - # Slice out last 32 tokens in LM head to produce next token + # Slice out last padding_length(32) tokens in LM head to produce next token # TODO: Does not work for perplexity, or if we padded input to current sequence length seq_len = norm_out.shape[2] dmodel = norm_out.shape[3] norm_out = ttnn.slice( norm_out, - [0, 0, seq_len - 32, 0], + [0, 0, seq_len - self.model_config["PADDING_LENGTH"], 0], [1, 1, seq_len, dmodel], ) @@ -427,8 +401,8 @@ def prefill_forward( self.lm_head, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat8_b, - compute_kernel_config=self.COMPUTE_KERNEL_CONFIG, - program_config=self.LM_HEAD_PROGCFG, + compute_kernel_config=self.core_model_config["COMPUTE_KERNEL_CONFIG"], + program_config=self.core_model_config["LM_HEAD_PROGCFG"], ) lm_head_out = tt_all_reduce( diff --git a/models/demos/tg/llama3_70b/tt/model_config.py b/models/demos/tg/llama3_70b/tt/model_config.py new file mode 100644 index 00000000000..72a6a69002d --- /dev/null +++ b/models/demos/tg/llama3_70b/tt/model_config.py @@ -0,0 +1,542 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + +from models.demos.t3000.falcon40b.tt.model_utils import ( + matmul_2d_config_from_tensor_shapes, + matmul_1d_config_from_tensor_shapes, +) +from models.utility_functions import nearest_32 + +MAX_SEQ_LEN = 4096 +MAX_SEQ_LEN_LLAMA3 = 8192 +MAX_SEQ_LEN_LLAMA3_1 = 128 * 1024 + + +def num_to_corerange(x): + assert x < 8 or x % 8 == 0 + num_x = min(x, 8) + num_y = x // num_x + assert num_x * num_y == x + return ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(num_x - 1, num_y - 1), + ) + + +def num_to_corerange_set(x): + assert x < 8 or x % 8 == 0 + num_x = min(x, 8) + num_y = x // num_x + assert num_x * num_y == x + return ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(num_x - 1, num_y - 1), + ), + } + ) + + +def get_model_config(llama_version="llama3-tg", max_batch_size=32, max_context_len=4096, cluster_shape=(4, 8)): + assert max_batch_size in (1, 16, 32) + + if max_context_len == 8192: + assert max_batch_size == 16 + elif max_context_len == 128 * 1024: + assert max_batch_size == 1 + else: + assert max_batch_size == 32 + + model_config = { + "MAX_GRID_SIZE": (8, 8), + "CLUSTER_SHAPE": cluster_shape, + "HIDDEN_SIZE": model_config_entries["hidden_size"], + "MAX_BATCH_SIZE": max_batch_size, + "MAX_CONTEXT_LEN": max_context_len, + "llama3-tg": MAX_SEQ_LEN_LLAMA3, + "llama3.1-tg": MAX_SEQ_LEN_LLAMA3_1, + "NUM_DEVICES": 32, + "PADDING_LENGTH": 32, + "MAX_MM_SEQ_LEN": lambda seq_len: min(seq_len, 1024), # Used to support seq len greater than 2k + "CORE_GRID_Y": lambda seq_len: 4 + if min(seq_len, 1024) // 32 >= 4 + else min(seq_len, 1024) // 32, # Core grid must be ratio of seq_len // 32 + "COMPUTE_KERNEL_CONFIG": ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ), + } + + if llama_version == "llama3" or llama_version == "llama3-tg": + model_config["FFN_EXPANDED_HIDDEN_SIZE"] = 28 * 1024 + elif llama_version == "llama3-405b": + model_config["FFN_EXPANDED_HIDDEN_SIZE"] = 52 * 1024 + + # Set attention config + model_config["attention"] = set_attention_config(model_config, max_batch_size) + # Set mlp config + model_config["mlp"] = set_mlp_config(model_config, cluster_shape) + # Set decoder config + model_config["decoder"] = set_decoder_config(model_config) + # Set core model config + model_config["core_model"] = set_core_model_config(model_config, cluster_shape) + return model_config + + +def get_batch_grid_size(batch_size): + if batch_size == 1: + return [1, 1] + elif batch_size == 16: + return [8, 2] + elif batch_size == 32: + return [8, 4] + else: + raise ValueError(f"Unsupported batch size: {batch_size}") + + +def set_attention_config(model_config, max_batch_size): + # Set decode config first + decode_config = {} + + decode_config["ROT_MAT_MM_PROGCFG"] = lambda batch_size: ttnn.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=get_batch_grid_size(batch_size), + in0_block_w=4, # 128 // TILE_SIZE (dynamic) + out_subblock_h=1, + out_subblock_w=4, + per_core_M=1, + per_core_N=4, + ) + + decode_config["FUSED_QKV_MM_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=(8, 5), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=1, + per_core_M=1, + per_core_N=1, + fuse_batch=True, + fused_activation=None, + mcast_in0=True, + ) + + decode_config["COMPUTE_KERNEL_QKV"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + decode_config["COMPUTE_KERNEL_SELFOUT"] = decode_config["COMPUTE_KERNEL_QKV"] + n_local_heads = model_config_entries["num_attention_heads"] // model_config_entries["num_kv_heads"] + n_local_kv_heads = 1 + head_dim = model_config_entries["head_dim"] + total_cores = (n_local_heads + n_local_kv_heads * 2) * head_dim // 32 # 1280 / 32 = 40 + assert total_cores == 40, f"total_cores: {total_cores}" + shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(total_cores)}) + + decode_config["CREATE_HEAD_INPUT_MEMCFG"] = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_n_cores_grid, + [ + 32, + 32, + ], + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + + decode_config["COMPUTE_KERNEL_ROTARY"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + decode_config["ROTARY_PROGCFG"] = ttnn.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=[8, 1], + in0_block_w=4, + out_subblock_h=1, + out_subblock_w=4, + per_core_M=1, + per_core_N=4, + ) + + decode_config["COMPUTE_KERNEL_SDPA"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + padded_local_heads = 32 + + decode_config["SDPA_HEIGHT_SHARDED_MEMCFG"] = lambda batch_size_per_device_group: ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + ttnn.CoreRangeSet({num_to_corerange(batch_size_per_device_group)}), + (padded_local_heads, head_dim), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + + decode_config["QKV_OUT_GATHERED_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( + shape=(32 * mesh_cols, 1280 // 40), # mesh_cols = 4 + core_grid=ttnn.CoreGrid(y=5, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + decode_config["SELF_OUT_GATHERED_MEMCFG"] = lambda mesh_rows: ttnn.create_sharded_memory_config( + shape=(32 * mesh_rows, 2048 // 32), # mesh_rows = 8 + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + decode_config["GATHER_USERS_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( + shape=(32 * mesh_cols, 1024 // 32), # mesh_cols = 4 + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + # Set prefill config + prefill_config = {} + + prefill_config["COMPUTE_KERNEL_QKV"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + prefill_config["COMPUTE_KERNEL_SELFOUT"] = prefill_config["COMPUTE_KERNEL_QKV"] + prefill_config["COMPUTE_KERNEL_ROTARY"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + prefill_config["COMPUTE_KERNEL_SDPA"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + prefill_config["SDPA_PROG_CFG"] = lambda seq_len: ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=[8, 7], + q_chunk_size=256 if seq_len % 256 == 0 else 32, + k_chunk_size=256 if seq_len % 256 == 0 else 32, + ) + + prefill_config["FUSED_QKV_MM_PROGCFG"] = lambda seq_len: matmul_2d_config_from_tensor_shapes( + (1, 1, model_config["MAX_MM_SEQ_LEN"](seq_len), 2048), + (1, 1, 2048, 1280), + grid=ttnn.CoreGrid(x=8, y=model_config["CORE_GRID_Y"](seq_len)), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + fuse_batch=False, + ) + + prefill_config["SELFOUT_PROGCFG"] = lambda seq_len: matmul_2d_config_from_tensor_shapes( + (1, 1, model_config["MAX_MM_SEQ_LEN"](seq_len), 1024), + (1, 1, 1024, 2048), + grid=ttnn.CoreGrid(x=8, y=model_config["CORE_GRID_Y"](seq_len)), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + fuse_batch=False, + ) + + return {"prefill": prefill_config, "decode": decode_config} + + +def set_mlp_config(model_config, cluster_shape): + decode_config = {} + + decode_config["COMPUTE_KERNEL_LOFI"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + M, K, N = 32, model_config["HIDDEN_SIZE"], model_config["FFN_EXPANDED_HIDDEN_SIZE"] + K = K // cluster_shape[0] + N = N // cluster_shape[1] + decode_config["W1_MEM_CONFIG"] = lambda mesh_device: ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.ShardSpec( + setup_weight_grid(mesh_device), + (K, nearest_32(N // 12)), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + + decode_config["W2_MEM_CONFIG"] = lambda mesh_device: ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.ShardSpec( + setup_weight_grid(mesh_device), + (N, nearest_32(K // 12)), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + + decode_config["FF1_DRAM_SHARDED_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=K // 8 // 32, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size + per_core_M=M // 32, # M / TILE_HEIGHT = 32 / 32 + per_core_N=N // 8 // 32, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size + fused_activation=None, + ) + + decode_config["FF2_DRAM_SHARDED_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=N // 8 // 32, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size + per_core_M=M // 32, # M / TILE_HEIGHT = 32 / 32 + per_core_N=K // 8 // 32, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size + fused_activation=None, + ) + + full_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 7), + ) + } + ) + decode_config["FULL_GRID_MEMCFG"] = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + full_grid, + [ + 32, + nearest_32(56), + ], + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + decode_config["FF2_ACT_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(M, N // 8), + core_grid=ttnn.CoreGrid(y=1, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + decode_config["FF1_ACT_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(32, 2048 // 8), + core_grid=ttnn.CoreGrid(y=1, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + decode_config["FF1_OUT_GATHERED_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(M * cluster_shape[0], N // 8), + core_grid=ttnn.CoreGrid(y=1, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + decode_config["FF2_OUT_GATHERED_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(32 * cluster_shape[1], 2048 // 8), + core_grid=ttnn.CoreGrid(y=1, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + prefill_config = {} + + prefill_config["COMPUTE_KERNEL_LOFI"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + hidden_dim_per_chip = model_config_entries["hidden_size"] // cluster_shape[0] # 2048 + ff_outer_dim_per_chip = model_config["FFN_EXPANDED_HIDDEN_SIZE"] // cluster_shape[1] # 3584 + prefill_config["FF1_PROGCFG"] = lambda seq_len: matmul_2d_config_from_tensor_shapes( + ( + 1, + 1, + model_config["MAX_MM_SEQ_LEN"](seq_len), + hidden_dim_per_chip, + ), # (1, 1, model_config["MAX_MM_SEQ_LEN"], 2048) + (1, 1, hidden_dim_per_chip, ff_outer_dim_per_chip), # (1, 1, 2048, 3584) + grid=ttnn.CoreGrid(x=8, y=model_config["CORE_GRID_Y"](seq_len)), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + fuse_batch=False, + ) + prefill_config["FF2_PROGCFG"] = lambda seq_len: matmul_2d_config_from_tensor_shapes( + ( + 1, + 1, + model_config["MAX_MM_SEQ_LEN"](seq_len), + ff_outer_dim_per_chip, + ), # (1, 1, self.model_config["MAX_MM_SEQ_LEN"], 3584) + (1, 1, ff_outer_dim_per_chip, hidden_dim_per_chip), # (1, 1, 3584, 2048) + grid=ttnn.CoreGrid(x=8, y=model_config["CORE_GRID_Y"](seq_len)), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + fuse_batch=False, + ) + + return {"prefill": prefill_config, "decode": decode_config} + + +def set_decoder_config(model_config): + decode_config = {} + + decode_config["LN_COMPUTE_KERNEL_CONFIG"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + decode_config["LN_PROGCFG"] = ttnn.LayerNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=[8, 4], + subblock_w=8, + block_h=32 // 32, + block_w=8, + inplace=False, + ) + + shard_spec_32_cores_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 3), + ), + } + ) + + decode_config["LN_OUTPUT_MEMCFG"] = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_32_cores_grid, + [ + 32, + 8192 // 32, + ], + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ) + decode_config["ATTN_ACT_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(32, 2048 // 32), + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + decode_config["MLP_ACT_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(32, 2048 // 8), + core_grid=ttnn.CoreGrid(y=1, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + prefill_config = {} + + prefill_config["LN_COMPUTE_KERNEL_CONFIG"] = decode_config["LN_COMPUTE_KERNEL_CONFIG"] + + return {"prefill": prefill_config, "decode": decode_config} + + +def set_core_model_config(model_config, cluster_shape): + decode_config = {} + decode_config["LN_COMPUTE_KERNEL_CONFIG"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + decode_config["COMPUTE_KERNEL_CONFIG"] = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + decode_config["LM_HEAD_ACT_MEMCFG"] = ttnn.create_sharded_memory_config( + shape=(32, 2048 // 32), + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + prefill_config = {} + prefill_config["LN_COMPUTE_KERNEL_CONFIG"] = decode_config["LN_COMPUTE_KERNEL_CONFIG"] + prefill_config["COMPUTE_KERNEL_CONFIG"] = decode_config["COMPUTE_KERNEL_CONFIG"] + prefill_config["LM_HEAD_ACT_MEMCFG"] = decode_config["LM_HEAD_ACT_MEMCFG"] + + hidden_size_per_chip = model_config_entries["hidden_size"] // cluster_shape[0] + prefill_config["LM_HEAD_PROGCFG"] = matmul_1d_config_from_tensor_shapes( + ( + 1, + 1, + model_config["PADDING_LENGTH"], + hidden_size_per_chip, + ), # get only last padding_length (32) tokens + ( + 1, + 1, + hidden_size_per_chip, + model_config_entries["padded_vocab_size"] // cluster_shape[1], + ), # (1, 1, 2048, 16 * 1024) + grid=ttnn.CoreGrid(x=8, y=4), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + ) + + return {"prefill": prefill_config, "decode": decode_config} + + +def setup_weight_grid(mesh_device): + weight_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord( + mesh_device.dram_grid_size().x - 1, + mesh_device.dram_grid_size().y - 1, + ), + ) + } + ) + return weight_grid + + +model_config_entries = { + "hidden_size": 8192, + "head_dim": 128, + "num_attention_heads": 64, + "num_kv_heads": 8, + "num_layers": 80, + "weight_cache": True, + "vocab_size": 128256, + "padded_vocab_size": 128 * 1024, + "mlp_dim": 28672, + "padded_mlp_dim": 32768, + "layer_norm_epsilon": 1e-05, +}