From 5e9d5613b5618616232b2791b4a0daa714a1110b Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Wed, 5 Jun 2024 22:15:08 +0000 Subject: [PATCH] #5383: [Falcon7b] Add support for decode-2k (l1-sharded) by disabling fp32acc on QK^T and fixing l1 fragmentation Signed-off-by: Salar Hosseini --- .../demos/falcon7b/tests/test_perf_falcon.py | 14 +++++++++----- models/demos/falcon7b/tt/falcon_attention.py | 13 +++++++------ models/demos/falcon7b/tt/model_config.py | 18 +++++++++++++++--- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/models/demos/falcon7b/tests/test_perf_falcon.py b/models/demos/falcon7b/tests/test_perf_falcon.py index 4813a5fc92a8..610a556411c5 100644 --- a/models/demos/falcon7b/tests/test_perf_falcon.py +++ b/models/demos/falcon7b/tests/test_perf_falcon.py @@ -470,8 +470,6 @@ def run_perf_wh_bare_metal( all_devices, async_mode, ): - if model_config_str == "BFLOAT16-L1_SHARDED" and kv_cache_len == 2047: - pytest.skip(f"kv_cache_len={kv_cache_len} does not fit with L1_SHARDED") if model_config_str == "BFLOAT16-L1_SHARDED" and llm_mode == "prefill": pytest.skip(f"prefill does not support L1_SHARDED") if num_devices > 1: @@ -518,9 +516,10 @@ def run_perf_wh_bare_metal( ("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1), ("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.86, 0.92, 0.92, 0.4), ("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.86, 0.92, 0.92, 0.35), - ("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.85, 0.93, 0.94, 0.1), + ("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.94, 0.94, 0.1), ("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.88, 0.93, 0.93, 0.75), ("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.88, 0.93, 0.93, 0.6), + ("decode", 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.88, 0.92, 0.93, 0.6), ), ids=[ "prefill_seq128_bf16_dram", @@ -534,6 +533,7 @@ def run_perf_wh_bare_metal( "decode_batch32_1024_bf16_l1_sharded", "decode_batch32_2047_bf16_dram", "decode_batch32_2047_bf16_l1", + "decode_batch32_2047_bf16_l1_sharded", ], ) @pytest.mark.parametrize("async_mode", (False, True)) @@ -589,12 +589,14 @@ def test_perf_wh_bare_metal( ("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 0.18, False), # Issue 7816 Inference time ("prefill", 4, 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 0.5, False), ("prefill", 4, 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 1.1, False), - ("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.21, False), + ("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.86, 0.90, 0.91, 0.21, False), + ("decode", 4, 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.77, 0.69, 0.72, 0.21, False), ("prefill", 4, 32, 1, 128, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.97, 0.1, True), ("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 0.18, True), ("prefill", 4, 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 0.5, True), ("prefill", 4, 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 1.1, True), - ("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.09, True), + ("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.86, 0.90, 0.91, 0.09, True), + ("decode", 4, 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.77, 0.69, 0.72, 0.21, True), ), ids=[ "prefill_seq128", @@ -602,11 +604,13 @@ def test_perf_wh_bare_metal( "prefill_seq1024", "prefill_seq2048", "decode_batch32_1024", + "decode_batch32_2047", "prefill_seq128_async", "prefill_seq256_async", "prefill_seq1024_async", "prefill_seq2048_async", "decode_batch32_1024_async", + "decode_batch32_2047_async", ], ) @skip_for_grayskull() diff --git a/models/demos/falcon7b/tt/falcon_attention.py b/models/demos/falcon7b/tt/falcon_attention.py index 57a4d51809b9..97fdfa952a0f 100644 --- a/models/demos/falcon7b/tt/falcon_attention.py +++ b/models/demos/falcon7b/tt/falcon_attention.py @@ -704,6 +704,7 @@ def forward( for i in range(self.num_devices): # Update kv_cache in place ttnn.experimental.tensor.update_cache(layer_past[i][0], key_layer[i], layer_past_len) + key_layer[i].deallocate(True) for i in range(self.num_devices): # key and value layers will have kv_seq_len padded to nearest 32 key_layer[i] = ttnn.experimental.tensor.unpad( @@ -779,8 +780,8 @@ def forward( for i, device in enumerate(self.devices): attn_weights.append( ttnn.experimental.operations.primary.matmul( - query_layer[i], - key_layer_transposed[i], + query_layer[i], # [batch, 1, padded_local_heads, head_dim] + key_layer_transposed[i], # [batch, 1, head_dim, padded_layer_past_len] program_config=self.model_config["ATTN_BATCHED_MM_PROGCFG"]( self.head_dim // 32, self.padded_local_heads // 32, padded_layer_past_len // 32 ), @@ -788,7 +789,7 @@ def forward( self.padded_local_heads, padded_layer_past_len ), output_dtype=self.model_config["PRE_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16 - compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"], + compute_kernel_config=self.model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"], ) ) query_layer[i].deallocate() @@ -894,8 +895,8 @@ def forward( for i in range(self.num_devices): attn_output.append( ttnn.experimental.operations.primary.matmul( - attn_weights[i], - value_layer[i], + attn_weights[i], # [batch, 1, padded_local_heads, padded_layer_past_len] + value_layer[i], # [batch, 1, padded_layer_past_len, head_dim] program_config=self.model_config["ATTN_BATCHED_MM_PROGCFG"]( padded_layer_past_len // 32, self.padded_local_heads // 32, @@ -906,7 +907,7 @@ def forward( self.head_dim, ), output_dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"], - compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"], + compute_kernel_config=self.model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"], ) ) attn_weights[i].deallocate(True) diff --git a/models/demos/falcon7b/tt/model_config.py b/models/demos/falcon7b/tt/model_config.py index 20973390fd6c..adc982e3f054 100644 --- a/models/demos/falcon7b/tt/model_config.py +++ b/models/demos/falcon7b/tt/model_config.py @@ -145,7 +145,9 @@ def get_model_config(model_config_str, prefill_seq_len=0): if model_config_str == "BFLOAT16-L1_SHARDED": model_config["ATTN_MASK_MEMCFG"] = L1_MEMCFG model_config["ROTARY_EMBEDDING_OUTPUT_MEMCFG"] = L1_MEMCFG - model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG + if not model_config_str == "BFLOAT16-L1_SHARDED": + # Don't send keys to l1 before converting to l1-sharded (after kcache update) to avoid l1 framgentation issues with kv_cache_size=2048 + model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG model_config["V_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = L1_MEMCFG model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = L1_MEMCFG @@ -191,17 +193,27 @@ def get_model_config(model_config_str, prefill_seq_len=0): ) if is_wormhole_b0(): - model_config["COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.WormholeComputeKernelConfig( + model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.WormholeComputeKernelConfig( + math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + model_config[ + "POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG" + ] = ttnn.experimental.tensor.WormholeComputeKernelConfig( math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi, math_approx_mode=True, fp32_dest_acc_en=True, packer_l1_acc=True, ) else: - model_config["COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.GrayskullComputeKernelConfig( + gs_compute_kernel_config = ttnn.experimental.tensor.GrayskullComputeKernelConfig( math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi, math_approx_mode=True, ) + model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config + model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config # uncomment if need to see all the configs # logger.debug(f"Falcon model config: \n{pretty_print_model_config(model_config)}")