Skip to content

Commit

Permalink
#5592: Disable attention
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jovic committed May 17, 2024
1 parent 048d889 commit a74e5e5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
("prefill", 32, 1, 32, 0, "BFLOAT16-DRAM", 0.97, 0.95, 0.95),
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96),
("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96),
# ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97), # CI machines don't have enough RAM memory to run this test atm; to reduce memory usage (#8349)
# ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96), # CI machines don't have enough RAM memory to run this test atm; to reduce memory usage (#8349)
),
ids=[
"prefill_seq32",
Expand Down
2 changes: 1 addition & 1 deletion models/demos/falcon7b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def run_perf_wh_bare_metal(
(
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96, 0.1),
("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 1),
("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 1),
("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 1),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1),
Expand Down
6 changes: 5 additions & 1 deletion models/demos/falcon7b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ def forward(

seq_len = hidden_states[0].get_legacy_shape()[2]

if self.model_config["PREFILL_OPTIMIZED_MODE"] and seq_len in [128, 1024, 2048]:
if (
self.model_config["PREFILL_OPTIMIZED_MODE"]
and self.model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"]
and seq_len in [128, 1024, 2048]
):
attn_output, layer_present = self._optimized_forward(
hidden_states,
attention_mask,
Expand Down
6 changes: 5 additions & 1 deletion models/demos/falcon7b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
dim=-1,
)

if self.model_config["PREFILL_OPTIMIZED_MODE"] and num_input_tokens in [128, 1024, 2048]:
if (
self.model_config["PREFILL_OPTIMIZED_MODE"]
and self.model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"]
and num_input_tokens in [128, 1024, 2048]
):
attention_mask_ = create_prefill_attn_mask_for_sharded_softmax(
attention_mask_bool_padded * -1e5,
self.config.num_attention_heads,
Expand Down
1 change: 1 addition & 0 deletions models/demos/falcon7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def get_model_config(model_config_str, prefill_seq_len=0):

def set_prefill_config(model_config, seq_len, dram_memcfg):
model_config["PREFILL_OPTIMIZED_MODE"] = not is_grayskull()
model_config["PREFILL_ATTENTION_OPTIMIZED_MODE"] = False # enable when #8349 is fixed
model_config["MLP_SEQ_LEN"] = seq_len
model_config["MLP_PADDING_VALUE"] = 4608
model_config["MLP_GRID_SIZE"] = (8, 8)
Expand Down

0 comments on commit a74e5e5

Please sign in to comment.