From 7de49aa86c7f169eb0962b6db29ad53fff519ffb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Sep 2024 00:11:55 -0700 Subject: [PATCH] [torch.compile] hide slicing under custom op for inductor (#8384) --- tests/compile/test_full_graph.py | 4 +- vllm/attention/backends/flash_attn.py | 105 +++++++++++++++++--------- 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index d5b59db8c7887..0a6e781e18834 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -16,5 +16,7 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B") + llm = LLM(model="meta-llama/Meta-Llama-3-8B", + enforce_eager=True, + load_format="dummy") llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 69faa6d343eda..ec9cbde7467d6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -122,6 +122,40 @@ def _( return torch.empty_like(decode_query) +@torch.library.custom_op("vllm::reshape_and_cache_flash", + mutates_args=["kv_cache"]) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + """Inductor cannot deal with inplace operations on views. + See https://github.com/pytorch/pytorch/issues/131192 + and https://github.com/pytorch/pytorch/issues/130174 + This is a workaround to hide the view operation from the inductor. + """ + return torch.ops._C_cache_ops.reshape_and_cache_flash( + key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, + k_scale, v_scale) + + +@reshape_and_cache_flash.register_fake # type: ignore +def _( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + pass + + class FlashAttentionBackend(AttentionBackend): @staticmethod @@ -653,11 +687,10 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - ops.reshape_and_cache_flash( + torch.ops.vllm.reshape_and_cache_flash( key, value, - key_cache, - value_cache, + kv_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, k_scale, @@ -669,7 +702,6 @@ def forward( assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -680,6 +712,9 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None @@ -687,7 +722,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = torch.ops.vllm.flash_attn_varlen_func( + prefill_output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value, @@ -701,42 +736,44 @@ def forward( alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[: - num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[ - num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, softcap=self.logits_soft_cap, - ).squeeze(1) + ) - # Reshape the output tensor. + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + decode_output = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size)