From a74e5e530db83d1d662e589deae3caa20e21a180 Mon Sep 17 00:00:00 2001 From: Sofija Jovic Date: Fri, 17 May 2024 10:17:47 +0000 Subject: [PATCH] #5592: Disable attention --- .../falcon7b/tests/ci/test_falcon_end_to_end_prefill.py | 2 +- models/demos/falcon7b/tests/test_perf_falcon.py | 2 +- models/demos/falcon7b/tt/falcon_attention.py | 6 +++++- models/demos/falcon7b/tt/falcon_model.py | 6 +++++- models/demos/falcon7b/tt/model_config.py | 1 + 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py b/models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py index 416854f8ca4..2d18afc669b 100644 --- a/models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py +++ b/models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py @@ -15,7 +15,7 @@ ("prefill", 32, 1, 32, 0, "BFLOAT16-DRAM", 0.97, 0.95, 0.95), ("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96), ("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96), - # ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97), # CI machines don't have enough RAM memory to run this test atm; to reduce memory usage (#8349) + # ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96), # CI machines don't have enough RAM memory to run this test atm; to reduce memory usage (#8349) ), ids=[ "prefill_seq32", diff --git a/models/demos/falcon7b/tests/test_perf_falcon.py b/models/demos/falcon7b/tests/test_perf_falcon.py index aacf997eb7d..859d9f4c49e 100644 --- a/models/demos/falcon7b/tests/test_perf_falcon.py +++ b/models/demos/falcon7b/tests/test_perf_falcon.py @@ -511,7 +511,7 @@ def run_perf_wh_bare_metal( ( ("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96, 0.1), ("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 1), - ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 1), + ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 1), ("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.91, 0.92, 0.93, 0.15), ("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.91, 0.92, 0.93, 0.15), ("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1), diff --git a/models/demos/falcon7b/tt/falcon_attention.py b/models/demos/falcon7b/tt/falcon_attention.py index b670755ec60..9d4e47cc141 100644 --- a/models/demos/falcon7b/tt/falcon_attention.py +++ b/models/demos/falcon7b/tt/falcon_attention.py @@ -216,7 +216,11 @@ def forward( seq_len = hidden_states[0].get_legacy_shape()[2] - if self.model_config["PREFILL_OPTIMIZED_MODE"] and seq_len in [128, 1024, 2048]: + if ( + self.model_config["PREFILL_OPTIMIZED_MODE"] + and self.model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"] + and seq_len in [128, 1024, 2048] + ): attn_output, layer_present = self._optimized_forward( hidden_states, attention_mask, diff --git a/models/demos/falcon7b/tt/falcon_model.py b/models/demos/falcon7b/tt/falcon_model.py index d8e7f82c2b6..76aa979b5f6 100644 --- a/models/demos/falcon7b/tt/falcon_model.py +++ b/models/demos/falcon7b/tt/falcon_model.py @@ -117,7 +117,11 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token dim=-1, ) - if self.model_config["PREFILL_OPTIMIZED_MODE"] and num_input_tokens in [128, 1024, 2048]: + if ( + self.model_config["PREFILL_OPTIMIZED_MODE"] + and self.model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"] + and num_input_tokens in [128, 1024, 2048] + ): attention_mask_ = create_prefill_attn_mask_for_sharded_softmax( attention_mask_bool_padded * -1e5, self.config.num_attention_heads, diff --git a/models/demos/falcon7b/tt/model_config.py b/models/demos/falcon7b/tt/model_config.py index 81a0f8120eb..08575446369 100644 --- a/models/demos/falcon7b/tt/model_config.py +++ b/models/demos/falcon7b/tt/model_config.py @@ -195,6 +195,7 @@ def get_model_config(model_config_str, prefill_seq_len=0): def set_prefill_config(model_config, seq_len, dram_memcfg): model_config["PREFILL_OPTIMIZED_MODE"] = not is_grayskull() + model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"] = False # enable when #8349 is fixed model_config["MLP_SEQ_LEN"] = seq_len model_config["MLP_PADDING_VALUE"] = 4608 model_config["MLP_GRID_SIZE"] = (8, 8)