Skip to content

Commit

Permalink
Revert remaining affected matmuls to previous configs
Browse files Browse the repository at this point in the history
Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Nov 25, 2024
1 parent 429f86b commit 9832c19
Showing 7 changed files with 75 additions and 48 deletions.
4 changes: 2 additions & 2 deletions models/demos/falcon7b_common/tests/run_falcon_end_to_end.py
Original file line number Diff line number Diff line change
@@ -95,8 +95,8 @@ class DeviceSetup(Enum):
},
DeviceSetup.WORMHOLE_B0: {
"BFLOAT16-DRAM": {128: (0.89, 0.92, 0.91), 1024: (0.92, 0.94, 0.95), 2047: (0.95, 0.96, 0.97)},
"BFLOAT16-L1": {128: (0.89, 0.92, 0.91), 1024: (0.89, 0.93, 0.93), 2047: (0.96, 0.99, 0.99)},
"BFLOAT16-L1_SHARDED": {128: (0.89, 0.91, 0.91), 1024: (0.90, 0.96, 0.95), 2047: (0.96, 0.99, 0.99)},
"BFLOAT16-L1": {128: (0.89, 0.92, 0.91), 1024: (0.92, 0.94, 0.95), 2047: (0.95, 0.96, 0.97)},
"BFLOAT16-L1_SHARDED": {128: (0.90, 0.91, 0.91), 1024: (0.93, 0.94, 0.96), 2047: (0.92, 0.93, 0.94)},
},
DeviceSetup.T3000: {
"BFLOAT16-L1_SHARDED": {128: (0.85, 0.89, 0.90), 1024: (0.90, 0.92, 0.93), 2047: (0.95, 0.91, 0.89)}
44 changes: 22 additions & 22 deletions models/demos/falcon7b_common/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
@@ -163,32 +163,32 @@ def run_perf_wh_bare_metal(
@pytest.mark.parametrize(
"llm_mode, num_layers, batch, seq_len, kv_cache_len, model_config_str, expected_inference_time",
(
# ("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.1),
# ("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.5),
# ("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 1.1),
# ("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.15),
# ("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.15),
# ("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.1),
# ("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.4),
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.1),
("prefill", 32, 1, 1024, 0, "BFLOAT16-DRAM", 0.5),
("prefill", 32, 1, 2048, 0, "BFLOAT16-DRAM", 1.1),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.1),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.4),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.35),
# ("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.1),
# ("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.75),
# ("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.6),
# ("decode", 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.11),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.1),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.75),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.6),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1_SHARDED", 0.11),
),
ids=[
# "prefill_seq128_bf16_dram",
# "prefill_seq1024_bf16_dram",
# "prefill_seq2048_bf16_dram",
# "decode_batch32_128_bf16_dram",
# "decode_batch32_128_bf16_l1",
# "decode_batch32_128_bf16_l1_sharded",
# "decode_batch32_1024_bf16_dram",
# "decode_batch32_1024_bf16_l1",
# "decode_batch32_1024_bf16_l1_sharded",
# "decode_batch32_2047_bf16_dram",
"prefill_seq128_bf16_dram",
"prefill_seq1024_bf16_dram",
"prefill_seq2048_bf16_dram",
"decode_batch32_128_bf16_dram",
"decode_batch32_128_bf16_l1",
"decode_batch32_128_bf16_l1_sharded",
"decode_batch32_1024_bf16_dram",
"decode_batch32_1024_bf16_l1",
"decode_batch32_1024_bf16_l1_sharded",
"decode_batch32_2047_bf16_dram",
"decode_batch32_2047_bf16_l1",
# "decode_batch32_2047_bf16_l1_sharded",
"decode_batch32_2047_bf16_l1_sharded",
],
)
@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True, ids=["noasync", "async"])
8 changes: 6 additions & 2 deletions models/demos/falcon7b_common/tt/falcon_attention.py
Original file line number Diff line number Diff line change
@@ -217,6 +217,7 @@ def forward(
memory_config=self.model_config["FUSED_QKV_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["FUSED_QKV_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(hidden_states.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

###########
@@ -261,7 +262,7 @@ def forward(
query_layer,
key_layer_transposed,
memory_config=self.model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_config["HiFi2_KERNEL_CONFIG"],
compute_kernel_config=self.model_config["DEFAULT_HiFi2_KERNEL_CONFIG"],
)
query_layer.deallocate()
key_layer_transposed.deallocate()
@@ -296,7 +297,7 @@ def forward(
attn_weights,
value_layer,
memory_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_config["HiFi2_KERNEL_CONFIG"],
compute_kernel_config=self.model_config["DEFAULT_HiFi2_KERNEL_CONFIG"],
)
attn_weights.deallocate()
value_layer.deallocate()
@@ -315,6 +316,7 @@ def forward(
memory_config=self.model_config["SELFOUT_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["SELFOUT_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(attn_output.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

return attn_output, layer_present
@@ -591,6 +593,7 @@ def forward(
memory_config=self.model_config["FUSED_QKV_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["FUSED_QKV_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(hidden_states.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

###########
@@ -847,6 +850,7 @@ def forward(
memory_config=self.model_config["SELFOUT_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["SELFOUT_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(attn_output.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

return attn_output, layer_present
10 changes: 6 additions & 4 deletions models/demos/falcon7b_common/tt/falcon_causallm.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,11 @@
from ttnn import ReplicateTensorToMesh
from models.demos.falcon7b_common.tt.falcon_lm_head import falcon_lm_head_matmul_2d
from models.demos.falcon7b_common.tt.falcon_model import TtFalconModelShared
from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid, get_weights_cached
from models.demos.falcon7b_common.tt.model_utils import (
get_falcon_default_core_grid,
get_weights_cached,
get_default_hifi2_kernel_config,
)
from models.demos.falcon7b_common.tests.test_utils import tt_from_torch
from models.utility_functions import (
is_grayskull,
@@ -21,7 +25,6 @@ def falcon_lm_head_matmul(
input_tensor_a,
input_tensor_b,
core_grid,
compute_kernel_config,
output_mem_config=ttnn.DRAM_MEMORY_CONFIG,
output_dtype=None,
):
@@ -33,7 +36,7 @@ def falcon_lm_head_matmul(
input_tensor_b,
memory_config=output_mem_config,
dtype=output_dtype,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=get_default_hifi2_kernel_config(),
)

if is_grayskull():
@@ -177,7 +180,6 @@ def forward(
lm_logits = falcon_lm_head_matmul(
hidden_states,
self.lm_head_weights,
compute_kernel_config=self.model_config["HiFi2_KERNEL_CONFIG"],
output_mem_config=self.model_config["LM_HEAD_MM_OUTPUT_MEMCFG"],
output_dtype=self.model_config["LM_HEAD_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(hidden_states.device()),
10 changes: 6 additions & 4 deletions models/demos/falcon7b_common/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,11 @@
import torch
import ttnn
from ttnn import ReplicateTensorToMesh
from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid, get_weights_cached
from models.demos.falcon7b_common.tt.model_utils import (
get_falcon_default_core_grid,
get_weights_cached,
get_default_hifi2_kernel_config,
)
from models.demos.falcon7b_common.tests.test_utils import tt_from_torch
from torch import nn
from models.utility_functions import (
@@ -50,7 +54,6 @@ def falcon_dense_h_to_4h_matmul(
input_tensor_a,
input_tensor_b,
core_grid,
compute_kernel_config,
fused_activation=None,
output_mem_config=ttnn.DRAM_MEMORY_CONFIG,
output_dtype=None,
@@ -64,7 +67,7 @@ def falcon_dense_h_to_4h_matmul(
input_tensor_b,
memory_config=output_mem_config,
dtype=output_dtype,
compute_kernel_config=compute_kernel_config,
compute_kernel_config=get_default_hifi2_kernel_config(),
)

if is_grayskull():
@@ -357,7 +360,6 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
hidden_states = falcon_dense_h_to_4h_matmul(
x,
self.dense_h_to_4h_weights,
compute_kernel_config=self.model_config["HiFi2_KERNEL_CONFIG"],
fused_activation="gelu",
output_mem_config=self.model_config["DENSE_H_TO_4H_MM_OUTPUT_MEMCFG"],
output_dtype=self.model_config["DENSE_H_TO_4H_MM_OUTPUT_DTYPE"],
31 changes: 17 additions & 14 deletions models/demos/falcon7b_common/tt/model_config.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
from pathlib import Path
from transformers import FalconConfig
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.demos.falcon7b_common.tt.model_utils import get_default_hifi2_kernel_config

OP_KEYS = (
# Inputs
@@ -290,6 +291,22 @@ def get_model_config(model_config_str, prefill_seq_len=0, decode_batch_size=32):
model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config
model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config

if is_wormhole_b0():
default_lofi_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
default_lofi_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
)
model_config["DEFAULT_LoFi_KERNEL_CONFIG"] = default_lofi_kernel_config

model_config["DEFAULT_HiFi2_KERNEL_CONFIG"] = get_default_hifi2_kernel_config()

# uncomment if need to see all the configs
# logger.debug(f"Falcon model config: \n{pretty_print_model_config(model_config)}")
set_prefill_config(model_config, prefill_seq_len, DRAM_MEMCFG)
@@ -317,20 +334,6 @@ def set_prefill_config(model_config, seq_len, dram_memcfg):
)
model_config["MLP_KERNEL_CONFIG"] = default_kernel_config

if is_wormhole_b0():
hifi2_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
else:
hifi2_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
)
model_config["HiFi2_KERNEL_CONFIG"] = hifi2_kernel_config

mm_h_to_4h_prog_cfg = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=model_config["MLP_GRID_SIZE"],
in0_block_w=3,
16 changes: 16 additions & 0 deletions models/demos/falcon7b_common/tt/model_utils.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,22 @@ def get_falcon_default_core_grid(device):
return ttnn.CoreGrid(y=grid_size.y, x=grid_size.x)


def get_default_hifi2_kernel_config():
if is_wormhole_b0():
hifi2_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
hifi2_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
)
return hifi2_kernel_config


def layernorm(ln_input, ln_eps, ln_gamma, ln_betta, model_config):
h_dim = ln_input.shape.with_tile_padding()[-2] # corresponds to batch size (decode) or seq_len (prefill)
if h_dim in [32, 128, 256, 1024, 2048]:

0 comments on commit 9832c19

Please sign in to comment.