Skip to content

Commit

Permalink
#14008: enabled non-causal flash decode and paged flash decode
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Oct 22, 2024
1 parent a7ba1ad commit 4e1b7ae
Show file tree
Hide file tree
Showing 14 changed files with 468 additions and 174 deletions.
2 changes: 1 addition & 1 deletion models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def forward_decode(
q_heads_1B4D,
keys_1BPD,
values_1BPD,
start_pos_ids,
cur_pos=start_pos_ids,
scale=self.scale,
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"],
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def attn_mqa(
query_layer,
keys,
values,
[start_pos for _ in range(self.max_batch_size)],
cur_pos=[start_pos for _ in range(self.max_batch_size)],
scale=self.scale,
program_config=program_config,
compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SDPA"],
Expand Down
Loading

0 comments on commit 4e1b7ae

Please sign in to comment.