Skip to content

Commit

Permalink
Remove unnecessary prints in PagedAttention (pytorch#8374)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Nov 12, 2024
1 parent 071ddfa commit baa14f3
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def _multi_queries_paged_attention_nonkernel(
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens)
page_indices, # [batch_size, pages_per_sequence]
) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim]
print('Running the nonkernel version of multi-queries paged attention.')
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
num_query_per_kv = num_query_heads // num_kv_heads
Expand Down Expand Up @@ -561,7 +560,6 @@ def multi_queries_paged_attention(
lengths,
page_indices,
)
print('Running the kernel version of multi-queries paged attention.')

# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
Expand Down Expand Up @@ -612,8 +610,6 @@ def paged_attention(q,
pages_per_compute_block,
megacore_mode: str = None,
attn_logits_soft_cap: float = None):
print('Running the single-query paged attention.')

# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
Expand Down

0 comments on commit baa14f3

Please sign in to comment.