Skip to content

Commit

Permalink
#8866: Move ln_tensor creation to model_utils and support for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djordje-tt committed May 30, 2024
1 parent c77248a commit 319652f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
13 changes: 13 additions & 0 deletions models/demos/t3000/falcon40b/tests/test_falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
comp_pcc,
)
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, skip_for_grayskull, get_devices_for_t3000
from models.demos.t3000.falcon40b.tt.model_utils import generate_layernorm_persistent_tensors


class PytorchFalconDecoderModel(torch.nn.Module):
Expand Down Expand Up @@ -72,6 +73,7 @@ def run_test_FalconDecoder_inference(
use_cache = True
user_id = 0

ln_output_tensors_dict = {"final_layernorm": dict(), "mlp_layernorm": dict(), "attn_layernorm": dict()}
# Generate input, attention_mask, and kv_cache --------------------------------------
# TODO: Generate attention_mask on device
if llm_mode == "prefill":
Expand Down Expand Up @@ -146,6 +148,16 @@ def run_test_FalconDecoder_inference(
)
tt_layer_past = (tt_k_cache, tt_v_cache)

if seq_len > model_config["layernorm_params"]["slice_size"]:
generate_layernorm_persistent_tensors(
seq_len,
model_config["layernorm_params"]["slice_size"],
ln_output_tensors_dict,
devices,
configuration.hidden_size,
model_config["LN_MLP_OUTPUT_DTYPE"],
)

elif llm_mode == "decode":
q_len, kv_len = seq_len, kv_cache_len + 1
assert batch % 32 == 0, "For decode, batch must be multiple of 32!"
Expand Down Expand Up @@ -261,6 +273,7 @@ def run_test_FalconDecoder_inference(
model_config,
tt_cache_path,
None,
ln_output_tensors_dict,
)

tt_out, tt_layer_present = tt_FalconDecoder_model(
Expand Down
30 changes: 12 additions & 18 deletions models/demos/t3000/falcon40b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
nearest_32,
)

from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout, partial_layernorm
from models.demos.t3000.falcon40b.tt.model_utils import partial_layernorm, generate_layernorm_persistent_tensors


class TtFalconModelShared:
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
self.num_layers = num_layers
self.hidden_size = config.hidden_size
self.num_devices = len(devices)
self.ln_output_tensors_dict = {"layernorm": dict(), "mlp_layernorm": dict(), "attn_layernorm": dict()}
self.ln_output_tensors_dict = {"final_layernorm": dict(), "mlp_layernorm": dict(), "attn_layernorm": dict()}

# Word Embeddings
self.embeddings = TtFalconEmbeddings(
Expand Down Expand Up @@ -208,23 +208,17 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
# Genereate ln output tensors for prefill if not existing
do_generate_ln_tensors = (
sequence_size > self.model_config["layernorm_params"]["slice_size"]
and sequence_size not in self.ln_output_tensors_dict["layernorm"]
and sequence_size not in self.ln_output_tensors_dict["final_layernorm"]
)
if do_generate_ln_tensors:
for name in ["layernorm", "mlp_layernorm", "attn_layernorm"]:
output_tensor = [
torch2tt_tensor(
torch.zeros(1, 1, sequence_size, self.hidden_size),
self.devices[i],
tt_memory_config=ttnn.DRAM_MEMORY_CONFIG,
tt_dtype=self.model_config["LN_MLP_OUTPUT_DTYPE"],
)
for i in range(self.num_devices)
]
if name in self.ln_output_tensors_dict and self.ln_output_tensors_dict[name] is not None:
self.ln_output_tensors_dict[name].update({sequence_size: output_tensor})
else:
self.ln_output_tensors_dict[name] = {sequence_size: output_tensor}
generate_layernorm_persistent_tensors(
sequence_size,
self.model_config["layernorm_params"]["slice_size"],
self.ln_output_tensors_dict,
self.devices,
self.hidden_size,
self.model_config["LN_MLP_OUTPUT_DTYPE"],
)

elif llm_mode == "decode":
assert batch_size % 32 == 0, "For decode, batch_size must be multiple of 32!"
Expand Down Expand Up @@ -369,7 +363,7 @@ def fwd_prefill(
self.model_config["LN_MLP_OUTPUT_DTYPE"],
self.hidden_size,
self.devices,
self.ln_output_tensors_dict["layernorm"],
self.ln_output_tensors_dict["final_layernorm"],
)

return layer_output, presents
Expand Down
20 changes: 20 additions & 0 deletions models/demos/t3000/falcon40b/tt/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,23 @@ def deallocate_ln_tensors(layernorm_slice_size, seq_len):
return False
else:
return True


def generate_layernorm_persistent_tensors(seq_len, slice_size, ln_output_tensors_dict, devices, hidden_size, dtype):
if seq_len <= slice_size:
return

for name in ["final_layernorm", "mlp_layernorm", "attn_layernorm"]:
output_tensor = [
torch2tt_tensor(
torch.zeros(1, 1, seq_len, hidden_size),
devices[i],
tt_memory_config=ttnn.DRAM_MEMORY_CONFIG,
tt_dtype=dtype,
)
for i in range(len(devices))
]
if name in ln_output_tensors_dict and ln_output_tensors_dict[name] is not None:
ln_output_tensors_dict[name].update({seq_len: output_tensor})
else:
ln_output_tensors_dict[name] = {seq_len: output_tensor}

0 comments on commit 319652f

Please sign in to comment.