Skip to content

Commit

Permalink
#8049: Update bfp8 config: only residual in fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
johanna-rock-tt committed May 22, 2024
1 parent 276dcfd commit e2c5f52
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,18 @@ def test_FalconCausalLM_prefill_end_to_end_t3000_ci_loops_10(
if data_type == "BFLOAT8_B":
if seq_len == 32:
out_pcc = 0.983
k_cache_pcc = 0.985
v_cache_pcc = 0.957
k_cache_pcc = 0.982
v_cache_pcc = 0.949
token_pcc = 0.99
elif seq_len == 128:
out_pcc = 0.990
k_cache_pcc = 0.990
v_cache_pcc = 0.963
k_cache_pcc = 0.989
v_cache_pcc = 0.950
token_pcc = 0.99
elif seq_len == 2048:
out_pcc = 0.993
k_cache_pcc = 0.992
v_cache_pcc = 0.979
k_cache_pcc = 0.991
v_cache_pcc = 0.972
token_pcc = 0.99
elif data_type == "BFLOAT16":
if seq_len == 32:
Expand Down
12 changes: 6 additions & 6 deletions models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,18 @@ def test_FalconCausalLM_end_to_end_with_program_cache(
if data_type == "BFLOAT8_B":
if seq_len == 32:
out_pcc = 0.983
k_cache_pcc = 0.985
v_cache_pcc = 0.957
k_cache_pcc = 0.982
v_cache_pcc = 0.949
token_pcc = 0.99
elif seq_len == 128:
out_pcc = 0.990
k_cache_pcc = 0.990
v_cache_pcc = 0.963
k_cache_pcc = 0.989
v_cache_pcc = 0.950
token_pcc = 0.99
elif seq_len == 2048:
out_pcc = 0.993
k_cache_pcc = 0.992
v_cache_pcc = 0.979
k_cache_pcc = 0.991
v_cache_pcc = 0.972
token_pcc = 0.99
elif data_type == "BFLOAT16":
if seq_len == 32:
Expand Down
3 changes: 0 additions & 3 deletions models/demos/t3000/falcon40b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,6 @@ def get_prefill_model_config(model_config_str, input_shape, num_devices):

model_config["ATTN_MASK_DTYPE"] = BFP4_DTYPE

model_config["LN_INPUT_DTYPE"] = BFLOAT16_DTYPE
model_config["LN_MLP_OUTPUT_DTYPE"] = BFLOAT16_DTYPE
model_config["ATTENTION_DTYPE"] = BFLOAT16_DTYPE # used for SDPA
model_config["WORD_EMBEDDING_OUTPUT_DTYPE"] = BFLOAT16_DTYPE # embeddings output and the residual stream

# Set input df for AllGathers to bfp8 to save data bandwidth
Expand Down

0 comments on commit e2c5f52

Please sign in to comment.