Skip to content

Commit

Permalink
#0: Remove TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 12, 2024
1 parent e7b3a33 commit 48e9da1
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,12 @@ def prefill_forward_single_user(
to keep it otherwise unaware that it is operating on a chunk.
- due to the above point, we must always set user_id to 0 for chunked prefill.
"""
# TODO: Ensure that last_token_idx is within the last chunk!
assert page_table is not None, "page_table must be provided for chunked prefill"
# TODO: Uncomment assert
assert kv_cache is not None, "kv_cache must be provided for chunked prefill"
chunk_size = get_max_prefill_chunk_size(seq_len, self.tt_model.model_config["MAX_PREFILL_SEQ_LEN"])
block_size = get_block_size(kv_cache)
last_token_idx_in_chunk = last_token_idx % chunk_size if last_token_idx is not None else None
# Calculate which chunk contains the last_token_idx
last_chunk_start = (last_token_idx // chunk_size) * chunk_size if last_token_idx is not None else None
page_table_user = page_table[user_id : user_id + 1, :]
# Pad page table to match number of blocks in seq_len
Expand All @@ -232,8 +231,6 @@ def prefill_forward_single_user(
)
CHUNK_USER_ID = 0

# Calculate which chunk contains the last_token_idx

logits_list = []
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = chunk_start + chunk_size
Expand Down

0 comments on commit 48e9da1

Please sign in to comment.