diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 6f3e382c020..31e5da5f119 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -216,13 +216,12 @@ def prefill_forward_single_user( to keep it otherwise unaware that it is operating on a chunk. - due to the above point, we must always set user_id to 0 for chunked prefill. """ - # TODO: Ensure that last_token_idx is within the last chunk! assert page_table is not None, "page_table must be provided for chunked prefill" - # TODO: Uncomment assert assert kv_cache is not None, "kv_cache must be provided for chunked prefill" chunk_size = get_max_prefill_chunk_size(seq_len, self.tt_model.model_config["MAX_PREFILL_SEQ_LEN"]) block_size = get_block_size(kv_cache) last_token_idx_in_chunk = last_token_idx % chunk_size if last_token_idx is not None else None + # Calculate which chunk contains the last_token_idx last_chunk_start = (last_token_idx // chunk_size) * chunk_size if last_token_idx is not None else None page_table_user = page_table[user_id : user_id + 1, :] # Pad page table to match number of blocks in seq_len @@ -232,8 +231,6 @@ def prefill_forward_single_user( ) CHUNK_USER_ID = 0 - # Calculate which chunk contains the last_token_idx - logits_list = [] for chunk_start in range(0, seq_len, chunk_size): chunk_end = chunk_start + chunk_size