diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 7a667dbcaff..d840a10b55b 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -490,7 +490,6 @@ def _multi_queries_paged_attention_nonkernel( lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens) page_indices, # [batch_size, pages_per_sequence] ) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim] - print('Running the nonkernel version of multi-queries paged attention.') batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape num_query_per_kv = num_query_heads // num_kv_heads @@ -561,7 +560,6 @@ def multi_queries_paged_attention( lengths, page_indices, ) - print('Running the kernel version of multi-queries paged attention.') # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. @@ -612,8 +610,6 @@ def paged_attention(q, pages_per_compute_block, megacore_mode: str = None, attn_logits_soft_cap: float = None): - print('Running the single-query paged attention.') - # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. jax_import_guard()