Skip to content

Commit

Permalink
#5383: [Falcon7b] Add support for decode-2k (l1-sharded) by disabling…
Browse files Browse the repository at this point in the history
… fp32acc on QK^T and fixing l1 fragmentation

Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed Jun 5, 2024
1 parent 102534a commit 5e9d561
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
14 changes: 9 additions & 5 deletions models/demos/falcon7b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ def run_perf_wh_bare_metal(
all_devices,
async_mode,
):
if model_config_str == "BFLOAT16-L1_SHARDED" and kv_cache_len == 2047:
pytest.skip(f"kv_cache_len={kv_cache_len} does not fit with L1_SHARDED")
if model_config_str == "BFLOAT16-L1_SHARDED" and llm_mode == "prefill":
pytest.skip(f"prefill does not support L1_SHARDED")
if num_devices > 1:
Expand Down Expand Up @@ -518,9 +516,10 @@ def run_perf_wh_bare_metal(
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.86, 0.92, 0.92, 0.4),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.86, 0.92, 0.92, 0.35),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.85, 0.93, 0.94, 0.1),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.94, 0.94, 0.1),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.88, 0.93, 0.93, 0.75),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.88, 0.93, 0.93, 0.6),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.88, 0.92, 0.93, 0.6),
),
ids=[
"prefill_seq128_bf16_dram",
Expand All @@ -534,6 +533,7 @@ def run_perf_wh_bare_metal(
"decode_batch32_1024_bf16_l1_sharded",
"decode_batch32_2047_bf16_dram",
"decode_batch32_2047_bf16_l1",
"decode_batch32_2047_bf16_l1_sharded",
],
)
@pytest.mark.parametrize("async_mode", (False, True))
Expand Down Expand Up @@ -589,24 +589,28 @@ def test_perf_wh_bare_metal(
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 0.18, False), # Issue 7816 Inference time
("prefill", 4, 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 0.5, False),
("prefill", 4, 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 1.1, False),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.21, False),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.86, 0.90, 0.91, 0.21, False),
("decode", 4, 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.77, 0.69, 0.72, 0.21, False),
("prefill", 4, 32, 1, 128, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.97, 0.1, True),
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.97, 0.18, True),
("prefill", 4, 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 0.5, True),
("prefill", 4, 32, 1, 2048, 0, "BFLOAT16-DRAM", 0.99, 0.99, 0.98, 1.1, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.09, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.86, 0.90, 0.91, 0.09, True),
("decode", 4, 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.77, 0.69, 0.72, 0.21, True),
),
ids=[
"prefill_seq128",
"prefill_seq256",
"prefill_seq1024",
"prefill_seq2048",
"decode_batch32_1024",
"decode_batch32_2047",
"prefill_seq128_async",
"prefill_seq256_async",
"prefill_seq1024_async",
"prefill_seq2048_async",
"decode_batch32_1024_async",
"decode_batch32_2047_async",
],
)
@skip_for_grayskull()
Expand Down
13 changes: 7 additions & 6 deletions models/demos/falcon7b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def forward(
for i in range(self.num_devices):
# Update kv_cache in place
ttnn.experimental.tensor.update_cache(layer_past[i][0], key_layer[i], layer_past_len)
key_layer[i].deallocate(True)
for i in range(self.num_devices):
# key and value layers will have kv_seq_len padded to nearest 32
key_layer[i] = ttnn.experimental.tensor.unpad(
Expand Down Expand Up @@ -779,16 +780,16 @@ def forward(
for i, device in enumerate(self.devices):
attn_weights.append(
ttnn.experimental.operations.primary.matmul(
query_layer[i],
key_layer_transposed[i],
query_layer[i], # [batch, 1, padded_local_heads, head_dim]
key_layer_transposed[i], # [batch, 1, head_dim, padded_layer_past_len]
program_config=self.model_config["ATTN_BATCHED_MM_PROGCFG"](
self.head_dim // 32, self.padded_local_heads // 32, padded_layer_past_len // 32
),
output_mem_config=self.model_config["ATTN_BATCH_SHARDED_MEMCFG"](
self.padded_local_heads, padded_layer_past_len
),
output_dtype=self.model_config["PRE_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16
compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"],
compute_kernel_config=self.model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"],
)
)
query_layer[i].deallocate()
Expand Down Expand Up @@ -894,8 +895,8 @@ def forward(
for i in range(self.num_devices):
attn_output.append(
ttnn.experimental.operations.primary.matmul(
attn_weights[i],
value_layer[i],
attn_weights[i], # [batch, 1, padded_local_heads, padded_layer_past_len]
value_layer[i], # [batch, 1, padded_layer_past_len, head_dim]
program_config=self.model_config["ATTN_BATCHED_MM_PROGCFG"](
padded_layer_past_len // 32,
self.padded_local_heads // 32,
Expand All @@ -906,7 +907,7 @@ def forward(
self.head_dim,
),
output_dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"],
compute_kernel_config=self.model_config["COMPUTE_KERNEL_CONFIG"],
compute_kernel_config=self.model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"],
)
)
attn_weights[i].deallocate(True)
Expand Down
18 changes: 15 additions & 3 deletions models/demos/falcon7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def get_model_config(model_config_str, prefill_seq_len=0):
if model_config_str == "BFLOAT16-L1_SHARDED":
model_config["ATTN_MASK_MEMCFG"] = L1_MEMCFG
model_config["ROTARY_EMBEDDING_OUTPUT_MEMCFG"] = L1_MEMCFG
model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG
if not model_config_str == "BFLOAT16-L1_SHARDED":
# Don't send keys to l1 before converting to l1-sharded (after kcache update) to avoid l1 framgentation issues with kv_cache_size=2048
model_config["K_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG
model_config["V_CACHE_SLICE_OUTPUT_MEMCFG"] = L1_MEMCFG
model_config["K_TRANSPOSED_OUTPUT_MEMCFG"] = L1_MEMCFG
model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"] = L1_MEMCFG
Expand Down Expand Up @@ -191,17 +193,27 @@ def get_model_config(model_config_str, prefill_seq_len=0):
)

if is_wormhole_b0():
model_config["COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.WormholeComputeKernelConfig(
model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.WormholeComputeKernelConfig(
math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
model_config[
"POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"
] = ttnn.experimental.tensor.WormholeComputeKernelConfig(
math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=True,
packer_l1_acc=True,
)
else:
model_config["COMPUTE_KERNEL_CONFIG"] = ttnn.experimental.tensor.GrayskullComputeKernelConfig(
gs_compute_kernel_config = ttnn.experimental.tensor.GrayskullComputeKernelConfig(
math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi,
math_approx_mode=True,
)
model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config
model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config

# uncomment if need to see all the configs
# logger.debug(f"Falcon model config: \n{pretty_print_model_config(model_config)}")
Expand Down

0 comments on commit 5e9d561

Please sign in to comment.